Note
Go to the end to download the full example code.
Comparing Sedov blast with 8 patches with Phantom#
Restart a Sedov blast simulation from a Phantom dump using 8 patches, run it in Shamrock and compare the results with the original Phantom dump. This test is used to check that the Shamrock solver is able to reproduce the same results as Phantom also when subdomain decomposition is enabled.
11 import numpy as np
12
13 import shamrock
14
15 if not shamrock.sys.is_initialized():
16 shamrock.change_loglevel(1)
17 shamrock.sys.init("0:0")
18
19
20 Npart = 174000
21 split = int(Npart / 2)
22
23
24 def load_dataset(filename):
25 print("Loading", filename)
26
27 dump = shamrock.load_phantom_dump(filename)
28 dump.print_state()
29
30 ctx = shamrock.Context()
31 ctx.pdata_layout_new()
32 model = shamrock.get_Model_SPH(context=ctx, vector_type="f64_3", sph_kernel="M4")
33
34 cfg = model.gen_config_from_phantom_dump(dump)
35 # Set the solver config to be the one stored in cfg
36 model.set_solver_config(cfg)
37 # Print the solver config
38 model.get_current_config().print_status()
39
40 model.init_scheduler(split, 1)
41
42 model.init_from_phantom_dump(dump)
43 ret = ctx.collect_data()
44
45 del model
46 del ctx
47
48 return ret
49
50
51 def L2diff_relat(arr1, pos1, arr2, pos2):
52 from scipy.spatial import cKDTree
53
54 pos1 = np.asarray(pos1)
55 pos2 = np.asarray(pos2)
56 arr1 = np.asarray(arr1)
57 arr2 = np.asarray(arr2)
58 tree = cKDTree(pos2)
59 dists, idxs = tree.query(pos1, k=1)
60 matched_arr2 = arr2[idxs]
61 return np.sqrt(np.mean((arr1 - matched_arr2) ** 2))
62
63 # Old way without neigh matching
64 # return np.sqrt(np.mean((arr1 - arr2) ** 2))
65
66
67 def compare_datasets(istep, dataset1, dataset2):
68 if shamrock.sys.world_rank() > 0:
69 return
70
71 import matplotlib.pyplot as plt
72
73 fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 6), dpi=125)
74
75 smarker = 1
76 print(len(dataset1["r"]), len(dataset1["rho"]))
77 axs[0, 0].scatter(
78 dataset1["r"],
79 dataset1["rho"],
80 s=smarker * 5,
81 marker="x",
82 c="red",
83 rasterized=True,
84 label="phantom",
85 )
86 axs[0, 1].scatter(
87 dataset1["r"], dataset1["u"], s=smarker * 5, marker="x", c="red", rasterized=True
88 )
89 axs[1, 0].scatter(
90 dataset1["r"], dataset1["vr"], s=smarker * 5, marker="x", c="red", rasterized=True
91 )
92 axs[1, 1].scatter(
93 dataset1["r"], dataset1["alpha"], s=smarker * 5, marker="x", c="red", rasterized=True
94 )
95
96 axs[0, 0].scatter(
97 dataset2["r"], dataset2["rho"], s=smarker, c="black", rasterized=True, label="shamrock"
98 )
99 axs[0, 1].scatter(dataset2["r"], dataset2["u"], s=smarker, c="black", rasterized=True)
100 axs[1, 0].scatter(dataset2["r"], dataset2["vr"], s=smarker, c="black", rasterized=True)
101 axs[1, 1].scatter(dataset2["r"], dataset2["alpha"], s=smarker, c="black", rasterized=True)
102
103 axs[0, 0].set_ylabel(r"$\rho$")
104 axs[1, 0].set_ylabel(r"$v_r$")
105 axs[0, 1].set_ylabel(r"$u$")
106 axs[1, 1].set_ylabel(r"$\alpha$")
107
108 axs[0, 0].set_xlabel("$r$")
109 axs[1, 0].set_xlabel("$r$")
110 axs[0, 1].set_xlabel("$r$")
111 axs[1, 1].set_xlabel("$r$")
112
113 axs[0, 0].set_xlim(0, 0.5)
114 axs[1, 0].set_xlim(0, 0.5)
115 axs[0, 1].set_xlim(0, 0.5)
116 axs[1, 1].set_xlim(0, 0.5)
117
118 axs[0, 0].legend()
119
120 plt.tight_layout()
121
122 L2r = L2diff_relat(dataset1["r"], dataset1["xyz"], dataset2["r"], dataset2["xyz"])
123 L2rho = L2diff_relat(dataset1["rho"], dataset1["xyz"], dataset2["rho"], dataset2["xyz"])
124 L2u = L2diff_relat(dataset1["u"], dataset1["xyz"], dataset2["u"], dataset2["xyz"])
125 L2vr = L2diff_relat(dataset1["vr"], dataset1["xyz"], dataset2["vr"], dataset2["xyz"])
126 L2alpha = L2diff_relat(dataset1["alpha"], dataset1["xyz"], dataset2["alpha"], dataset2["xyz"])
127
128 print("L2r", L2r)
129 print("L2rho", L2rho)
130 print("L2u", L2u)
131 print("L2vr", L2vr)
132 print("L2alpha", L2alpha)
133
134 expected_L2 = {
135 0: [0, 9.00924285345295e-08, 0, 0, 0],
136 1: [
137 1.849032833754011e-15,
138 1.1219057799666405e-07,
139 2.999040994475206e-05,
140 2.779110446924334e-07,
141 1.110758084267404e-06,
142 ],
143 10: [2.36279697e-10, 1.08938225e-07, 4.01017490e-04, 5.00002547e-06, 2.36643265e-02],
144 100: [1.55853980e-08, 6.15520271e-07, 3.11811375e-04, 2.91734592e-05, 6.43234536e-05],
145 1000: [0, 0, 0, 0, 0],
146 }
147
148 tols = {
149 0: [0, 1e-16, 0, 0, 0],
150 1: [0, 1e-16, 1e-16, 1e-18, 1e-19],
151 10: [1e-18, 1e-15, 1e-12, 1e-14, 1e-10],
152 100: [1e-16, 1e-16, 1e-12, 1e-13, 1e-13],
153 1000: [0, 0, 0, 0, 0],
154 }
155
156 error = False
157 if abs(L2r - expected_L2[istep][0]) > tols[istep][0]:
158 error = True
159 if abs(L2rho - expected_L2[istep][1]) > tols[istep][1]:
160 error = True
161 if abs(L2u - expected_L2[istep][2]) > tols[istep][2]:
162 error = True
163 if abs(L2vr - expected_L2[istep][3]) > tols[istep][3]:
164 error = True
165 if abs(L2alpha - expected_L2[istep][4]) > tols[istep][4]:
166 error = True
167
168 plt.savefig("sedov_blast_phantom_comp_" + str(istep) + ".png")
169
170 if error:
171 exit(
172 f"Tolerances are not respected, got \n istep={istep}\n"
173 + f" got: [{float(L2r)}, {float(L2rho)}, {float(L2u)}, {float(L2vr)}, {float(L2alpha)}] \n"
174 + f" expected : [{expected_L2[istep][0]}, {expected_L2[istep][1]}, {expected_L2[istep][2]}, {expected_L2[istep][3]}, {expected_L2[istep][4]}]\n"
175 + f" delta : [{(L2r - expected_L2[istep][0])}, {(L2rho - expected_L2[istep][1])}, {(L2u - expected_L2[istep][2])}, {(L2vr - expected_L2[istep][3])}, {(L2alpha - expected_L2[istep][4])}]\n"
176 + f" tolerance : [{tols[istep][0]}, {tols[istep][1]}, {tols[istep][2]}, {tols[istep][3]}, {tols[istep][4]}]"
177 )
178
179
180 step0000 = load_dataset("reference-files/sedov_blast_phantom/blast_00000")
181 step0001 = load_dataset("reference-files/sedov_blast_phantom/blast_00001")
182 step0010 = load_dataset("reference-files/sedov_blast_phantom/blast_00010")
183 step0100 = load_dataset("reference-files/sedov_blast_phantom/blast_00100")
184 step1000 = load_dataset("reference-files/sedov_blast_phantom/blast_01000")
185
186 print(step0000)
187
188
189 filename_start = "reference-files/sedov_blast_phantom/blast_00000"
190
191 dump = shamrock.load_phantom_dump(filename_start)
192 dump.print_state()
193
194
195 ctx = shamrock.Context()
196 ctx.pdata_layout_new()
197 model = shamrock.get_Model_SPH(context=ctx, vector_type="f64_3", sph_kernel="M4")
198
199 cfg = model.gen_config_from_phantom_dump(dump)
200 cfg.set_boundary_free() # try to force some h iterations
201 # Set the solver config to be the one stored in cfg
202 model.set_solver_config(cfg)
203 # Print the solver config
204 model.get_current_config().print_status()
205
206 model.init_scheduler(split, 1)
207
208 model.init_from_phantom_dump(dump)
209
210 pmass = model.get_particle_mass()
211
212
213 def hpart_to_rho(hpart_array):
214 return pmass * (model.get_hfact() / hpart_array) ** 3
215
216
217 def get_testing_sets(dataset):
218 ret = {}
219
220 if shamrock.sys.world_rank() > 0:
221 return {}
222
223 print("making test dataset, Npart={}".format(len(dataset["xyz"])))
224
225 ret["r"] = np.sqrt(
226 dataset["xyz"][:, 0] ** 2 + dataset["xyz"][:, 1] ** 2 + dataset["xyz"][:, 2] ** 2
227 )
228 ret["rho"] = hpart_to_rho(dataset["hpart"])
229 ret["u"] = dataset["uint"]
230 ret["vr"] = np.sqrt(
231 dataset["vxyz"][:, 0] ** 2 + dataset["vxyz"][:, 1] ** 2 + dataset["vxyz"][:, 2] ** 2
232 )
233 ret["alpha"] = dataset["alpha_AV"]
234 ret["xyz"] = dataset["xyz"]
235
236 # Even though we have neigh matching to compare the datasets
237 # We still need the cutoff, hence the sorting + cutoff
238 index = np.argsort(ret["r"])
239
240 ret["r"] = ret["r"][index]
241 ret["rho"] = ret["rho"][index]
242 ret["u"] = ret["u"][index]
243 ret["vr"] = ret["vr"][index]
244 ret["alpha"] = ret["alpha"][index]
245 ret["xyz"] = ret["xyz"][index]
246
247 cutoff = 50000
248
249 ret["r"] = ret["r"][:cutoff]
250 ret["rho"] = ret["rho"][:cutoff]
251 ret["u"] = ret["u"][:cutoff]
252 ret["vr"] = ret["vr"][:cutoff]
253 ret["alpha"] = ret["alpha"][:cutoff]
254 ret["xyz"] = ret["xyz"][:cutoff]
255
256 return ret
257
258
259 model.evolve_once_override_time(0, 0)
260
261 dt = 1e-5
262 t = 0
263 for i in range(101):
264 if i == 0:
265 compare_datasets(i, get_testing_sets(step0000), get_testing_sets(ctx.collect_data()))
266 if i == 1:
267 compare_datasets(i, get_testing_sets(step0001), get_testing_sets(ctx.collect_data()))
268 if i == 10:
269 compare_datasets(i, get_testing_sets(step0010), get_testing_sets(ctx.collect_data()))
270 if i == 100:
271 compare_datasets(i, get_testing_sets(step0100), get_testing_sets(ctx.collect_data()))
272 if i == 1000:
273 compare_datasets(i, get_testing_sets(step1000), get_testing_sets(ctx.collect_data()))
274
275 model.evolve_once_override_time(0, dt)
276 t += dt
277
278
279 # plt.show()
Estimated memory usage: 0 MB