commit
e6e2519eaa
|
@ -10,6 +10,7 @@ import matplotlib
|
|||
matplotlib.use('Agg')
|
||||
from matplotlib import rcParams
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
# Colors
|
||||
BLUEISH = [c / 255.0 for c in [71, 101, 177]] # #4765b1
|
||||
|
@ -24,7 +25,7 @@ def plot_graphs(results, file_name, num_rows, num_cols,
|
|||
x_keys, y_keys, titles, x_labels, y_labels,
|
||||
label_names, title, tight_plot, verbose):
|
||||
assert len(results) == num_rows * num_cols
|
||||
assert len(results) != 1
|
||||
assert len(results) >= 1
|
||||
assert len(x_keys) == len(results)
|
||||
assert len(y_keys) == len(results)
|
||||
assert len(titles) == len(results)
|
||||
|
@ -64,6 +65,9 @@ def plot_graphs(results, file_name, num_rows, num_cols,
|
|||
size_y = plot_size * num_rows
|
||||
rcParams.update({'font.size': font_size})
|
||||
fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(size_x, size_y), facecolor='w', edgecolor='k')
|
||||
if len(results) == 1 and not type(axes) is np.ndarray:
|
||||
axes = np.full((1,1), axes)
|
||||
assert type(axes) is np.ndarray
|
||||
fig.text(.5, 0.92, title, horizontalalignment="center", fontsize=font_size_title)
|
||||
plt.subplots_adjust(wspace=w_space, hspace=h_space)
|
||||
|
||||
|
@ -72,7 +76,7 @@ def plot_graphs(results, file_name, num_rows, num_cols,
|
|||
for col in range(num_cols):
|
||||
index = row * num_cols + col
|
||||
result = results[index]
|
||||
ax = axes.flat[index]
|
||||
ax = axes[row, col]
|
||||
plt.sca(ax)
|
||||
print("[plot] Plotting subplot %d" % index)
|
||||
|
||||
|
|
Loading…
Reference in New Issue