Note
Go to the end to download the full example code.
Comparing Sedov blast with 8 patches with Phantom#
Restart a Sedov blast simulation from a Phantom dump using 8 patches, run it in Shamrock and compare the results with the original Phantom dump. This test is used to check that the Shamrock solver is able to reproduce the same results as Phantom also when subdomain decomposition is enabled.
11 import numpy as np
12
13 import shamrock
14
15 if not shamrock.sys.is_initialized():
16 shamrock.change_loglevel(1)
17 shamrock.sys.init("0:0")
18
19
20 Npart = 174000
21 split = int(Npart / 2)
22
23
24 def load_dataset(filename):
25
26 print("Loading", filename)
27
28 dump = shamrock.load_phantom_dump(filename)
29 dump.print_state()
30
31 ctx = shamrock.Context()
32 ctx.pdata_layout_new()
33 model = shamrock.get_Model_SPH(context=ctx, vector_type="f64_3", sph_kernel="M4")
34
35 cfg = model.gen_config_from_phantom_dump(dump)
36 # Set the solver config to be the one stored in cfg
37 model.set_solver_config(cfg)
38 # Print the solver config
39 model.get_current_config().print_status()
40
41 model.init_scheduler(split, 1)
42
43 model.init_from_phantom_dump(dump)
44 ret = ctx.collect_data()
45
46 del model
47 del ctx
48
49 return ret
50
51
52 def L2diff_relat(arr1, pos1, arr2, pos2):
53 from scipy.spatial import cKDTree
54
55 pos1 = np.asarray(pos1)
56 pos2 = np.asarray(pos2)
57 arr1 = np.asarray(arr1)
58 arr2 = np.asarray(arr2)
59 tree = cKDTree(pos2)
60 dists, idxs = tree.query(pos1, k=1)
61 matched_arr2 = arr2[idxs]
62 return np.sqrt(np.mean((arr1 - matched_arr2) ** 2))
63
64 # Old way without neigh matching
65 # return np.sqrt(np.mean((arr1 - arr2) ** 2))
66
67
68 def compare_datasets(istep, dataset1, dataset2):
69
70 if shamrock.sys.world_rank() > 0:
71 return
72
73 import matplotlib.pyplot as plt
74
75 fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 6), dpi=125)
76
77 smarker = 1
78 print(len(dataset1["r"]), len(dataset1["rho"]))
79 axs[0, 0].scatter(
80 dataset1["r"],
81 dataset1["rho"],
82 s=smarker * 5,
83 marker="x",
84 c="red",
85 rasterized=True,
86 label="phantom",
87 )
88 axs[0, 1].scatter(
89 dataset1["r"], dataset1["u"], s=smarker * 5, marker="x", c="red", rasterized=True
90 )
91 axs[1, 0].scatter(
92 dataset1["r"], dataset1["vr"], s=smarker * 5, marker="x", c="red", rasterized=True
93 )
94 axs[1, 1].scatter(
95 dataset1["r"], dataset1["alpha"], s=smarker * 5, marker="x", c="red", rasterized=True
96 )
97
98 axs[0, 0].scatter(
99 dataset2["r"], dataset2["rho"], s=smarker, c="black", rasterized=True, label="shamrock"
100 )
101 axs[0, 1].scatter(dataset2["r"], dataset2["u"], s=smarker, c="black", rasterized=True)
102 axs[1, 0].scatter(dataset2["r"], dataset2["vr"], s=smarker, c="black", rasterized=True)
103 axs[1, 1].scatter(dataset2["r"], dataset2["alpha"], s=smarker, c="black", rasterized=True)
104
105 axs[0, 0].set_ylabel(r"$\rho$")
106 axs[1, 0].set_ylabel(r"$v_r$")
107 axs[0, 1].set_ylabel(r"$u$")
108 axs[1, 1].set_ylabel(r"$\alpha$")
109
110 axs[0, 0].set_xlabel("$r$")
111 axs[1, 0].set_xlabel("$r$")
112 axs[0, 1].set_xlabel("$r$")
113 axs[1, 1].set_xlabel("$r$")
114
115 axs[0, 0].set_xlim(0, 0.5)
116 axs[1, 0].set_xlim(0, 0.5)
117 axs[0, 1].set_xlim(0, 0.5)
118 axs[1, 1].set_xlim(0, 0.5)
119
120 axs[0, 0].legend()
121
122 plt.tight_layout()
123
124 L2r = L2diff_relat(dataset1["r"], dataset1["xyz"], dataset2["r"], dataset2["xyz"])
125 L2rho = L2diff_relat(dataset1["rho"], dataset1["xyz"], dataset2["rho"], dataset2["xyz"])
126 L2u = L2diff_relat(dataset1["u"], dataset1["xyz"], dataset2["u"], dataset2["xyz"])
127 L2vr = L2diff_relat(dataset1["vr"], dataset1["xyz"], dataset2["vr"], dataset2["xyz"])
128 L2alpha = L2diff_relat(dataset1["alpha"], dataset1["xyz"], dataset2["alpha"], dataset2["xyz"])
129
130 print("L2r", L2r)
131 print("L2rho", L2rho)
132 print("L2u", L2u)
133 print("L2vr", L2vr)
134 print("L2alpha", L2alpha)
135
136 expected_L2 = {
137 0: [0, 9.00924285345295e-08, 0, 0, 0],
138 1: [
139 1.849032833754011e-15,
140 1.1219057799666405e-07,
141 2.999040994475206e-05,
142 2.779110446924334e-07,
143 1.110758084267404e-06,
144 ],
145 10: [2.36279697e-10, 1.08938225e-07, 4.01017490e-04, 5.00002547e-06, 2.36643265e-02],
146 100: [1.55853980e-08, 6.15520271e-07, 3.11811375e-04, 2.91734592e-05, 6.43234536e-05],
147 1000: [0, 0, 0, 0, 0],
148 }
149
150 tols = {
151 0: [0, 1e-16, 0, 0, 0],
152 1: [0, 1e-16, 1e-16, 1e-18, 1e-19],
153 10: [1e-18, 1e-15, 1e-12, 1e-14, 1e-10],
154 100: [1e-16, 1e-16, 1e-12, 1e-13, 1e-13],
155 1000: [0, 0, 0, 0, 0],
156 }
157
158 error = False
159 if abs(L2r - expected_L2[istep][0]) > tols[istep][0]:
160 error = True
161 if abs(L2rho - expected_L2[istep][1]) > tols[istep][1]:
162 error = True
163 if abs(L2u - expected_L2[istep][2]) > tols[istep][2]:
164 error = True
165 if abs(L2vr - expected_L2[istep][3]) > tols[istep][3]:
166 error = True
167 if abs(L2alpha - expected_L2[istep][4]) > tols[istep][4]:
168 error = True
169
170 plt.savefig("sedov_blast_phantom_comp_" + str(istep) + ".png")
171
172 if error:
173 exit(
174 f"Tolerances are not respected, got \n istep={istep}\n"
175 + f" got: [{float(L2r)}, {float(L2rho)}, {float(L2u)}, {float(L2vr)}, {float(L2alpha)}] \n"
176 + f" expected : [{expected_L2[istep][0]}, {expected_L2[istep][1]}, {expected_L2[istep][2]}, {expected_L2[istep][3]}, {expected_L2[istep][4]}]\n"
177 + f" delta : [{(L2r - expected_L2[istep][0])}, {(L2rho - expected_L2[istep][1])}, {(L2u - expected_L2[istep][2])}, {(L2vr - expected_L2[istep][3])}, {(L2alpha - expected_L2[istep][4])}]\n"
178 + f" tolerance : [{tols[istep][0]}, {tols[istep][1]}, {tols[istep][2]}, {tols[istep][3]}, {tols[istep][4]}]"
179 )
180
181
182 step0000 = load_dataset("reference-files/sedov_blast_phantom/blast_00000")
183 step0001 = load_dataset("reference-files/sedov_blast_phantom/blast_00001")
184 step0010 = load_dataset("reference-files/sedov_blast_phantom/blast_00010")
185 step0100 = load_dataset("reference-files/sedov_blast_phantom/blast_00100")
186 step1000 = load_dataset("reference-files/sedov_blast_phantom/blast_01000")
187
188 print(step0000)
189
190
191 filename_start = "reference-files/sedov_blast_phantom/blast_00000"
192
193 dump = shamrock.load_phantom_dump(filename_start)
194 dump.print_state()
195
196
197 ctx = shamrock.Context()
198 ctx.pdata_layout_new()
199 model = shamrock.get_Model_SPH(context=ctx, vector_type="f64_3", sph_kernel="M4")
200
201 cfg = model.gen_config_from_phantom_dump(dump)
202 cfg.set_boundary_free() # try to force some h iterations
203 # Set the solver config to be the one stored in cfg
204 model.set_solver_config(cfg)
205 # Print the solver config
206 model.get_current_config().print_status()
207
208 model.init_scheduler(split, 1)
209
210 model.init_from_phantom_dump(dump)
211
212 pmass = model.get_particle_mass()
213
214
215 def hpart_to_rho(hpart_array):
216 return pmass * (model.get_hfact() / hpart_array) ** 3
217
218
219 def get_testing_sets(dataset):
220 ret = {}
221
222 if shamrock.sys.world_rank() > 0:
223 return {}
224
225 print("making test dataset, Npart={}".format(len(dataset["xyz"])))
226
227 ret["r"] = np.sqrt(
228 dataset["xyz"][:, 0] ** 2 + dataset["xyz"][:, 1] ** 2 + dataset["xyz"][:, 2] ** 2
229 )
230 ret["rho"] = hpart_to_rho(dataset["hpart"])
231 ret["u"] = dataset["uint"]
232 ret["vr"] = np.sqrt(
233 dataset["vxyz"][:, 0] ** 2 + dataset["vxyz"][:, 1] ** 2 + dataset["vxyz"][:, 2] ** 2
234 )
235 ret["alpha"] = dataset["alpha_AV"]
236 ret["xyz"] = dataset["xyz"]
237
238 # Even though we have neigh matching to compare the datasets
239 # We still need the cutoff, hence the sorting + cutoff
240 index = np.argsort(ret["r"])
241
242 ret["r"] = ret["r"][index]
243 ret["rho"] = ret["rho"][index]
244 ret["u"] = ret["u"][index]
245 ret["vr"] = ret["vr"][index]
246 ret["alpha"] = ret["alpha"][index]
247 ret["xyz"] = ret["xyz"][index]
248
249 cutoff = 50000
250
251 ret["r"] = ret["r"][:cutoff]
252 ret["rho"] = ret["rho"][:cutoff]
253 ret["u"] = ret["u"][:cutoff]
254 ret["vr"] = ret["vr"][:cutoff]
255 ret["alpha"] = ret["alpha"][:cutoff]
256 ret["xyz"] = ret["xyz"][:cutoff]
257
258 return ret
259
260
261 model.evolve_once_override_time(0, 0)
262
263 dt = 1e-5
264 t = 0
265 for i in range(101):
266
267 if i == 0:
268 compare_datasets(i, get_testing_sets(step0000), get_testing_sets(ctx.collect_data()))
269 if i == 1:
270 compare_datasets(i, get_testing_sets(step0001), get_testing_sets(ctx.collect_data()))
271 if i == 10:
272 compare_datasets(i, get_testing_sets(step0010), get_testing_sets(ctx.collect_data()))
273 if i == 100:
274 compare_datasets(i, get_testing_sets(step0100), get_testing_sets(ctx.collect_data()))
275 if i == 1000:
276 compare_datasets(i, get_testing_sets(step1000), get_testing_sets(ctx.collect_data()))
277
278 model.evolve_once_override_time(0, dt)
279 t += dt
280
281
282 # plt.show()
Estimated memory usage: 0 MB