-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
74 lines (64 loc) · 3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
# source: https://stackoverflow.com/questions/43064524/plotting-shaded-uncertainty-region-in-line-plot-in-matplotlib-when-data-has-nans
def plot_final(final_means, final_stds, xs, title, xlabel, fname, legends=None):
fig, ax = plt.subplots()
clrs = sns.color_palette("husl", len(final_means))
if len(clrs) > 1:
clrs[0], clrs[1] = clrs[1], clrs[0]
legend_available = legends is not None and len(legends) > 0
with sns.axes_style("whitegrid"):
for i, color in enumerate(clrs):
if legend_available:
if i==1:
ax.plot(xs, final_means[i], label=legends[i], c=color)
else:
ax.plot(xs, final_means[i], label=legends[i], c=color, alpha=0.7)
# if "Regret" in title:
# ax.plot(xs, final_means[i], label=legends[i], c=color, alpha=0.2)
# else:
# ax.plot(xs, final_means[i], label=legends[i], c=color, alpha=0.4)
else:
ax.plot(xs, final_means[i], c=color)
ax.fill_between(xs, final_means[i] - final_stds[i], final_means[i] + final_stds[i],
alpha=0.2, facecolor=color)
ax.set_title(title)
ax.set_xlabel(xlabel)
if legend_available:
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.savefig(fname, dpi=1200, bbox_inches='tight')
# plt.show()
plt.close()
def plot_final_discrete(final_means, final_stds, xs, title, xlabel, fname, legends=None):
fig, ax = plt.subplots()
clrs = sns.color_palette("husl", len(final_means))
legend_available = legends is not None and len(legends) > 0
with sns.axes_style("darkgrid"):
for i, color in enumerate(clrs):
if legend_available:
ax.errorbar(xs, final_means[i], yerr=final_stds[i], fmt='o', linestyle='dashed', capsize=3,
label=legends[i], c=color)
else:
ax.errorbar(xs, final_means[i], yerr=final_stds[i], fmt='o', linestyle='dashed', capsize=3, c=color)
ax.set_title(title)
ax.set_xlabel(xlabel)
if legend_available:
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.savefig(fname, dpi=1200, bbox_inches='tight')
plt.close()
# plt.show()
def plot_network(Network, pos=None, parent=None, fname=None, node_color=None):
plt.figure(1000)
if pos is None:
pos = nx.spring_layout(Network)
if fname is None and parent is not None:
fname = f"{parent}/{Network.name}.pdf"
if node_color is None:
nx.draw_networkx(Network, with_labels=True, pos=pos, node_size=100, font_size=8)
else:
nx.draw_networkx(Network, node_color=node_color, with_labels=True, pos=pos, node_size=100, font_size=8)
plt.savefig(fname, bbox_inches='tight')
plt.close()
# plt.show()