Plotly Fundamentals -
Subplots

In this chapter we will learn how to produce subplots displaying multiple different charts in one figure. To start off we will import some packages. Besides the Plotly Express and Plotly Graph Objects package we will use the make_subplots function which helps you to set up the canvas on which the charts are going to be displayed.

import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

For the examples that follow we will work with the stock market data set provided by the Plotly Express package. We can once again use the head() function to have a look at the first couple of lines of our dataframe.

stocks = px.data.stocks()
stocks.head()
date GOOG AAPL AMZN FB NFLX MSFT
0 2018-01-01 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000
1 2018-01-08 1.018172 1.011943 1.061881 0.959968 1.053526 1.015988
2 2018-01-15 1.032008 1.019771 1.053240 0.970243 1.049860 1.020524
3 2018-01-22 1.066783 0.980057 1.140676 1.016858 1.307681 1.066561
4 2018-01-29 1.008773 0.917143 1.163374 1.018357 1.273537 1.040708

For our first example we will use the make_subplots function to plot two charts side by side in one figure. To do that, we create a figure object by passing a grid layout defined by a number of rows and columns (1 by 2) to the function call. Subsequently we use the ass_trace method to add two line charts. We can reference their position on the grid by providing the row and col position to the add_trace function. In this instance, we are referring to the first row in the first column of our 1 by 2 grid.

fig = make_subplots(rows=1, cols=2)

fig.add_trace(
    go.Scatter(x=stocks['date'], y=stocks['GOOG'], name='GOOG'),
    row=1, col=1
)

fig.add_trace(
    go.Scatter(x=stocks['date'], y=stocks['AAPL'], name='AAPL'),
    row=1, col=2
)

fig.update_layout(height=400, width=1200)
fig.show()

Awesome, but what about four or more charts. As the number of charts grows its pretty cumbersome to add all the traces manually. We can however start by defining a python dictionary that is indexed with the data column name we want to display and a respective reference to the chart position in our grid. We then create our figure object and loop through the keys in our dictionary and iteratively assign a graph for each position in the grid.

pos = {'GOOG':(1,1), 'AAPL': (2,1), 'AMZN': (1,2), 'FB': (2,2)}

fig = make_subplots(rows=2, cols=2, subplot_titles=(list(pos.keys())))

for stock in list(pos.keys()):
    
    fig.add_trace(go.Scatter(x=stocks['date'], y=stocks[stock], name = stock), row=pos[stock][0], col=pos[stock][1])

fig.update_layout(height=800, width=1200)
fig.show()

Maybe you want to add charts that do not use equal space on the grid. To achieve this you can simply add a list to the column_width parameter that contains a percentage value for each chart you intend to draw (in case of multiple rows provide a list of lists). The percentage values control how much of each dimension the chart will occupy on the grid.

fig = make_subplots(rows=1, cols=2, column_width = [0.75, 0.25])

fig.add_trace(
    go.Scatter(x=stocks['date'], y=stocks['GOOG'], name = 'GOOG'),
    row=1, col=1
)

fig.add_trace(
    go.Scatter(x=stocks['date'], y=stocks['AAPL'], name = 'AAPL'), 
    row=1, col=2
)

fig.update_layout(height=400, width=1200)
fig.show()

A very powerful feature is the application of cross filtering which gives your subplots similar look and feel as a simple dashboard. By using the update_xaxes and update_yaxes you can extend fiter values from every individual chart to the other charts on the grid layout. 

pos = {'GOOG':(1,1), 'AAPL': (2,1), 'AMZN': (1,2), 'FB': (2,2)}

fig = make_subplots(rows=2, cols=2, subplot_titles=(list(pos.keys())))

for stock in list(pos.keys()):
    
    fig.add_trace(go.Scatter(x=stocks['date'], y=stocks[stock], name = stock), row=pos[stock][0], col=pos[stock][1])


fig.update_xaxes(matches='x')
fig.update_yaxes(matches='y')

fig.update_layout(height=800, width=1200)
fig.show()

In fringe cases you might want to enable cross filtering between subplots only for charts along a certain axis. To do that we use the shared_xaxes and shared_yaxes parameters in the make_subplots function call.

pos = {'GOOG':(1,1), 'AAPL': (2,1), 'AMZN': (1,2), 'FB': (2,2)}

fig = make_subplots(rows=2, cols=2, shared_yaxes = True, shared_xaxes = True, subplot_titles=(list(pos.keys())))

for stock in list(pos.keys()):
    
    fig.add_trace(go.Scatter(x=stocks['date'], y=stocks[stock], name = stock), row=pos[stock][0], col=pos[stock][1])
    fig.update_xaxes(row=pos[stock][0], col=pos[stock][1], title = 'subplot axis title')


fig.update_layout(height=800, width=1200)
fig.show()

In a final example we will further customise our grid. Besides the usual dictionary containing the data column labels and chart positions we will create a list whit the same dimensions as our chart. With the “colspan” parameter we can extend the chart in the second row to occupy also the space of the chart normally located in position (2,2). In place of that chart we will provide None to indicate that no additional chart will be plotted in that position.

pos = {'GOOG':(1,1), 'AAPL': (2,1), 'AMZN': (1,2)}
spec = [[{}, {}], [{"colspan": 2}, None]]

fig = make_subplots(rows=2, cols=2, specs = spec, subplot_titles=(list(pos.keys())))

for stock in list(pos.keys()):
    
    fig.add_trace(go.Scatter(x=stocks['date'], y=stocks[stock], name = stock), row=pos[stock][0], col=pos[stock][1])


fig.update_layout(height=800, width=1200)
fig.show()

We covered a lot of ground in this tutorial series already which is great. However, there is much more to come in the following chapter on 3d plots. Have fun!

fistofgeek.com

coding - data science - finance

Get Connected