Note
Go to the end to download the full example code.
Comparing Sedov blast with 1 patch with Phantom#
Restart a Sedov blast simulation from a Phantom dump using 1 patch, run it in Shamrock and compare the results with the original Phantom dump. This test is used to check that the Shamrock solver is able to reproduce the same results as Phantom.
11 import numpy as np
12
13 import shamrock
14
15 if not shamrock.sys.is_initialized():
16 shamrock.change_loglevel(1)
17 shamrock.sys.init("0:0")
18
19
20 Npart = 174000
21 split = int(Npart) * 2
22
23
24 def load_dataset(filename):
25 print("Loading", filename)
26
27 dump = shamrock.load_phantom_dump(filename)
28 dump.print_state()
29
30 ctx = shamrock.Context()
31 ctx.pdata_layout_new()
32 model = shamrock.get_Model_SPH(context=ctx, vector_type="f64_3", sph_kernel="M4")
33
34 cfg = model.gen_config_from_phantom_dump(dump)
35 # Set the solver config to be the one stored in cfg
36 model.set_solver_config(cfg)
37 # Print the solver config
38 model.get_current_config().print_status()
39
40 model.init_scheduler(split, 1)
41
42 model.init_from_phantom_dump(dump)
43 ret = ctx.collect_data()
44
45 del model
46 del ctx
47
48 return ret
49
50
51 def L2diff_relat(arr1, pos1, arr2, pos2):
52 from scipy.spatial import cKDTree
53
54 pos1 = np.asarray(pos1)
55 pos2 = np.asarray(pos2)
56 arr1 = np.asarray(arr1)
57 arr2 = np.asarray(arr2)
58 tree = cKDTree(pos2)
59 dists, idxs = tree.query(pos1, k=1)
60 matched_arr2 = arr2[idxs]
61 return np.sqrt(np.mean((arr1 - matched_arr2) ** 2))
62
63 # Old way without neigh matching
64 # return np.sqrt(np.mean((arr1 - arr2) ** 2))
65
66
67 def compare_datasets(istep, dataset1, dataset2):
68 if shamrock.sys.world_rank() > 0:
69 return
70
71 import matplotlib.pyplot as plt
72
73 fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 6), dpi=125)
74
75 smarker = 1
76 print(len(dataset1["r"]), len(dataset1["rho"]))
77 axs[0, 0].scatter(
78 dataset1["r"],
79 dataset1["rho"],
80 s=smarker * 5,
81 marker="x",
82 c="red",
83 rasterized=True,
84 label="phantom",
85 )
86 axs[0, 1].scatter(
87 dataset1["r"], dataset1["u"], s=smarker * 5, marker="x", c="red", rasterized=True
88 )
89 axs[1, 0].scatter(
90 dataset1["r"], dataset1["vr"], s=smarker * 5, marker="x", c="red", rasterized=True
91 )
92 axs[1, 1].scatter(
93 dataset1["r"], dataset1["alpha"], s=smarker * 5, marker="x", c="red", rasterized=True
94 )
95
96 axs[0, 0].scatter(
97 dataset2["r"], dataset2["rho"], s=smarker, c="black", rasterized=True, label="shamrock"
98 )
99 axs[0, 1].scatter(dataset2["r"], dataset2["u"], s=smarker, c="black", rasterized=True)
100 axs[1, 0].scatter(dataset2["r"], dataset2["vr"], s=smarker, c="black", rasterized=True)
101 axs[1, 1].scatter(dataset2["r"], dataset2["alpha"], s=smarker, c="black", rasterized=True)
102
103 axs[0, 0].set_ylabel(r"$\rho$")
104 axs[1, 0].set_ylabel(r"$v_r$")
105 axs[0, 1].set_ylabel(r"$u$")
106 axs[1, 1].set_ylabel(r"$\alpha$")
107
108 axs[0, 0].set_xlabel("$r$")
109 axs[1, 0].set_xlabel("$r$")
110 axs[0, 1].set_xlabel("$r$")
111 axs[1, 1].set_xlabel("$r$")
112
113 axs[0, 0].set_xlim(0, 0.5)
114 axs[1, 0].set_xlim(0, 0.5)
115 axs[0, 1].set_xlim(0, 0.5)
116 axs[1, 1].set_xlim(0, 0.5)
117
118 axs[0, 0].legend()
119
120 plt.tight_layout()
121
122 L2r = L2diff_relat(dataset1["r"], dataset1["xyz"], dataset2["r"], dataset2["xyz"])
123 L2rho = L2diff_relat(dataset1["rho"], dataset1["xyz"], dataset2["rho"], dataset2["xyz"])
124 L2u = L2diff_relat(dataset1["u"], dataset1["xyz"], dataset2["u"], dataset2["xyz"])
125 L2vr = L2diff_relat(dataset1["vr"], dataset1["xyz"], dataset2["vr"], dataset2["xyz"])
126 L2alpha = L2diff_relat(dataset1["alpha"], dataset1["xyz"], dataset2["alpha"], dataset2["xyz"])
127
128 print("L2r", L2r)
129 print("L2rho", L2rho)
130 print("L2u", L2u)
131 print("L2vr", L2vr)
132 print("L2alpha", L2alpha)
133
134 expected_L2 = {
135 0: [0.0, 9.009242852618063e-08, 0.0, 0.0, 0.0],
136 1: [
137 1.849032833754011e-15,
138 1.1219057799155968e-07,
139 2.999040994476978e-05,
140 2.77911044692338e-07,
141 1.1107580842674083e-06,
142 ],
143 10: [
144 2.362796968279928e-10,
145 1.0893822456258663e-07,
146 0.0004010174902848735,
147 5.000025464452176e-06,
148 0.02366432648382834,
149 ],
150 100: [
151 1.5585397967807125e-08,
152 6.155202709399902e-07,
153 0.0003118113752459928,
154 2.9173459165073988e-05,
155 6.432345363293235e-05,
156 ],
157 1000: [0, 0, 0, 0, 0],
158 }
159
160 tols = {
161 0: [0, 1e-16, 0, 0, 0],
162 1: [1e-20, 1e-16, 1e-16, 1e-18, 1e-20],
163 10: [1e-20, 1e-16, 1e-15, 1e-17, 1e-17],
164 100: [1e-19, 1e-17, 1e-15, 1e-17, 1e-18],
165 1000: [0, 0, 0, 0, 0],
166 }
167
168 error = False
169 if abs(L2r - expected_L2[istep][0]) > tols[istep][0]:
170 error = True
171 if abs(L2rho - expected_L2[istep][1]) > tols[istep][1]:
172 error = True
173 if abs(L2u - expected_L2[istep][2]) > tols[istep][2]:
174 error = True
175 if abs(L2vr - expected_L2[istep][3]) > tols[istep][3]:
176 error = True
177 if abs(L2alpha - expected_L2[istep][4]) > tols[istep][4]:
178 error = True
179
180 plt.savefig("sedov_blast_phantom_comp_" + str(istep) + ".png")
181
182 if error:
183 exit(
184 f"Tolerances are not respected, got \n istep={istep}\n"
185 + f" got: [{float(L2r)}, {float(L2rho)}, {float(L2u)}, {float(L2vr)}, {float(L2alpha)}] \n"
186 + f" expected : [{expected_L2[istep][0]}, {expected_L2[istep][1]}, {expected_L2[istep][2]}, {expected_L2[istep][3]}, {expected_L2[istep][4]}]\n"
187 + f" delta : [{(L2r - expected_L2[istep][0])}, {(L2rho - expected_L2[istep][1])}, {(L2u - expected_L2[istep][2])}, {(L2vr - expected_L2[istep][3])}, {(L2alpha - expected_L2[istep][4])}]\n"
188 + f" tolerance : [{tols[istep][0]}, {tols[istep][1]}, {tols[istep][2]}, {tols[istep][3]}, {tols[istep][4]}]"
189 )
190
191
192 step0000 = load_dataset("reference-files/sedov_blast_phantom/blast_00000")
193 step0001 = load_dataset("reference-files/sedov_blast_phantom/blast_00001")
194 step0010 = load_dataset("reference-files/sedov_blast_phantom/blast_00010")
195 step0100 = load_dataset("reference-files/sedov_blast_phantom/blast_00100")
196 step1000 = load_dataset("reference-files/sedov_blast_phantom/blast_01000")
197
198 print(step0000)
199
200
201 filename_start = "reference-files/sedov_blast_phantom/blast_00000"
202
203 dump = shamrock.load_phantom_dump(filename_start)
204 dump.print_state()
205
206
207 ctx = shamrock.Context()
208 ctx.pdata_layout_new()
209 model = shamrock.get_Model_SPH(context=ctx, vector_type="f64_3", sph_kernel="M4")
210
211 cfg = model.gen_config_from_phantom_dump(dump)
212 cfg.set_boundary_free() # try to force some h iterations
213 # Set the solver config to be the one stored in cfg
214 model.set_solver_config(cfg)
215 # Print the solver config
216 model.get_current_config().print_status()
217
218 model.init_scheduler(split, 1)
219
220 model.init_from_phantom_dump(dump)
221
222 pmass = model.get_particle_mass()
223
224
225 def hpart_to_rho(hpart_array):
226 return pmass * (model.get_hfact() / hpart_array) ** 3
227
228
229 def get_testing_sets(dataset):
230 ret = {}
231
232 if shamrock.sys.world_rank() > 0:
233 return {}
234
235 print("making test dataset, Npart={}".format(len(dataset["xyz"])))
236
237 ret["r"] = np.sqrt(
238 dataset["xyz"][:, 0] ** 2 + dataset["xyz"][:, 1] ** 2 + dataset["xyz"][:, 2] ** 2
239 )
240 ret["rho"] = hpart_to_rho(dataset["hpart"])
241 ret["u"] = dataset["uint"]
242 ret["vr"] = np.sqrt(
243 dataset["vxyz"][:, 0] ** 2 + dataset["vxyz"][:, 1] ** 2 + dataset["vxyz"][:, 2] ** 2
244 )
245 ret["alpha"] = dataset["alpha_AV"]
246 ret["xyz"] = dataset["xyz"]
247
248 # Even though we have neigh matching to compare the datasets
249 # We still need the cutoff, hence the sorting + cutoff
250 index = np.argsort(ret["r"])
251
252 ret["r"] = ret["r"][index]
253 ret["rho"] = ret["rho"][index]
254 ret["u"] = ret["u"][index]
255 ret["vr"] = ret["vr"][index]
256 ret["alpha"] = ret["alpha"][index]
257 ret["xyz"] = ret["xyz"][index]
258
259 cutoff = 50000
260
261 ret["r"] = ret["r"][:cutoff]
262 ret["rho"] = ret["rho"][:cutoff]
263 ret["u"] = ret["u"][:cutoff]
264 ret["vr"] = ret["vr"][:cutoff]
265 ret["alpha"] = ret["alpha"][:cutoff]
266 ret["xyz"] = ret["xyz"][:cutoff]
267
268 return ret
269
270
271 model.evolve_once_override_time(0, 0)
272
273 dt = 1e-5
274 t = 0
275 for i in range(101):
276 if i == 0:
277 compare_datasets(i, get_testing_sets(step0000), get_testing_sets(ctx.collect_data()))
278 if i == 1:
279 compare_datasets(i, get_testing_sets(step0001), get_testing_sets(ctx.collect_data()))
280 if i == 10:
281 compare_datasets(i, get_testing_sets(step0010), get_testing_sets(ctx.collect_data()))
282 if i == 100:
283 compare_datasets(i, get_testing_sets(step0100), get_testing_sets(ctx.collect_data()))
284 if i == 1000:
285 compare_datasets(i, get_testing_sets(step1000), get_testing_sets(ctx.collect_data()))
286
287 model.evolve_once_override_time(0, dt)
288 t += dt
289
290
291 # plt.show()
Estimated memory usage: 0 MB