Testing Sod tube with Zeus#

CI test for Sod tube with Zeus

  8 import os
  9
 10 import matplotlib.pyplot as plt
 11 import numpy as np
 12
 13 import shamrock
 14
 15 multx = 4
 16 multy = 1
 17 multz = 1
 18
 19 sz = 1 << 1
 20 base = 32
 21 gamma = 1.4
 22
 23
 24 ctx = shamrock.Context()
 25 ctx.pdata_layout_new()
 26
 27 model = shamrock.get_Model_Zeus(context=ctx, vector_type="f64_3", grid_repr="i64_3")
 28
 29
 30 cfg = model.gen_default_config()
 31 scale_fact = 2 / (sz * base * multx)
 32 cfg.set_scale_factor(scale_fact)
 33
 34 cfg.set_eos_gamma(gamma)
 35 cfg.set_consistent_transport(True)
 36 cfg.set_van_leer(True)
 37 model.set_solver_config(cfg)
 38
 39 model.init_scheduler(int(1e7), 1)
 40 model.make_base_grid((0, 0, 0), (sz, sz, sz), (base * multx, base * multy, base * multz))
 41
 42
 43 def rho_map(rmin, rmax):
 44     x, y, z = rmin
 45     if x < 1:
 46         return 1
 47     else:
 48         return 0.125
 49
 50
 51 eint_L = 1.0 / (gamma - 1)
 52 eint_R = 0.1 / (gamma - 1)
 53
 54
 55 def eint_map(rmin, rmax):
 56     x, y, z = rmin
 57     if x < 1:
 58         return eint_L
 59     else:
 60         return eint_R
 61
 62
 63 def vel_map(rmin, rmax):
 64     return (0, 0, 0)
 65
 66
 67 model.set_field_value_lambda_f64("rho", rho_map)
 68 model.set_field_value_lambda_f64("eint", eint_map)
 69 model.set_field_value_lambda_f64_3("vel", vel_map)
 70
 71 t_target = 0.245
 72
 73
 74 # model.evolve_once(0,0.1)
 75 freq = 50
 76 dt = 0.0010
 77 t = 0
 78 for i in range(701):
 79     model.evolve_once(i * dt, dt)
 80     t = i * dt
 81     if i * dt >= t_target:
 82         break
 83
 84 # model.evolve_until(t_target)
 85
 86 # model.evolve_once()
 87 xref = 1.0
 88 xrange = 0.5
 89 sod = shamrock.phys.SodTube(gamma=gamma, rho_1=1, P_1=1, rho_5=0.125, P_5=0.1)
 90 sodanalysis = model.make_analysis_sodtube(sod, (1, 0, 0), t_target, xref, -xrange, xrange)
 91
 92
 93 #################
 94 ### Plot
 95 #################
 96 # do plot or not
 97 if False:
 98
 99     def convert_to_cell_coords(dic):
100         cmin = dic["cell_min"]
101         cmax = dic["cell_max"]
102
103         xmin = []
104         ymin = []
105         zmin = []
106         xmax = []
107         ymax = []
108         zmax = []
109
110         for i in range(len(cmin)):
111             m, M = cmin[i], cmax[i]
112
113             mx, my, mz = m
114             Mx, My, Mz = M
115
116             for j in range(8):
117                 a, b = model.get_cell_coords(((mx, my, mz), (Mx, My, Mz)), j)
118
119                 x, y, z = a
120                 xmin.append(x)
121                 ymin.append(y)
122                 zmin.append(z)
123
124                 x, y, z = b
125                 xmax.append(x)
126                 ymax.append(y)
127                 zmax.append(z)
128
129         dic["xmin"] = np.array(xmin)
130         dic["ymin"] = np.array(ymin)
131         dic["zmin"] = np.array(zmin)
132         dic["xmax"] = np.array(xmax)
133         dic["ymax"] = np.array(ymax)
134         dic["zmax"] = np.array(zmax)
135
136         return dic
137
138     dic = convert_to_cell_coords(ctx.collect_data())
139
140     X = []
141     rho = []
142     velx = []
143     P = []
144
145     for i in range(len(dic["xmin"])):
146         X.append(dic["xmin"][i] - 0.5)
147         rho.append(dic["rho"][i])
148         velx.append(dic["vel"][i][0])
149         P.append(dic["eint"][i] * (gamma - 1))
150
151     X = np.array(X)
152     rho = np.array(rho)
153     velx = np.array(velx)
154     P = np.array(P)
155
156     fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(9, 6), dpi=125)
157
158     plt.scatter(X, rho, rasterized=True, label="rho")
159     plt.scatter(X, velx, rasterized=True, label="v")
160     plt.scatter(X, P, rasterized=True, label="P")
161     # plt.scatter(X,rhoetot, rasterized=True,label="rhoetot")
162     plt.legend()
163     plt.grid()
164
165     #### add analytical soluce
166     arr_x = np.linspace(xref - xrange, xref + xrange, 1000)
167
168     arr_rho = []
169     arr_P = []
170     arr_vx = []
171
172     for i in range(len(arr_x)):
173         x_ = arr_x[i] - xref
174
175         _rho, _vx, _P = sod.get_value(t_target, x_)
176         arr_rho.append(_rho)
177         arr_vx.append(_vx)
178         arr_P.append(_P)
179
180     plt.plot(arr_x, arr_rho, color="black", label="analytic")
181     plt.plot(arr_x, arr_vx, color="black")
182     plt.plot(arr_x, arr_P, color="black")
183     plt.ylim(-0.1, 1.1)
184     plt.xlim(0.5, 1.5)
185     #######
186     plt.show()
187
188 #################
189 ### Test CD
190 #################
191 rho, v, P = sodanalysis.compute_L2_dist()
192 vx, vy, vz = v
193
194 if shamrock.sys.world_rank() == 0:
195     print("L2 norm : rho = ", rho, " v = ", v, " P = ", P)
196
197 test_pass = True
198 pass_rho = 0.08027925640209972 + 1e-7
199 pass_vx = 0.18526690716374897 + 1e-7
200 pass_vy = 1e-09
201 pass_vz = 1e-09
202 pass_P = 0.1263222182067176 + 1e-7
203
204 err_log = ""
205
206 if rho > pass_rho:
207     err_log += ("error on rho is too high " + str(rho) + ">" + str(pass_rho)) + "\n"
208     test_pass = False
209 if vx > pass_vx:
210     err_log += ("error on vx is too high " + str(vx) + ">" + str(pass_vx)) + "\n"
211     test_pass = False
212 if vy > pass_vy:
213     err_log += ("error on vy is too high " + str(vy) + ">" + str(pass_vy)) + "\n"
214     test_pass = False
215 if vz > pass_vz:
216     err_log += ("error on vz is too high " + str(vz) + ">" + str(pass_vz)) + "\n"
217     test_pass = False
218 if P > pass_P:
219     err_log += ("error on P is too high " + str(P) + ">" + str(pass_P)) + "\n"
220     test_pass = False
221
222 if test_pass == False:
223     exit("Test did not pass L2 margins : \n" + err_log)

Estimated memory usage: 0 MB

Gallery generated by Sphinx-Gallery