Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
StandardPlotHelper.py
1import glob
2import json
3import os
4
5import numpy as np
6
7import shamrock.sys
8
9from .UnitHelper import plot_codeu_to_unit
10
11try:
12 import matplotlib
13 import matplotlib.animation as animation
14 import matplotlib.pyplot as plt
15
16 _HAS_MATPLOTLIB = True
17except ImportError:
18 _HAS_MATPLOTLIB = False
19
20
21def analysis_save(iplot, data, metadata, npy_data_filename, json_data_filename):
22 """
23 Save the analysis data to the json and npy files
24 """
25 if shamrock.sys.world_rank() == 0:
26 print(f"Saving data to {npy_data_filename.format(iplot)}")
27 np.save(npy_data_filename.format(iplot), data)
28
29 with open(json_data_filename.format(iplot), "w") as fp:
30 print(f"Saving metadata to {json_data_filename.format(iplot)}")
31 json.dump(metadata, fp)
32
33
34def load_analysis(iplot, json_data_filename, npy_data_filename):
35 """
36 Load the analysis data from the json and npy files
37 """
38 with open(json_data_filename.format(iplot), "r") as fp:
39 metadata = json.load(fp)
40 return np.load(npy_data_filename.format(iplot)), metadata
41
42
43def get_list_analysis_id(glob_str_data):
44 """
45 Get the list of analysis ids from the glob string
46 """
47 list_files = glob.glob(glob_str_data)
48 list_files.sort()
49 list_analysis_id = []
50 for f in list_files:
51 list_analysis_id.append(int(f.split("_")[-1].split(".")[0]))
52 return list_analysis_id
53
54
55def field_normalize(field, normalization, min_normalization=1e-9):
56 """
57 Normalize the field by the normalization and set to nan below min_normalization
58 """
59 ret = field / normalization
60
61 # set to nan below min_normalization
62 ret[normalization < min_normalization] = np.nan
63
64 return ret
65
66
67def figure_init(aspect, holywood_mode=False, dpi=200, base_size=6, nx=None, ny=None):
68 """
69 Initialize the figure
70 """
71 figsize = (aspect * base_size, 1.0 * base_size)
72
73 if not holywood_mode:
74 fx, fy = figsize
75 figsize = (fx + 1, fy)
76
77 # Reset the figure using the same memory as the last one
78 plt.figure(figsize=figsize, num=1, clear=True, dpi=dpi)
79
80 if holywood_mode:
81 if nx is None or ny is None:
82 raise ValueError("nx and ny must be provided in holywood mode")
83
84 plt.gca().set_position((0, 0, 1, 1))
85 plt.gcf().set_size_inches(nx / dpi, ny / dpi)
86 plt.axis("off")
87
88
89def figure_add_colorbar(imshow_result, label, holywood_mode=False):
90 """
91 Add the colorbar to the figure
92 """
93 if holywood_mode:
94 axins = plt.gca().inset_axes([0.73, 0.1, 0.25, 0.025])
95 cbar = plt.colorbar(imshow_result, cax=axins, orientation="horizontal", extend="both")
96 cbar.set_label(label, color="white")
97
98 # Set colorbar elements to white
99 cbar.outline.set_edgecolor("white")
100 # cbar.ax.yaxis.set_tick_params(color='white')
101 plt.setp(cbar.ax.get_yticklabels(), color="white")
102 plt.setp(cbar.ax.get_xticklabels(), color="white")
103 cbar.ax.tick_params(color="white", labelcolor="white", length=6, width=1)
104
105 else:
106 cbar = plt.colorbar(imshow_result, extend="both")
107 cbar.set_label(label)
108
109
110def figure_render_sinks(sink_pos_screen, ax, scale_factor, color, linewidth, fill):
111 """
112 Render the sinks on the figure
113 """
114 output_list = []
115 for x, y, s in sink_pos_screen:
116 output_list.append(
117 plt.Circle(
118 (x, y),
119 s["accretion_radius"] * scale_factor,
120 linewidth=linewidth,
121 color=color,
122 fill=fill,
123 )
124 )
125 for circle in output_list:
126 ax.add_artist(circle)
127
128
129def figure_add_time_info(text, holywood_mode=False):
130 """
131 Add the time info to the figure
132 """
133 if holywood_mode:
134 from matplotlib.offsetbox import AnchoredText
135
136 anchored_text = AnchoredText(text, loc=2)
137 plt.gca().add_artist(anchored_text)
138 else:
139 plt.title(text)
140
141
142def init_analysis_plot_paths(obj, analysis_folder, analysis_prefix):
143 plots_dir = os.path.join(analysis_folder, "plots")
144 os.makedirs(plots_dir, exist_ok=True)
145
146 obj.analysis_prefix = os.path.join(plots_dir, analysis_prefix) + "_"
147 obj.plot_prefix = os.path.join(plots_dir, "plot_" + analysis_prefix) + "_"
148
149 obj.npy_data_filename = obj.analysis_prefix + "{:07}.npy"
150 obj.json_data_filename = obj.analysis_prefix + "{:07}.json"
151 obj.plot_filename = obj.plot_prefix + "{:07}.png"
152 obj.glob_str_plot = obj.plot_prefix + "*.png"
153 obj.glob_str_data = obj.analysis_prefix + "*.json"
154
155
157 def __init__(
158 self,
159 model,
160 ext_r,
161 nx,
162 ny,
163 ex,
164 ey,
165 center,
166 analysis_folder,
167 analysis_prefix,
168 compute_function=None,
169 ):
170 self.model = model
171 self.ext_r = ext_r
172 self.nx = nx
173 self.ny = ny
174 self.ex = ex
175 self.ey = ey
176 self.center = center
177 self.aspect = float(self.nx) / float(self.ny)
178 self.compute_function = compute_function
179 init_analysis_plot_paths(self, analysis_folder, analysis_prefix)
180
181 def get_dx_dy(self):
182 ext_x = 2 * self.ext_r * self.aspect
183 ext_y = 2 * self.ext_r
184
185 dx = (self.ex[0] * ext_x, self.ex[1] * ext_x, self.ex[2] * ext_x)
186 dy = (self.ey[0] * ext_y, self.ey[1] * ext_y, self.ey[2] * ext_y)
187
188 return dx, dy
189
190 def column_integ_render(self, field_name, field_type, custom_getter=None):
191 dx, dy = self.get_dx_dy()
192 arr_field = self.model.render_cartesian_column_integ(
193 field_name,
194 field_type,
195 center=(self.center[0], self.center[1], self.center[2]),
196 delta_x=dx,
197 delta_y=dy,
198 nx=self.nx,
199 ny=self.ny,
200 custom_getter=custom_getter,
201 )
202
203 return arr_field
204
205 def column_average_render(
206 self, field_name, field_type, min_normalization=1e-9, custom_getter=None
207 ):
208 dx, dy = self.get_dx_dy()
209 arr_field = self.model.render_cartesian_column_integ(
210 field_name,
211 field_type,
212 center=(self.center[0], self.center[1], self.center[2]),
213 delta_x=dx,
214 delta_y=dy,
215 nx=self.nx,
216 ny=self.ny,
217 custom_getter=custom_getter,
218 )
219
220 normalisation = self.model.render_cartesian_column_integ(
221 "unity",
222 "f64",
223 center=(self.center[0], self.center[1], self.center[2]),
224 delta_x=dx,
225 delta_y=dy,
226 nx=self.nx,
227 ny=self.ny,
228 )
229
230 return field_normalize(arr_field, normalisation, min_normalization)
231
232 def slice_render(
233 self,
234 field_name,
235 field_type,
236 do_normalization=True,
237 min_normalization=1e-9,
238 field_transform=None,
239 custom_getter=None,
240 ):
241 dx, dy = self.get_dx_dy()
242 arr_field_data = self.model.render_cartesian_slice(
243 field_name,
244 field_type,
245 center=(self.center[0], self.center[1], self.center[2]),
246 delta_x=dx,
247 delta_y=dy,
248 nx=self.nx,
249 ny=self.ny,
250 custom_getter=custom_getter,
251 )
252
253 if field_transform is not None:
254 arr_field_data = field_transform(arr_field_data)
255
256 if not do_normalization:
257 return arr_field_data
258
259 arr_field_normalization = self.model.render_cartesian_slice(
260 "unity",
261 "f64",
262 center=(self.center[0], self.center[1], self.center[2]),
263 delta_x=dx,
264 delta_y=dy,
265 nx=self.nx,
266 ny=self.ny,
267 )
268
269 return field_normalize(arr_field_data, arr_field_normalization, min_normalization)
270
271 def get_extent(self):
272 x_e_x = (
273 self.ex[0] * self.center[0] + self.ex[1] * self.center[1] + self.ex[2] * self.center[2]
274 )
275 y_e_y = (
276 self.ey[0] * self.center[0] + self.ey[1] * self.center[1] + self.ey[2] * self.center[2]
277 )
278 return [
279 -self.ext_r * self.aspect + x_e_x,
280 self.ext_r * self.aspect + x_e_x,
281 -self.ext_r + y_e_y,
282 self.ext_r + y_e_y,
283 ]
284
285 def compute(self):
286 if self.compute_function is None:
287 raise ValueError("compute_function is not set")
288 return self.compute_function(self)
289
290 def analysis_save(self, iplot, data=None):
291 metadata = {
292 "extent": self.get_extent(),
293 "time": self.model.get_time(),
294 "sinks": self.model.get_sinks(),
295 }
296 if data is None:
297 data = self.compute()
298 analysis_save(iplot, data, metadata, self.npy_data_filename, self.json_data_filename)
299
300 def load_analysis(self, iplot):
301 return load_analysis(iplot, self.json_data_filename, self.npy_data_filename)
302
303 def get_list_analysis_id(self):
304 return get_list_analysis_id(self.glob_str_data)
305
306 def metadata_to_screen_sink_pos(self, metadata):
307 output_list = []
308 for s in metadata["sinks"]:
309 # print(s)
310 x, y, z = s["pos"]
311
312 x_e_x = self.ex[0] * x + self.ex[1] * y + self.ex[2] * z
313 y_e_y = self.ey[0] * x + self.ey[1] * y + self.ey[2] * z
314
315 output_list.append((x_e_x, y_e_y, s))
316 return output_list
317
318 def figure_init(self, holywood_mode=False, dpi=200):
319 figure_init(self.aspect, holywood_mode, dpi, base_size=6, nx=self.nx, ny=self.ny)
320
321 def figure_render_sinks(self, metadata, ax, scale_factor, color, linewidth, fill):
322 sink_list_plot = self.metadata_to_screen_sink_pos(metadata)
323 figure_render_sinks(sink_list_plot, ax, scale_factor, color, linewidth, fill)
324
325 def figure_add_time_info(self, text, holywood_mode=False):
326 figure_add_time_info(text, holywood_mode)
327
328 def figure_add_colorbar(self, imshow_result, label, holywood_mode=False):
329 figure_add_colorbar(imshow_result, label, holywood_mode)
330
331 def make_plot(
332 self,
333 iplot,
334 x_unit=None,
335 y_unit=None,
336 time_unit=None,
337 field_unit=None,
338 x_label=None,
339 y_label=None,
340 field_label=None,
341 holywood_mode=False,
342 cmap="magma",
343 cmap_bad_color="black",
344 contour_list=None,
345 add_sinks=True,
346 sink_scale_factor=1,
347 sink_color="green",
348 sink_linewidth=1,
349 sink_fill=False,
350 save_plot=True,
351 **kwargs,
352 ):
353 if shamrock.sys.world_rank() == 0:
354 field_render, metadata = self.load_analysis(iplot)
355
356 dist_label_x, dist_conv_x = plot_codeu_to_unit(self.model.get_units(), x_unit)
357 dist_label_y, dist_conv_y = plot_codeu_to_unit(self.model.get_units(), y_unit)
358
359 metadata["extent"][0] *= dist_conv_x
360 metadata["extent"][1] *= dist_conv_x
361 metadata["extent"][2] *= dist_conv_y
362 metadata["extent"][3] *= dist_conv_y
363
364 time_label, time_conv = plot_codeu_to_unit(self.model.get_units(), time_unit)
365 metadata["time"] *= time_conv
366
367 field_unit_label, field_conv = plot_codeu_to_unit(self.model.get_units(), field_unit)
368 field_render *= field_conv
369
370 self.figure_init(holywood_mode)
371
372 import copy
373
374 my_cmap = matplotlib.colormaps[cmap].copy() # copy the default cmap
375 my_cmap.set_bad(color=cmap_bad_color)
376
377 # Draw contours and add labels
378 if contour_list is not None:
379 # Create coordinate arrays matching the extent for contour alignment
380 ny, nx = field_render.shape
381 x = np.linspace(metadata["extent"][0], metadata["extent"][1], nx)
382 y = np.linspace(metadata["extent"][2], metadata["extent"][3], ny)
383 X, Y = np.meshgrid(x, y)
384
385 contour_set = plt.contour(
386 X, Y, field_render, levels=contour_list, colors="white", linewidths=0.5
387 )
388
389 plt.clabel(contour_set, inline=True, fontsize=8, fmt="%g")
390
391 res = plt.imshow(
392 field_render, cmap=my_cmap, origin="lower", extent=metadata["extent"], **kwargs
393 )
394
395 ax = plt.gca()
396
397 if add_sinks:
399 metadata, ax, sink_scale_factor, sink_color, sink_linewidth, sink_fill
400 )
401
402 plt.xlabel(f"{x_label} {dist_label_x}")
403 plt.ylabel(f"{y_label} {dist_label_y}")
404
405 text = f"t = {metadata['time']:0.3f} {time_label}"
406 self.figure_add_time_info(text, holywood_mode)
407
408 cmap_label = f"{field_label} {field_unit_label}"
409 self.figure_add_colorbar(res, cmap_label, holywood_mode)
410
411 print(f"Saving plot to {self.plot_filename.format(iplot)}")
412 plt.savefig(self.plot_filename.format(iplot))
413 plt.close()
414
415 def render_all(self, **kwargs):
416 for iplot in self.get_list_analysis_id():
417 self.make_plot(iplot, **kwargs)
418
419 def render_gif(self, gif_filename, save_animation=False, fps=15, bitrate=1800):
420 if shamrock.sys.world_rank() == 0:
421 ani = shamrock.utils.plot.show_image_sequence(self.glob_str_plot, render_gif=True)
422 if save_animation:
423 # To save the animation using Pillow as a gif
424 writer = animation.PillowWriter(
425 fps=fps, metadata=dict(artist="Me"), bitrate=bitrate
426 )
427 ani.save(self.analysis_prefix + gif_filename, writer=writer)
428 return ani
429 return None
430
431
433 def __init__(
434 self,
435 analysis_folder,
436 analysis_prefix,
437 ):
438 os.makedirs(analysis_folder, exist_ok=True)
439
440 self.analysis_prefix = os.path.join(analysis_folder, analysis_prefix) + "_"
441 self.npy_data_filename = self.analysis_prefix + "{:07}.npy"
442 self.glob_str_data = self.analysis_prefix + "*.npy"
443
444 def analysis_save(self, iplot, data):
445 """
446 Save the analysis data npy file
447 """
448 if shamrock.sys.world_rank() == 0:
449 print(f"Saving data to {self.npy_data_filename.format(iplot)}")
450 np.save(self.npy_data_filename.format(iplot), data)
451
452 def load_analysis(self, iplot):
453 """
454 Load the analysis data from the json and npy files
455 """
456 return np.load(self.npy_data_filename.format(iplot), allow_pickle=True)
457
458 def get_list_analysis_id(self):
459 return get_list_analysis_id(self.glob_str_data)
460
461 def make_plot(self, iplot, plot_func):
462
463 if shamrock.sys.world_rank() == 0:
464 plot_func(iplot, self.load_analysis(iplot))
465
466 def render_all(self, plot_func):
467 for iplot in self.get_list_analysis_id():
468 self.make_plot(iplot, plot_func)
469
470 def render_gif(self, glob_str_plot, gif_filename, save_animation=False, fps=15, bitrate=1800):
471 if shamrock.sys.world_rank() == 0:
472 ani = shamrock.utils.plot.show_image_sequence(glob_str_plot, render_gif=True)
473 if save_animation:
474 # To save the animation using Pillow as a gif
475 writer = animation.PillowWriter(
476 fps=fps, metadata=dict(artist="Me"), bitrate=bitrate
477 )
478 ani.save(self.analysis_prefix + gif_filename, writer=writer)
479 return ani
480 return None
make_plot(self, iplot, x_unit=None, y_unit=None, time_unit=None, field_unit=None, x_label=None, y_label=None, field_label=None, holywood_mode=False, cmap="magma", cmap_bad_color="black", contour_list=None, add_sinks=True, sink_scale_factor=1, sink_color="green", sink_linewidth=1, sink_fill=False, save_plot=True, **kwargs)
figure_render_sinks(self, metadata, ax, scale_factor, color, linewidth, fill)
figure_add_colorbar(self, imshow_result, label, holywood_mode=False)
show_image_sequence(glob_str, render_gif=True, dpi=200, interval=50, repeat_delay=10)
Definition __init__.py:37