Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
VelocityPlots.py
1import numpy as np
2
3import shamrock.sys
4
5from .StandardPlotHelper import StandardPlotHelper
6
7try:
8 from numba import njit
9
10 _HAS_NUMBA = True
11except ImportError:
12 _HAS_NUMBA = False
13
14
15def SliceVzPlot(
16 model,
17 ext_r,
18 nx,
19 ny,
20 ex,
21 ey,
22 center,
23 analysis_folder,
24 analysis_prefix,
25 do_normalization=True,
26 min_normalization=1e-9,
27):
28 def compute_v_z_slice(helper):
29 def keep_only_v_z(arr_v):
30 return arr_v[:, :, 2]
31
32 arr_v = helper.slice_render(
33 "vxyz", "f64_3", do_normalization, min_normalization, keep_only_v_z
34 )
35
36 return arr_v
37
38 return StandardPlotHelper(
39 model,
40 ext_r,
41 nx,
42 ny,
43 ex,
44 ey,
45 center,
46 analysis_folder,
47 analysis_prefix,
48 compute_function=compute_v_z_slice,
49 )
50
51
52def ColumnAverageVzPlot(
53 model,
54 ext_r,
55 nx,
56 ny,
57 ex,
58 ey,
59 center,
60 analysis_folder,
61 analysis_prefix,
62 min_normalization=1e-9,
63):
64 def compute_v_z_slice(helper):
65 def custom_getter(size: int, dic_out: dict) -> np.array:
66 return dic_out["vxyz"][:, 2]
67
68 arr_v = helper.column_average_render(
69 "custom", "f64", min_normalization, custom_getter=custom_getter
70 )
71
72 return arr_v
73
74 return StandardPlotHelper(
75 model,
76 ext_r,
77 nx,
78 ny,
79 ex,
80 ey,
81 center,
82 analysis_folder,
83 analysis_prefix,
84 compute_function=compute_v_z_slice,
85 )
86
87
88def SliceDiffVthetaProfile(
89 model,
90 ext_r,
91 nx,
92 ny,
93 ex,
94 ey,
95 center,
96 analysis_folder,
97 analysis_prefix,
98 velocity_profile,
99 do_normalization=True,
100 min_normalization=1e-9,
101):
102 def compute_diff_vtheta_profile(helper):
103 if _HAS_NUMBA:
104 if shamrock.sys.world_rank() == 0:
105 print("Using numba for velocity profile in SliceDiffVthetaProfile")
106
107 if _HAS_NUMBA:
108 vel_profile_jit = njit(velocity_profile)
109 else:
110 vel_profile_jit = np.vectorize(velocity_profile)
111
112 def internal(
113 size: int, x: np.array, y: np.array, vx: np.array, vy: np.array, vz: np.array
114 ) -> np.array:
115 r = np.sqrt(x**2 + y**2)
116 r_safe = r + 1e-9
117 v_theta = (-y * vx + x * vy) / r_safe
118 v_relative = v_theta - vel_profile_jit(r)
119 return v_relative
120
121 if _HAS_NUMBA:
122 internal = njit(internal)
123
124 def custom_getter(size: int, dic_out: dict) -> np.array:
125 return internal(
126 size,
127 dic_out["xyz"][:, 0],
128 dic_out["xyz"][:, 1],
129 dic_out["vxyz"][:, 0],
130 dic_out["vxyz"][:, 1],
131 dic_out["vxyz"][:, 2],
132 )
133
134 arr_v = helper.slice_render(
135 "custom",
136 "f64",
137 do_normalization,
138 min_normalization,
139 custom_getter=custom_getter,
140 )
141
142 return arr_v
143
144 return StandardPlotHelper(
145 model,
146 ext_r,
147 nx,
148 ny,
149 ex,
150 ey,
151 center,
152 analysis_folder,
153 analysis_prefix,
154 compute_function=compute_diff_vtheta_profile,
155 )
156
157
158def VerticalShearGradient(
159 model,
160 ext_r,
161 nx,
162 ny,
163 ex,
164 ey,
165 center,
166 analysis_folder,
167 analysis_prefix,
168 do_normalization=True,
169 min_normalization=1e-9,
170):
171 def compute_vertical_shear_gradient(helper):
172 if _HAS_NUMBA:
173 if shamrock.sys.world_rank() == 0:
174 print("Using numba for custom getter in VerticalShearGradient")
175
176 def internal(
177 size: int, x: np.array, y: np.array, vx: np.array, vy: np.array, vz: np.array
178 ) -> np.array:
179 r = np.sqrt(x**2 + y**2)
180 r_safe = r + 1e-9
181 v_theta = (-y * vx + x * vy) / r_safe
182 return v_theta
183
184 if _HAS_NUMBA:
185 internal = njit(internal)
186
187 def custom_getter(size: int, dic_out: dict) -> np.array:
188 return internal(
189 size,
190 dic_out["xyz"][:, 0],
191 dic_out["xyz"][:, 1],
192 dic_out["vxyz"][:, 0],
193 dic_out["vxyz"][:, 1],
194 dic_out["vxyz"][:, 2],
195 )
196
197 arr_v_theta = helper.slice_render(
198 "custom",
199 "f64",
200 do_normalization,
201 min_normalization,
202 custom_getter=custom_getter,
203 )
204
205 extent = helper.get_extent()
206 dy = (extent[3] - extent[2]) / helper.ny
207
208 vert_shear_gradient = np.gradient(arr_v_theta, dy, axis=0) # / dy
209
210 return vert_shear_gradient
211
212 return StandardPlotHelper(
213 model,
214 ext_r,
215 nx,
216 ny,
217 ex,
218 ey,
219 center,
220 analysis_folder,
221 analysis_prefix,
222 compute_function=compute_vertical_shear_gradient,
223 )
224
225
226def gen_angular_momt_custom_getter(model, velocity_profile):
227 pmass = model.get_particle_mass()
228 hfact = model.get_hfact()
229
230 if _HAS_NUMBA:
231 if shamrock.sys.world_rank() == 0:
232 print(
233 "Using numba for velocity profile in SliceAngularMomentumTransportCoefficientPlot"
234 )
235 vel_profile_jit = njit(velocity_profile)
236 else:
237 vel_profile_jit = np.vectorize(velocity_profile)
238
239 def internal(
240 x: np.array,
241 y: np.array,
242 z: np.array,
243 vx: np.array,
244 vy: np.array,
245 vz: np.array,
246 hpart: np.array,
247 cs: np.array,
248 ) -> np.array:
249 rho = pmass * (hfact / hpart) ** 3
250 P = cs**2 * rho # TODO: use true pressure
251
252 r = np.sqrt(x**2 + y**2)
253 r_safe = r + 1e-9
254 v_r = (x * vx + y * vy) / r_safe
255 v_theta = (-y * vx + x * vy) / r_safe
256
257 delta_vtheta = v_theta - vel_profile_jit(r)
258 alpha = rho * v_r * delta_vtheta / P
259
260 return alpha
261
262 if _HAS_NUMBA:
263 if shamrock.sys.world_rank() == 0:
264 print("Using numba for custom getter in SliceAngularMomentumTransportCoefficientPlot")
265 internal = njit(internal)
266
267 def custom_getter(size: int, dic_out: dict) -> np.array:
268 return internal(
269 dic_out["xyz"][:, 0],
270 dic_out["xyz"][:, 1],
271 dic_out["xyz"][:, 2],
272 dic_out["vxyz"][:, 0],
273 dic_out["vxyz"][:, 1],
274 dic_out["vxyz"][:, 2],
275 dic_out["hpart"],
276 dic_out["soundspeed"],
277 )
278
279 return custom_getter
280
281
282def SliceAngularMomentumTransportCoefficientPlot(
283 model,
284 ext_r,
285 nx,
286 ny,
287 ex,
288 ey,
289 center,
290 analysis_folder,
291 analysis_prefix,
292 do_normalization=True,
293 min_normalization=1e-9,
294 velocity_profile=None,
295):
296 def compute_angular_momentum_transport_coefficient(helper):
297 custom_getter = gen_angular_momt_custom_getter(model, velocity_profile)
298
299 arr_v = helper.slice_render(
300 "custom",
301 "f64",
302 do_normalization,
303 min_normalization,
304 custom_getter=custom_getter,
305 )
306
307 return arr_v
308
309 return StandardPlotHelper(
310 model,
311 ext_r,
312 nx,
313 ny,
314 ex,
315 ey,
316 center,
317 analysis_folder,
318 analysis_prefix,
319 compute_function=compute_angular_momentum_transport_coefficient,
320 )
321
322
323def ColumnAverageAngularMomentumTransportCoefficientPlot(
324 model,
325 ext_r,
326 nx,
327 ny,
328 ex,
329 ey,
330 center,
331 analysis_folder,
332 analysis_prefix,
333 min_normalization=1e-9,
334 velocity_profile=None,
335):
336 def compute_angular_momentum_transport_coefficient(helper):
337 custom_getter = gen_angular_momt_custom_getter(model, velocity_profile)
338
339 arr_v = helper.column_average_render(
340 "custom",
341 "f64",
342 min_normalization,
343 custom_getter=custom_getter,
344 )
345 return arr_v
346
347 return StandardPlotHelper(
348 model,
349 ext_r,
350 nx,
351 ny,
352 ex,
353 ey,
354 center,
355 analysis_folder,
356 analysis_prefix,
357 compute_function=compute_angular_momentum_transport_coefficient,
358 )