Note
Go to the end to download the full example code.
Testing Extreme Blast Wave with GSPH#
CI test for the extreme blast wave problem from Inutsuka 2002 (Section 4.3). This is a severe test with Mach number ~10^5.
- Initial conditions (from Inutsuka 2002):
rho_L = 1, rho_R = 1 P_L = 3000, P_R = 1e-7 v_L = 0, v_R = 0
14 import numpy as np
15
16 import shamrock
17
18 gamma = 1.4
19 rho_L, rho_R = 1.0, 1.0
20 P_L, P_R = 3000.0, 1e-7
21 u_L = P_L / ((gamma - 1) * rho_L)
22 u_R = P_R / ((gamma - 1) * rho_R)
23 resol = 100
24
25 ctx = shamrock.Context()
26 ctx.pdata_layout_new()
27
28 model = shamrock.get_Model_GSPH(context=ctx, vector_type="f64_3", sph_kernel="M4")
29 cfg = model.gen_default_config()
30 cfg.set_riemann_hllc()
31 cfg.set_reconstruct_piecewise_constant()
32 cfg.set_boundary_periodic()
33 cfg.set_eos_adiabatic(gamma)
34 cfg.print_status()
35 model.set_solver_config(cfg)
36 model.init_scheduler(int(1e8), 1)
37
38 (xs, ys, zs) = model.get_box_dim_fcc_3d(1, resol, 24, 24)
39 dr = 1 / xs
40 (xs, ys, zs) = model.get_box_dim_fcc_3d(dr, resol, 24, 24)
41 model.resize_simulation_box((-xs, -ys / 2, -zs / 2), (xs, ys / 2, zs / 2))
42
43 model.add_cube_hcp_3d(dr, (-xs, -ys / 2, -zs / 2), (0, ys / 2, zs / 2))
44 model.add_cube_hcp_3d(dr, (0, -ys / 2, -zs / 2), (xs, ys / 2, zs / 2))
45 model.set_field_in_box("uint", "f64", u_L, (-xs, -ys / 2, -zs / 2), (0, ys / 2, zs / 2))
46 model.set_field_in_box("uint", "f64", u_R, (0, -ys / 2, -zs / 2), (xs, ys / 2, zs / 2))
47
48 vol_b = xs * ys * zs
49 totmass = (rho_R * vol_b) + (rho_L * vol_b)
50 pmass = model.total_mass_to_part_mass(totmass)
51 model.set_particle_mass(pmass)
52 hfact = model.get_hfact()
53
54 model.set_cfl_cour(0.3)
55 model.set_cfl_force(0.3)
56
57 t_target = 0.015
58 print(f"GSPH Extreme Blast Wave Test (M4, HLLC, t={t_target})")
59 model.evolve_until(t_target)
60
61 sod = shamrock.phys.SodTube(gamma=gamma, rho_1=rho_L, P_1=P_L, rho_5=rho_R, P_5=P_R)
62
63 data = ctx.collect_data()
64
65
66 def compute_L2_errors(data, sod, t, x_min, x_max):
67 """Compute L2 errors using ctx.collect_data() (no pyvista dependency)."""
68 points = np.array(data["xyz"])
69 velocities = np.array(data["vxyz"])
70 hpart = np.array(data["hpart"])
71 uint = np.array(data["uint"])
72
73 rho_sim = pmass * (hfact / hpart) ** 3
74 P_sim = (gamma - 1) * rho_sim * uint
75
76 x, vx, vy, vz = points[:, 0], velocities[:, 0], velocities[:, 1], velocities[:, 2]
77 mask = (x >= x_min) & (x <= x_max)
78 x_f, rho_f, vx_f, vy_f, vz_f, P_f = (
79 x[mask],
80 rho_sim[mask],
81 vx[mask],
82 vy[mask],
83 vz[mask],
84 P_sim[mask],
85 )
86
87 if len(x_f) == 0:
88 raise RuntimeError("No particles in analysis region")
89
90 rho_ana, vx_ana, P_ana = np.zeros(len(x_f)), np.zeros(len(x_f)), np.zeros(len(x_f))
91 for i, xi in enumerate(x_f):
92 rho_ana[i], vx_ana[i], P_ana[i] = sod.get_value(t, xi)
93
94 rho_norm = max(np.mean(rho_ana), 1e-10)
95 vx_norm = max(np.mean(np.abs(vx_ana)), 0.1)
96 P_norm = max(np.mean(P_ana), 1e-10)
97
98 err_rho = np.sqrt(np.mean((rho_f - rho_ana) ** 2)) / rho_norm
99 err_vx = np.sqrt(np.mean((vx_f - vx_ana) ** 2)) / vx_norm
100 err_vy = np.sqrt(np.mean(vy_f**2))
101 err_vz = np.sqrt(np.mean(vz_f**2))
102 err_P = np.sqrt(np.mean((P_f - P_ana) ** 2)) / P_norm
103 return err_rho, (err_vx, err_vy, err_vz), err_P
104
105
106 if shamrock.sys.world_rank() == 0:
107 rho, v, P = compute_L2_errors(data, sod, t_target, -0.5, 0.5)
108 vx, vy, vz = v
109
110 print("current errors :")
111 print(f"err_rho = {rho}")
112 print(f"err_vx = {vx}")
113 print(f"err_vy = {vy}")
114 print(f"err_vz = {vz}")
115 print(f"err_P = {P}")
116
117 # Expected L2 error values (calibrated from CI run with M4 kernel)
118 expect_rho = 10.688658207003348
119 expect_vx = 1.0420471749025182
120 expect_vy = 0.11766417324542999
121 expect_vz = 0.0027436730451881886
122 expect_P = 1.6660643954434153
123
124 tol = 1e-8
125
126 test_pass = True
127 err_log = ""
128
129 error_checks = {
130 "rho": (rho, expect_rho),
131 "vx": (vx, expect_vx),
132 "vy": (vy, expect_vy),
133 "vz": (vz, expect_vz),
134 "P": (P, expect_P),
135 }
136
137 for name, (value, expected) in error_checks.items():
138 if abs(value - expected) > tol * expected:
139 err_log += f"error on {name} is outside of tolerances:\n"
140 err_log += f" expected error = {expected} +- {tol * expected}\n"
141 err_log += (
142 f" obtained error = {value} (relative error = {(value - expected) / expected})\n"
143 )
144 test_pass = False
145
146 if test_pass:
147 print("\n" + "=" * 50)
148 print("GSPH Extreme Blast Wave Test: PASSED")
149 print("=" * 50)
150 else:
151 exit("Test did not pass L2 margins : \n" + err_log)
Estimated memory usage: 0 MB