diff --git a/scripts/benchmark/plot.py b/scripts/benchmark/plot.py index 6337b78f..b0b63df3 100644 --- a/scripts/benchmark/plot.py +++ b/scripts/benchmark/plot.py @@ -10,6 +10,7 @@ import matplotlib matplotlib.use('Agg') from matplotlib import rcParams import matplotlib.pyplot as plt +import numpy as np # Colors BLUEISH = [c / 255.0 for c in [71, 101, 177]] # #4765b1 @@ -24,7 +25,7 @@ def plot_graphs(results, file_name, num_rows, num_cols, x_keys, y_keys, titles, x_labels, y_labels, label_names, title, tight_plot, verbose): assert len(results) == num_rows * num_cols - assert len(results) != 1 + assert len(results) >= 1 assert len(x_keys) == len(results) assert len(y_keys) == len(results) assert len(titles) == len(results) @@ -64,6 +65,9 @@ def plot_graphs(results, file_name, num_rows, num_cols, size_y = plot_size * num_rows rcParams.update({'font.size': font_size}) fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(size_x, size_y), facecolor='w', edgecolor='k') + if len(results) == 1 and not type(axes) is np.ndarray: + axes = np.full((1,1), axes) + assert type(axes) is np.ndarray fig.text(.5, 0.92, title, horizontalalignment="center", fontsize=font_size_title) plt.subplots_adjust(wspace=w_space, hspace=h_space) @@ -72,7 +76,7 @@ def plot_graphs(results, file_name, num_rows, num_cols, for col in range(num_cols): index = row * num_cols + col result = results[index] - ax = axes.flat[index] + ax = axes[row, col] plt.sca(ax) print("[plot] Plotting subplot %d" % index)