from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
"figure.facecolor"] = (1, 1, 1, 0) # RGBA tuple with alpha=0
plt.rcParams["axes.facecolor"] = (1, 1, 1, 0) # RGBA tuple with alpha=0
plt.rcParams[
= Path().cwd().parent.parent PROJECT_ROOT
Matplotlib
# Load some sample data
= pd.read_csv(f"{PROJECT_ROOT}/data/medals_by_country_2016.csv", index_col=[0])
medals = pd.read_csv(f"{PROJECT_ROOT}/data/climate_change.csv", parse_dates=["date"]) climate_change
medals
Bronze | Gold | Silver | |
---|---|---|---|
United States | 67 | 137 | 52 |
Germany | 67 | 47 | 43 |
Great Britain | 26 | 64 | 55 |
Russia | 35 | 50 | 28 |
China | 35 | 44 | 30 |
France | 21 | 20 | 55 |
Australia | 25 | 23 | 34 |
Italy | 24 | 8 | 38 |
Canada | 61 | 4 | 4 |
Japan | 34 | 17 | 13 |
climate_change
date | co2 | relative_temp | |
---|---|---|---|
0 | 1958-03-06 | 315.71 | 0.10 |
1 | 1958-04-06 | 317.45 | 0.01 |
2 | 1958-05-06 | 317.50 | 0.08 |
3 | 1958-06-06 | NaN | -0.05 |
4 | 1958-07-06 | 315.86 | 0.06 |
... | ... | ... | ... |
701 | 2016-08-06 | 402.27 | 0.98 |
702 | 2016-09-06 | 401.05 | 0.87 |
703 | 2016-10-06 | 401.59 | 0.89 |
704 | 2016-11-06 | 403.55 | 0.93 |
705 | 2016-12-06 | 404.45 | 0.81 |
706 rows × 3 columns
= plt.subplots() fig, ax
We created two variables, fig and ax. These are arbitrary names but a convention.
These two variables hold the core objects used for all types of plotting operations. First object fig
, short for figure, imagine it as the frame of your plot. You can resize, reshape the frame but you cannot draw on it. On a single notebook or a script, you can have multiple figures. Each figure can have multiple subplots. Here, subplot is synonymous with axes. The second object, ax
, short for axes, is the canvas you draw on. Or rephrasing, it is the blank sheet you can plot and hold your data. An axes object can only belong to one figure.
Axes methods vs. pyplot, understanding further
Pyplot is a more beginner-friendly method to interact with Matplotlib. Compared to axes methods, pyplot offers a quicker and more concise method of plotting. It will have fewer local variables and syntax. But most people prefer the object-oriented way?
Let’s see the concise use of pyplot:
"Gold"])
plt.bar(medals.index, medals[=90)
plt.xticks(rotation"# of gold medals") plt.ylabel(
Text(0, 0.5, '# of gold medals')
This was easy. It only took 3 lines. Now let’s what happens if we try to plot the (completely unrelated) climate change data next to it:
# plt.bar(medals.index, medals["Gold"])
# plt.xticks(rotation=90)
# plt.ylabel("# of gold medals")
# plt.plot(climate_change["date"], climate_change["co2"])
In this case, we get a TypeError. pyplot, on its own, cannot create new axes or a new figure and intelligently plot the new data. As we get to more complex plotting like this one, we are going to need more a flexible approach.
Matplotlib has this concept called current figure. By default, pyplot itself creates a current figure axes and plots on it. If for example, we want to focus on that current figure and plot extra data on it, like we tried in the last example, pyplot moves the current figure to a new one immediately after a new plotting command is given.
To avoid this, let’s see the approach where we are in full control of each figure and axes:
= plt.subplots()
fig, ax "date"], climate_change["co2"])
ax.plot(climate_change[set(title="Amount of CO2 (ppm) in each year", xlabel="Year", ylabel="Amount of CO2 (ppm)") ax.
[Text(0.5, 1.0, 'Amount of CO2 (ppm) in each year'),
Text(0.5, 0, 'Year'),
Text(0, 0.5, 'Amount of CO2 (ppm)')]
We specifically point out that we are working on this fig object. It means that any plotting command we write will be applied to the axes (ax) object that belongs to fig. Unless, we define a new figure with plt.subplots() command, the current figure will be the variable fig. This way is very nice since now we can create as many axes or subplots in a single figure and work with them.
From now on, I will be using subplot and axes terms interchangeably as they are synonyms
plt.subplots() grid system
We saw an example of creating one subplot. Let’s see how can create more in a single figure:
= plt.subplots(nrows=1, ncols=2)
fig, ax print(ax)
print(type(ax))
[<AxesSubplot: > <AxesSubplot: >]
<class 'numpy.ndarray'>
Among other parameters, .subplots()
have two parameters to specify the grid size. nrows
and ncols
are used to point out the number of rows and columns we need respectively. If you paid attention, now our second variable contains not one but two axes. And it is now given as a numpy.ndarray
. So, we have to unpack or index this array to use our plotting commands:
= plt.subplots(1, 2, figsize=(10, 5))
fig, (ax1, ax2) =3)
fig.tight_layout(pad"date"], climate_change["co2"])
ax1.plot(climate_change["date"], climate_change["relative_temp"])
ax2.plot(climate_change[
plt.show()
= plt.subplots(1, 2, figsize=(10, 5))
fig, ax =3)
fig.tight_layout(pad0].plot(climate_change["date"], climate_change["co2"])
ax[1].plot(climate_change["date"], climate_change["relative_temp"])
ax[
plt.show()
The two methods are completely similar and up to you to choose one. Let’s see one more example but slightly more difficult:
= plt.subplots(2, 2, figsize=(10, 5)) fig, ((ax1, ax2), (ax3, ax4))
= plt.subplots()
fig, ax "Gold"])
ax.bar(medals.index, medals[set(title="# gold medals by country", ylabel="# of medals", xlabel="Country")
ax.=90)
ax.set_xticklabels(medals.index, rotation plt.show()
/var/folders/qz/3pl78vgn6t7cltx35ttht03r0000gn/T/ipykernel_48226/3128316306.py:4: UserWarning: FixedFormatter should only be used together with FixedLocator
ax.set_xticklabels(medals.index, rotation=90)
All the methods that are available in pyplot API has an equivalent through ax.set_
. If you use a general, ax.set()
, you will avoid repetition when you have multiple subplots. However, if you need to specify additional parameters to specific parts of your plot, use ax.set_
:
= plt.subplots()
fig, ax "Gold"])
ax.bar(medals.index, medals[=90)
ax.set_xticklabels(medals.index, rotation"Country", fontsize=15)
ax.set_xlabel("# gold medals by country", fontsize=20)
ax.set_title( plt.show()
/var/folders/qz/3pl78vgn6t7cltx35ttht03r0000gn/T/ipykernel_48226/3042962666.py:3: UserWarning: FixedFormatter should only be used together with FixedLocator
ax.set_xticklabels(medals.index, rotation=90)
Doubling axis
Sometimes, we want to have a single subplot to have more than one x axis or y axis. While it is not possible with plain pyplot interface, it is very easy with top-level figure object-oriented API. Let’s say we want to plot the relative_temp
and co2
columns of climate_change
in a single plot. We want them to share an x axis since the data is for the same time period:
# Create a figure and an axis
= plt.subplots()
fig, ax # Plot CO2 emissions with a blue line
"date"], climate_change["co2"], color="blue")
ax.plot(climate_change[
# Specify that we will be using a twin x axis
= ax.twinx()
ax2
"date"], climate_change["relative_temp"], color="red")
ax2.plot(climate_change[
# Change the color of ticks
"y", colors="blue") # 'y' because we want to change the y axis
ax.tick_params("y", colors="red")
ax2.tick_params(
set(
ax.="CO2 emmissions and relative temperature by year", xlabel="Year"
title# Does not matter which one you pick, ax or ax2
) plt.show()
We wanted to have a common x axis, which was date column, so we created another axis using ax.twinx()
. If in some cases you want a common y axis, the function is ax.twiny()
.
Working with figure object
I think you noticed that once you create a figure object using .subplots()
command or other methods, pretty much everything happens with axes objects. One common method of figure object is savefig()
method, which is used often. So, let’s get exploring. We will get back to our double-axed plot of CO2. Let’s save it to local memory:
# Create a figure and an axis
= plt.subplots()
fig, ax # Plot CO2 emissions with a blue line
"co2"], color="blue")
ax.plot(climate_change.index, climate_change[
# Specify that we will be using a twin x axis
= ax.twinx()
ax2
"relative_temp"], color="red")
ax2.plot(climate_change.index, climate_change[
# Change the color of ticks
"y", colors="blue") # 'y' because we want to change the y axis
ax.tick_params("y", colors="red")
ax2.tick_params(
set(
ax.="CO2 emmissions and relative temperature by year", xlabel="Year"
title# Does not matter which one you pick, ax or ax2
)
# fig.savefig("co2_relative_temp.png")
[Text(0.5, 1.0, 'CO2 emmissions and relative temperature by year'),
Text(0.5, 0, 'Year')]