Note
Go to the end to download the full example code.
Testing Sod tube with Zeus#
CI test for Sod tube with Zeus
8 import os
9
10 import matplotlib.pyplot as plt
11 import numpy as np
12
13 import shamrock
14
15 multx = 4
16 multy = 1
17 multz = 1
18
19 sz = 1 << 1
20 base = 32
21 gamma = 1.4
22
23
24 ctx = shamrock.Context()
25 ctx.pdata_layout_new()
26
27 model = shamrock.get_Model_Zeus(context=ctx, vector_type="f64_3", grid_repr="i64_3")
28
29
30 cfg = model.gen_default_config()
31 scale_fact = 2 / (sz * base * multx)
32 cfg.set_scale_factor(scale_fact)
33
34 cfg.set_eos_gamma(gamma)
35 cfg.set_consistent_transport(True)
36 cfg.set_van_leer(True)
37 model.set_solver_config(cfg)
38
39 model.init_scheduler(int(1e7), 1)
40 model.make_base_grid((0, 0, 0), (sz, sz, sz), (base * multx, base * multy, base * multz))
41
42
43 def rho_map(rmin, rmax):
44 x, y, z = rmin
45 if x < 1:
46 return 1
47 else:
48 return 0.125
49
50
51 eint_L = 1.0 / (gamma - 1)
52 eint_R = 0.1 / (gamma - 1)
53
54
55 def eint_map(rmin, rmax):
56 x, y, z = rmin
57 if x < 1:
58 return eint_L
59 else:
60 return eint_R
61
62
63 def vel_map(rmin, rmax):
64 return (0, 0, 0)
65
66
67 model.set_field_value_lambda_f64("rho", rho_map)
68 model.set_field_value_lambda_f64("eint", eint_map)
69 model.set_field_value_lambda_f64_3("vel", vel_map)
70
71 t_target = 0.245
72
73
74 # model.evolve_once(0,0.1)
75 freq = 50
76 dt = 0.0010
77 t = 0
78 for i in range(701):
79 model.evolve_once(i * dt, dt)
80 t = i * dt
81 if i * dt >= t_target:
82 break
83
84 # model.evolve_until(t_target)
85
86 # model.evolve_once()
87 xref = 1.0
88 xrange = 0.5
89 sod = shamrock.phys.SodTube(gamma=gamma, rho_1=1, P_1=1, rho_5=0.125, P_5=0.1)
90 sodanalysis = model.make_analysis_sodtube(sod, (1, 0, 0), t_target, xref, -xrange, xrange)
91
92
93 #################
94 ### Plot
95 #################
96 # do plot or not
97 if False:
98
99 def convert_to_cell_coords(dic):
100
101 cmin = dic["cell_min"]
102 cmax = dic["cell_max"]
103
104 xmin = []
105 ymin = []
106 zmin = []
107 xmax = []
108 ymax = []
109 zmax = []
110
111 for i in range(len(cmin)):
112
113 m, M = cmin[i], cmax[i]
114
115 mx, my, mz = m
116 Mx, My, Mz = M
117
118 for j in range(8):
119 a, b = model.get_cell_coords(((mx, my, mz), (Mx, My, Mz)), j)
120
121 x, y, z = a
122 xmin.append(x)
123 ymin.append(y)
124 zmin.append(z)
125
126 x, y, z = b
127 xmax.append(x)
128 ymax.append(y)
129 zmax.append(z)
130
131 dic["xmin"] = np.array(xmin)
132 dic["ymin"] = np.array(ymin)
133 dic["zmin"] = np.array(zmin)
134 dic["xmax"] = np.array(xmax)
135 dic["ymax"] = np.array(ymax)
136 dic["zmax"] = np.array(zmax)
137
138 return dic
139
140 dic = convert_to_cell_coords(ctx.collect_data())
141
142 X = []
143 rho = []
144 velx = []
145 P = []
146
147 for i in range(len(dic["xmin"])):
148
149 X.append(dic["xmin"][i] - 0.5)
150 rho.append(dic["rho"][i])
151 velx.append(dic["vel"][i][0])
152 P.append(dic["eint"][i] * (gamma - 1))
153
154 X = np.array(X)
155 rho = np.array(rho)
156 velx = np.array(velx)
157 P = np.array(P)
158
159 fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(9, 6), dpi=125)
160
161 plt.scatter(X, rho, rasterized=True, label="rho")
162 plt.scatter(X, velx, rasterized=True, label="v")
163 plt.scatter(X, P, rasterized=True, label="P")
164 # plt.scatter(X,rhoetot, rasterized=True,label="rhoetot")
165 plt.legend()
166 plt.grid()
167
168 #### add analytical soluce
169 arr_x = np.linspace(xref - xrange, xref + xrange, 1000)
170
171 arr_rho = []
172 arr_P = []
173 arr_vx = []
174
175 for i in range(len(arr_x)):
176 x_ = arr_x[i] - xref
177
178 _rho, _vx, _P = sod.get_value(t_target, x_)
179 arr_rho.append(_rho)
180 arr_vx.append(_vx)
181 arr_P.append(_P)
182
183 plt.plot(arr_x, arr_rho, color="black", label="analytic")
184 plt.plot(arr_x, arr_vx, color="black")
185 plt.plot(arr_x, arr_P, color="black")
186 plt.ylim(-0.1, 1.1)
187 plt.xlim(0.5, 1.5)
188 #######
189 plt.show()
190
191 #################
192 ### Test CD
193 #################
194 rho, v, P = sodanalysis.compute_L2_dist()
195 vx, vy, vz = v
196
197 if shamrock.sys.world_rank() == 0:
198 print("L2 norm : rho = ", rho, " v = ", v, " P = ", P)
199
200 test_pass = True
201 pass_rho = 0.08027925640209972 + 1e-7
202 pass_vx = 0.18526690716374897 + 1e-7
203 pass_vy = 1e-09
204 pass_vz = 1e-09
205 pass_P = 0.1263222182067176 + 1e-7
206
207 err_log = ""
208
209 if rho > pass_rho:
210 err_log += ("error on rho is too high " + str(rho) + ">" + str(pass_rho)) + "\n"
211 test_pass = False
212 if vx > pass_vx:
213 err_log += ("error on vx is too high " + str(vx) + ">" + str(pass_vx)) + "\n"
214 test_pass = False
215 if vy > pass_vy:
216 err_log += ("error on vy is too high " + str(vy) + ">" + str(pass_vy)) + "\n"
217 test_pass = False
218 if vz > pass_vz:
219 err_log += ("error on vz is too high " + str(vz) + ">" + str(pass_vz)) + "\n"
220 test_pass = False
221 if P > pass_P:
222 err_log += ("error on P is too high " + str(P) + ">" + str(pass_P)) + "\n"
223 test_pass = False
224
225 if test_pass == False:
226 exit("Test did not pass L2 margins : \n" + err_log)
Estimated memory usage: 0 MB