Testing Sod tube with Zeus#

CI test for Sod tube with Zeus

  8 import os
  9
 10 import matplotlib.pyplot as plt
 11 import numpy as np
 12
 13 import shamrock
 14
 15 multx = 4
 16 multy = 1
 17 multz = 1
 18
 19 sz = 1 << 1
 20 base = 32
 21 gamma = 1.4
 22
 23
 24 ctx = shamrock.Context()
 25 ctx.pdata_layout_new()
 26
 27 model = shamrock.get_Model_Zeus(context=ctx, vector_type="f64_3", grid_repr="i64_3")
 28
 29
 30 cfg = model.gen_default_config()
 31 scale_fact = 2 / (sz * base * multx)
 32 cfg.set_scale_factor(scale_fact)
 33
 34 cfg.set_eos_gamma(gamma)
 35 cfg.set_consistent_transport(True)
 36 cfg.set_van_leer(True)
 37 model.set_solver_config(cfg)
 38
 39 model.init_scheduler(int(1e7), 1)
 40 model.make_base_grid((0, 0, 0), (sz, sz, sz), (base * multx, base * multy, base * multz))
 41
 42
 43 def rho_map(rmin, rmax):
 44     x, y, z = rmin
 45     if x < 1:
 46         return 1
 47     else:
 48         return 0.125
 49
 50
 51 eint_L = 1.0 / (gamma - 1)
 52 eint_R = 0.1 / (gamma - 1)
 53
 54
 55 def eint_map(rmin, rmax):
 56     x, y, z = rmin
 57     if x < 1:
 58         return eint_L
 59     else:
 60         return eint_R
 61
 62
 63 def vel_map(rmin, rmax):
 64     return (0, 0, 0)
 65
 66
 67 model.set_field_value_lambda_f64("rho", rho_map)
 68 model.set_field_value_lambda_f64("eint", eint_map)
 69 model.set_field_value_lambda_f64_3("vel", vel_map)
 70
 71 t_target = 0.245
 72
 73
 74 # model.evolve_once(0,0.1)
 75 freq = 50
 76 dt = 0.0010
 77 t = 0
 78 for i in range(701):
 79     model.evolve_once(i * dt, dt)
 80     t = i * dt
 81     if i * dt >= t_target:
 82         break
 83
 84 # model.evolve_until(t_target)
 85
 86 # model.evolve_once()
 87 xref = 1.0
 88 xrange = 0.5
 89 sod = shamrock.phys.SodTube(gamma=gamma, rho_1=1, P_1=1, rho_5=0.125, P_5=0.1)
 90 sodanalysis = model.make_analysis_sodtube(sod, (1, 0, 0), t_target, xref, -xrange, xrange)
 91
 92
 93 #################
 94 ### Plot
 95 #################
 96 # do plot or not
 97 if False:
 98
 99     def convert_to_cell_coords(dic):
100
101         cmin = dic["cell_min"]
102         cmax = dic["cell_max"]
103
104         xmin = []
105         ymin = []
106         zmin = []
107         xmax = []
108         ymax = []
109         zmax = []
110
111         for i in range(len(cmin)):
112
113             m, M = cmin[i], cmax[i]
114
115             mx, my, mz = m
116             Mx, My, Mz = M
117
118             for j in range(8):
119                 a, b = model.get_cell_coords(((mx, my, mz), (Mx, My, Mz)), j)
120
121                 x, y, z = a
122                 xmin.append(x)
123                 ymin.append(y)
124                 zmin.append(z)
125
126                 x, y, z = b
127                 xmax.append(x)
128                 ymax.append(y)
129                 zmax.append(z)
130
131         dic["xmin"] = np.array(xmin)
132         dic["ymin"] = np.array(ymin)
133         dic["zmin"] = np.array(zmin)
134         dic["xmax"] = np.array(xmax)
135         dic["ymax"] = np.array(ymax)
136         dic["zmax"] = np.array(zmax)
137
138         return dic
139
140     dic = convert_to_cell_coords(ctx.collect_data())
141
142     X = []
143     rho = []
144     velx = []
145     P = []
146
147     for i in range(len(dic["xmin"])):
148
149         X.append(dic["xmin"][i] - 0.5)
150         rho.append(dic["rho"][i])
151         velx.append(dic["vel"][i][0])
152         P.append(dic["eint"][i] * (gamma - 1))
153
154     X = np.array(X)
155     rho = np.array(rho)
156     velx = np.array(velx)
157     P = np.array(P)
158
159     fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(9, 6), dpi=125)
160
161     plt.scatter(X, rho, rasterized=True, label="rho")
162     plt.scatter(X, velx, rasterized=True, label="v")
163     plt.scatter(X, P, rasterized=True, label="P")
164     # plt.scatter(X,rhoetot, rasterized=True,label="rhoetot")
165     plt.legend()
166     plt.grid()
167
168     #### add analytical soluce
169     arr_x = np.linspace(xref - xrange, xref + xrange, 1000)
170
171     arr_rho = []
172     arr_P = []
173     arr_vx = []
174
175     for i in range(len(arr_x)):
176         x_ = arr_x[i] - xref
177
178         _rho, _vx, _P = sod.get_value(t_target, x_)
179         arr_rho.append(_rho)
180         arr_vx.append(_vx)
181         arr_P.append(_P)
182
183     plt.plot(arr_x, arr_rho, color="black", label="analytic")
184     plt.plot(arr_x, arr_vx, color="black")
185     plt.plot(arr_x, arr_P, color="black")
186     plt.ylim(-0.1, 1.1)
187     plt.xlim(0.5, 1.5)
188     #######
189     plt.show()
190
191 #################
192 ### Test CD
193 #################
194 rho, v, P = sodanalysis.compute_L2_dist()
195 vx, vy, vz = v
196
197 if shamrock.sys.world_rank() == 0:
198     print("L2 norm : rho = ", rho, " v = ", v, " P = ", P)
199
200 test_pass = True
201 pass_rho = 0.08027925640209972 + 1e-7
202 pass_vx = 0.18526690716374897 + 1e-7
203 pass_vy = 1e-09
204 pass_vz = 1e-09
205 pass_P = 0.1263222182067176 + 1e-7
206
207 err_log = ""
208
209 if rho > pass_rho:
210     err_log += ("error on rho is too high " + str(rho) + ">" + str(pass_rho)) + "\n"
211     test_pass = False
212 if vx > pass_vx:
213     err_log += ("error on vx is too high " + str(vx) + ">" + str(pass_vx)) + "\n"
214     test_pass = False
215 if vy > pass_vy:
216     err_log += ("error on vy is too high " + str(vy) + ">" + str(pass_vy)) + "\n"
217     test_pass = False
218 if vz > pass_vz:
219     err_log += ("error on vz is too high " + str(vz) + ">" + str(pass_vz)) + "\n"
220     test_pass = False
221 if P > pass_P:
222     err_log += ("error on P is too high " + str(P) + ">" + str(pass_P)) + "\n"
223     test_pass = False
224
225 if test_pass == False:
226     exit("Test did not pass L2 margins : \n" + err_log)

Estimated memory usage: 0 MB

Gallery generated by Sphinx-Gallery