Note
Go to the end to download the full example code.
Comparing Sedov blast with 1 patch with Phantom#
Restart a Sedov blast simulation from a Phantom dump using 1 patch, run it in Shamrock and compare the results with the original Phantom dump. This test is used to check that the Shamrock solver is able to reproduce the same results as Phantom.
11 import numpy as np
12
13 import shamrock
14
15 if not shamrock.sys.is_initialized():
16 shamrock.change_loglevel(1)
17 shamrock.sys.init("0:0")
18
19
20 Npart = 174000
21 split = int(Npart) * 2
22
23
24 def load_dataset(filename):
25
26 print("Loading", filename)
27
28 dump = shamrock.load_phantom_dump(filename)
29 dump.print_state()
30
31 ctx = shamrock.Context()
32 ctx.pdata_layout_new()
33 model = shamrock.get_Model_SPH(context=ctx, vector_type="f64_3", sph_kernel="M4")
34
35 cfg = model.gen_config_from_phantom_dump(dump)
36 # Set the solver config to be the one stored in cfg
37 model.set_solver_config(cfg)
38 # Print the solver config
39 model.get_current_config().print_status()
40
41 model.init_scheduler(split, 1)
42
43 model.init_from_phantom_dump(dump)
44 ret = ctx.collect_data()
45
46 del model
47 del ctx
48
49 return ret
50
51
52 def L2diff_relat(arr1, pos1, arr2, pos2):
53 from scipy.spatial import cKDTree
54
55 pos1 = np.asarray(pos1)
56 pos2 = np.asarray(pos2)
57 arr1 = np.asarray(arr1)
58 arr2 = np.asarray(arr2)
59 tree = cKDTree(pos2)
60 dists, idxs = tree.query(pos1, k=1)
61 matched_arr2 = arr2[idxs]
62 return np.sqrt(np.mean((arr1 - matched_arr2) ** 2))
63
64 # Old way without neigh matching
65 # return np.sqrt(np.mean((arr1 - arr2) ** 2))
66
67
68 def compare_datasets(istep, dataset1, dataset2):
69
70 if shamrock.sys.world_rank() > 0:
71 return
72
73 import matplotlib.pyplot as plt
74
75 fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 6), dpi=125)
76
77 smarker = 1
78 print(len(dataset1["r"]), len(dataset1["rho"]))
79 axs[0, 0].scatter(
80 dataset1["r"],
81 dataset1["rho"],
82 s=smarker * 5,
83 marker="x",
84 c="red",
85 rasterized=True,
86 label="phantom",
87 )
88 axs[0, 1].scatter(
89 dataset1["r"], dataset1["u"], s=smarker * 5, marker="x", c="red", rasterized=True
90 )
91 axs[1, 0].scatter(
92 dataset1["r"], dataset1["vr"], s=smarker * 5, marker="x", c="red", rasterized=True
93 )
94 axs[1, 1].scatter(
95 dataset1["r"], dataset1["alpha"], s=smarker * 5, marker="x", c="red", rasterized=True
96 )
97
98 axs[0, 0].scatter(
99 dataset2["r"], dataset2["rho"], s=smarker, c="black", rasterized=True, label="shamrock"
100 )
101 axs[0, 1].scatter(dataset2["r"], dataset2["u"], s=smarker, c="black", rasterized=True)
102 axs[1, 0].scatter(dataset2["r"], dataset2["vr"], s=smarker, c="black", rasterized=True)
103 axs[1, 1].scatter(dataset2["r"], dataset2["alpha"], s=smarker, c="black", rasterized=True)
104
105 axs[0, 0].set_ylabel(r"$\rho$")
106 axs[1, 0].set_ylabel(r"$v_r$")
107 axs[0, 1].set_ylabel(r"$u$")
108 axs[1, 1].set_ylabel(r"$\alpha$")
109
110 axs[0, 0].set_xlabel("$r$")
111 axs[1, 0].set_xlabel("$r$")
112 axs[0, 1].set_xlabel("$r$")
113 axs[1, 1].set_xlabel("$r$")
114
115 axs[0, 0].set_xlim(0, 0.5)
116 axs[1, 0].set_xlim(0, 0.5)
117 axs[0, 1].set_xlim(0, 0.5)
118 axs[1, 1].set_xlim(0, 0.5)
119
120 axs[0, 0].legend()
121
122 plt.tight_layout()
123
124 L2r = L2diff_relat(dataset1["r"], dataset1["xyz"], dataset2["r"], dataset2["xyz"])
125 L2rho = L2diff_relat(dataset1["rho"], dataset1["xyz"], dataset2["rho"], dataset2["xyz"])
126 L2u = L2diff_relat(dataset1["u"], dataset1["xyz"], dataset2["u"], dataset2["xyz"])
127 L2vr = L2diff_relat(dataset1["vr"], dataset1["xyz"], dataset2["vr"], dataset2["xyz"])
128 L2alpha = L2diff_relat(dataset1["alpha"], dataset1["xyz"], dataset2["alpha"], dataset2["xyz"])
129
130 print("L2r", L2r)
131 print("L2rho", L2rho)
132 print("L2u", L2u)
133 print("L2vr", L2vr)
134 print("L2alpha", L2alpha)
135
136 expected_L2 = {
137 0: [0.0, 9.009242852618063e-08, 0.0, 0.0, 0.0],
138 1: [
139 1.849032833754011e-15,
140 1.1219057799155968e-07,
141 2.999040994476978e-05,
142 2.77911044692338e-07,
143 1.1107580842674083e-06,
144 ],
145 10: [
146 2.362796968279928e-10,
147 1.0893822456258663e-07,
148 0.0004010174902848735,
149 5.000025464452176e-06,
150 0.02366432648382834,
151 ],
152 100: [
153 1.5585397967807125e-08,
154 6.155202709399902e-07,
155 0.0003118113752459928,
156 2.9173459165073988e-05,
157 6.432345363293235e-05,
158 ],
159 1000: [0, 0, 0, 0, 0],
160 }
161
162 tols = {
163 0: [0, 1e-16, 0, 0, 0],
164 1: [1e-20, 1e-16, 1e-16, 1e-18, 1e-20],
165 10: [1e-20, 1e-16, 1e-15, 1e-17, 1e-17],
166 100: [1e-19, 1e-17, 1e-15, 1e-17, 1e-18],
167 1000: [0, 0, 0, 0, 0],
168 }
169
170 error = False
171 if abs(L2r - expected_L2[istep][0]) > tols[istep][0]:
172 error = True
173 if abs(L2rho - expected_L2[istep][1]) > tols[istep][1]:
174 error = True
175 if abs(L2u - expected_L2[istep][2]) > tols[istep][2]:
176 error = True
177 if abs(L2vr - expected_L2[istep][3]) > tols[istep][3]:
178 error = True
179 if abs(L2alpha - expected_L2[istep][4]) > tols[istep][4]:
180 error = True
181
182 plt.savefig("sedov_blast_phantom_comp_" + str(istep) + ".png")
183
184 if error:
185 exit(
186 f"Tolerances are not respected, got \n istep={istep}\n"
187 + f" got: [{float(L2r)}, {float(L2rho)}, {float(L2u)}, {float(L2vr)}, {float(L2alpha)}] \n"
188 + f" expected : [{expected_L2[istep][0]}, {expected_L2[istep][1]}, {expected_L2[istep][2]}, {expected_L2[istep][3]}, {expected_L2[istep][4]}]\n"
189 + f" delta : [{(L2r - expected_L2[istep][0])}, {(L2rho - expected_L2[istep][1])}, {(L2u - expected_L2[istep][2])}, {(L2vr - expected_L2[istep][3])}, {(L2alpha - expected_L2[istep][4])}]\n"
190 + f" tolerance : [{tols[istep][0]}, {tols[istep][1]}, {tols[istep][2]}, {tols[istep][3]}, {tols[istep][4]}]"
191 )
192
193
194 step0000 = load_dataset("reference-files/sedov_blast_phantom/blast_00000")
195 step0001 = load_dataset("reference-files/sedov_blast_phantom/blast_00001")
196 step0010 = load_dataset("reference-files/sedov_blast_phantom/blast_00010")
197 step0100 = load_dataset("reference-files/sedov_blast_phantom/blast_00100")
198 step1000 = load_dataset("reference-files/sedov_blast_phantom/blast_01000")
199
200 print(step0000)
201
202
203 filename_start = "reference-files/sedov_blast_phantom/blast_00000"
204
205 dump = shamrock.load_phantom_dump(filename_start)
206 dump.print_state()
207
208
209 ctx = shamrock.Context()
210 ctx.pdata_layout_new()
211 model = shamrock.get_Model_SPH(context=ctx, vector_type="f64_3", sph_kernel="M4")
212
213 cfg = model.gen_config_from_phantom_dump(dump)
214 cfg.set_boundary_free() # try to force some h iterations
215 # Set the solver config to be the one stored in cfg
216 model.set_solver_config(cfg)
217 # Print the solver config
218 model.get_current_config().print_status()
219
220 model.init_scheduler(split, 1)
221
222 model.init_from_phantom_dump(dump)
223
224 pmass = model.get_particle_mass()
225
226
227 def hpart_to_rho(hpart_array):
228 return pmass * (model.get_hfact() / hpart_array) ** 3
229
230
231 def get_testing_sets(dataset):
232 ret = {}
233
234 if shamrock.sys.world_rank() > 0:
235 return {}
236
237 print("making test dataset, Npart={}".format(len(dataset["xyz"])))
238
239 ret["r"] = np.sqrt(
240 dataset["xyz"][:, 0] ** 2 + dataset["xyz"][:, 1] ** 2 + dataset["xyz"][:, 2] ** 2
241 )
242 ret["rho"] = hpart_to_rho(dataset["hpart"])
243 ret["u"] = dataset["uint"]
244 ret["vr"] = np.sqrt(
245 dataset["vxyz"][:, 0] ** 2 + dataset["vxyz"][:, 1] ** 2 + dataset["vxyz"][:, 2] ** 2
246 )
247 ret["alpha"] = dataset["alpha_AV"]
248 ret["xyz"] = dataset["xyz"]
249
250 # Even though we have neigh matching to compare the datasets
251 # We still need the cutoff, hence the sorting + cutoff
252 index = np.argsort(ret["r"])
253
254 ret["r"] = ret["r"][index]
255 ret["rho"] = ret["rho"][index]
256 ret["u"] = ret["u"][index]
257 ret["vr"] = ret["vr"][index]
258 ret["alpha"] = ret["alpha"][index]
259 ret["xyz"] = ret["xyz"][index]
260
261 cutoff = 50000
262
263 ret["r"] = ret["r"][:cutoff]
264 ret["rho"] = ret["rho"][:cutoff]
265 ret["u"] = ret["u"][:cutoff]
266 ret["vr"] = ret["vr"][:cutoff]
267 ret["alpha"] = ret["alpha"][:cutoff]
268 ret["xyz"] = ret["xyz"][:cutoff]
269
270 return ret
271
272
273 model.evolve_once_override_time(0, 0)
274
275 dt = 1e-5
276 t = 0
277 for i in range(101):
278
279 if i == 0:
280 compare_datasets(i, get_testing_sets(step0000), get_testing_sets(ctx.collect_data()))
281 if i == 1:
282 compare_datasets(i, get_testing_sets(step0001), get_testing_sets(ctx.collect_data()))
283 if i == 10:
284 compare_datasets(i, get_testing_sets(step0010), get_testing_sets(ctx.collect_data()))
285 if i == 100:
286 compare_datasets(i, get_testing_sets(step0100), get_testing_sets(ctx.collect_data()))
287 if i == 1000:
288 compare_datasets(i, get_testing_sets(step1000), get_testing_sets(ctx.collect_data()))
289
290 model.evolve_once_override_time(0, dt)
291 t += dt
292
293
294 # plt.show()
Estimated memory usage: 0 MB