Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
PerfHistory.py
1import json
2import os
3
4import numpy as np
5
6import shamrock.sys
7
8try:
9 import matplotlib
10 import matplotlib.pyplot as plt
11
12 _HAS_MATPLOTLIB = True
13except ImportError:
14 _HAS_MATPLOTLIB = False
15
16
18 """
19 Analysis utility to report performance during the simulation as well as some metrics regarding walltime and step counts.
20 """
21
22 def __init__(self, model, analysis_folder, analysis_prefix):
23 self.model = model
24
25 self.analysis_prefix = os.path.join(analysis_folder, analysis_prefix)
26 self.plot_prefix = os.path.join(analysis_folder, "plot_" + analysis_prefix)
27
28 self.json_data_filename = self.analysis_prefix + ".json"
29 self.plot_filename = self.plot_prefix
30
31 def analysis_save(self, iplot):
32 sim_time_delta = self.model.solver_logs_cumulated_step_time()
33 scount = self.model.solver_logs_step_count()
34 part_count = self.model.get_total_part_count()
35
36 self.model.solver_logs_reset_cumulated_step_time()
37 self.model.solver_logs_reset_step_count()
38
39 if shamrock.sys.world_rank() == 0:
40 perf_hist_new = {
41 "time": self.model.get_time(),
42 "sim_time_delta": sim_time_delta,
43 "world_size": shamrock.sys.world_size(),
44 "sim_step_count_delta": scount,
45 "part_count": part_count,
46 }
47
48 try:
49 with open(self.json_data_filename, "r") as fp:
50 perf_hist = json.load(fp)
51 except (FileNotFoundError, json.JSONDecodeError):
52 perf_hist = {"history": []}
53
54 perf_hist["history"] = perf_hist["history"][:iplot] + [perf_hist_new]
55
56 if scount == 0:
57 print("Warning: step count is 0, skipping save of perf history")
58 return
59
60 with open(self.json_data_filename, "w") as fp:
61 print(f"Saving perf history to {self.json_data_filename}")
62 json.dump(perf_hist, fp, indent=4)
63
64 def load_analysis(self):
65 with open(self.json_data_filename, "r") as fp:
66 perf_hist = json.load(fp)
67 return perf_hist
68
69 def digest_perf_history(self):
70 perf_hist = self.load_analysis()
71
72 t = [h["time"] for h in perf_hist["history"]]
73 sim_time_delta = [h["sim_time_delta"] for h in perf_hist["history"]]
74 world_size = [h["world_size"] for h in perf_hist["history"]]
75 sim_step_count_delta = [h["sim_step_count_delta"] for h in perf_hist["history"]]
76 part_count = [h["part_count"] for h in perf_hist["history"]]
77
78 t = np.array(t)
79 dt_code = np.diff(t)
80
81 sim_time_delta = np.array(sim_time_delta)
82 world_size = np.array(world_size)
83 sim_time_delta_all_proc = sim_time_delta * world_size
84 sim_step_count_delta = np.array(sim_step_count_delta)
85 part_count = np.array(part_count)
86
87 # cumulative sim_time & step_count
88 cum_sim_time_delta = np.cumsum(sim_time_delta)
89 cum_sim_time_delta_all_proc = np.cumsum(sim_time_delta_all_proc)
90 cum_sim_step_count_delta = np.cumsum(sim_step_count_delta)
91
92 tsim_per_hour = dt_code / (sim_time_delta[1:] / 3600)
93
94 time_per_step = []
95
96 for td, sc, pc in zip(sim_time_delta, sim_step_count_delta, part_count):
97 if sc > 0:
98 time_per_step.append(td / sc)
99 else:
100 # NAN here because the step count is 0
101 time_per_step.append(np.nan)
102
103 rate = []
104
105 for td, sc, pc in zip(sim_time_delta, sim_step_count_delta, part_count):
106 if sc > 0:
107 rate.append(pc / (td / sc))
108 else:
109 # NAN here because the step count is 0
110 rate.append(np.nan)
111
112 return {
113 "t": t,
114 "dt_code": dt_code,
115 "part_count": part_count,
116 "world_size": world_size,
117 "sim_time_delta": sim_time_delta,
118 "sim_step_count_delta": sim_step_count_delta,
119 "cum_sim_time_delta": cum_sim_time_delta,
120 "cum_sim_time_delta_all_proc": cum_sim_time_delta_all_proc,
121 "cum_sim_step_count_delta": cum_sim_step_count_delta,
122 "time_per_step": time_per_step,
123 "rate": rate,
124 "tsim_per_hour": tsim_per_hour,
125 }
126
127 def plot_perf_history(self, close_plots=True, figsize=(8, 5), dpi=200):
128 if not _HAS_MATPLOTLIB:
129 print("Warning: matplotlib is not installed, plot_perf_history is a no-op")
130 return
131
132 if shamrock.sys.world_rank() == 0:
133 perf_hist = self.digest_perf_history()
134
135 print(f"Plotting perf history from {self.json_data_filename}")
136
137 t = perf_hist["t"]
138
139 plt.figure(figsize=figsize, dpi=dpi)
140 plt.plot(t, perf_hist["cum_sim_time_delta"], "+-")
141 plt.xlabel("t [code unit] (simulation)")
142 plt.ylabel("t [s] (real time)")
143 plt.savefig(self.plot_filename + "_sim_time.png")
144 if close_plots:
145 plt.close()
146
147 plt.figure(figsize=figsize, dpi=dpi)
148 plt.plot(
149 t,
150 perf_hist["cum_sim_time_delta_all_proc"] / 3600.0,
151 "+-",
152 label="Used compute time",
153 )
154 plt.xlabel("t [code unit] (simulation)")
155 plt.ylabel("$\\sum_{processes} t$ [h] (real time)")
156
157 ax1 = plt.gca()
158
159 # Right y-axis
160 ax2 = ax1.twinx()
161 ax2.plot(t, perf_hist["world_size"], "+-", color="tab:orange", label="World size")
162 ax2.set_ylabel("World size")
163
164 lines1, labels1 = ax1.get_legend_handles_labels()
165 lines2, labels2 = ax2.get_legend_handles_labels()
166 ax1.legend(lines1 + lines2, labels1 + labels2, loc="best")
167
168 plt.savefig(self.plot_filename + "_sim_time_all_proc.png")
169 if close_plots:
170 plt.close()
171
172 plt.figure(figsize=figsize, dpi=dpi)
173 plt.plot(t, perf_hist["cum_sim_step_count_delta"], "+-")
174 plt.xlabel("t [code unit] (simulation)")
175 plt.ylabel("$N_\\mathrm{step}$")
176 plt.savefig(self.plot_filename + "_step_count.png")
177 if close_plots:
178 plt.close()
179
180 plt.figure(figsize=figsize, dpi=dpi)
181 plt.plot(t, perf_hist["sim_time_delta"], "+-")
182 plt.xlabel("t [code unit] (simulation)")
183 plt.ylabel("$d t_\\mathrm{real} / d i_\\mathrm{analysis}$ [s]")
184 plt.savefig(self.plot_filename + "_sim_time_delta.png")
185 if close_plots:
186 plt.close()
187
188 plt.figure(figsize=figsize, dpi=dpi)
189 plt.plot(t, perf_hist["sim_step_count_delta"], "+-")
190 plt.xlabel("t [code unit] (simulation)")
191 plt.ylabel("$d N_\\mathrm{step} / d i_\\mathrm{analysis}$")
192 plt.savefig(self.plot_filename + "_step_count_delta.png")
193 if close_plots:
194 plt.close()
195
196 # tsim per hour
197 plt.figure(figsize=figsize, dpi=dpi)
198 plt.plot(t[1:], perf_hist["tsim_per_hour"], "+-")
199 plt.xlabel("t [code unit] (simulation)")
200 plt.ylabel("$d t_\\mathrm{sim} / d t_\\mathrm{realtime}$ [code unit (time) / hour]")
201 plt.savefig(self.plot_filename + "_tsim_per_hour.png")
202 if close_plots:
203 plt.close()
204
205 plt.figure(figsize=figsize, dpi=dpi)
206 plt.plot(t, perf_hist["time_per_step"], "+-")
207 plt.xlabel("t [code unit] (simulation)")
208 plt.ylabel("time per step [s]")
209 plt.savefig(self.plot_filename + "_time_per_step.png")
210 if close_plots:
211 plt.close()
212
213 plt.figure(figsize=figsize, dpi=dpi)
214 plt.plot(t, perf_hist["rate"], "+-")
215 plt.xlabel("t [code unit] (simulation)")
216 plt.ylabel("Particles / second")
217 plt.yscale("log")
218 plt.savefig(self.plot_filename + "_rate.png")
219 if close_plots:
220 plt.close()