Note
Go to the end to download the full example code.
SPH benchmark for homogeneous density box#
This example tests the the performance of the SPH solver for a homogeneous density box, the resolution is automatically adapted to the available memory and number of processes.
9 import datetime
10 import json
11 import math
12 from statistics import mean, stdev
13
14 import shamrock
15
16 device_properties = shamrock.sys.get_compute_device_properties()
17
18 microbench_results = shamrock.sys.get_microbench_results()
19 if len(microbench_results) == 0:
20 print("no microbench results, please run with --benchmark-mpi")
21 raise ValueError("no microbench results")
22
23 memory_gb = device_properties["global_mem_size"] / (1e9)
24
25 N_target_base = 2 ** int(math.log2(memory_gb * 1e6 / 1.5))
26 print(f"N_target_base = {N_target_base}")
27 print(f"memory_gb = {memory_gb}")
28 print(f"device_properties = {device_properties}")
29
30 if N_target_base > 2**25:
31 N_target_base = 2**25
32
33 if device_properties["type"] == "CPU":
34 if N_target_base > 2**23:
35 N_target_base = 2**23
36
37 shamrock.backends.reset_mem_info_max()
38
39 gamma = 5.0 / 3.0
40 rho_g = 1
41 target_tot_u = 1
42
43 bmin = (-0.6, -0.6, -0.6)
44 bmax = (0.6, 0.6, 0.6)
45
46 compute_multiplier = shamrock.sys.world_size()
47 # compute_multiplier = 12
48 scheduler_split_val = int(2e7)
49 scheduler_merge_val = int(1)
50
51 N_target = N_target_base * compute_multiplier
52 xm, ym, zm = bmin
53 xM, yM, zM = bmax
54 vol_b = (xM - xm) * (yM - ym) * (zM - zm)
55
56 if shamrock.sys.world_rank() == 0:
57 print("N_target_base", N_target_base)
58 print("compute_multiplier", compute_multiplier)
59 print("scheduler_split_val", scheduler_split_val)
60 print("scheduler_merge_val", scheduler_merge_val)
61 print("N_target", N_target)
62 print("vol_b", vol_b)
63
64 part_vol = vol_b / N_target
65
66 # lattice volume
67 part_vol_lattice = 0.74 * part_vol
68
69 dr = (part_vol_lattice / ((4.0 / 3.0) * 3.1416)) ** (1.0 / 3.0)
70
71 pmass = -1
72
73 ctx = shamrock.Context()
74 ctx.pdata_layout_new()
75
76 model = shamrock.get_Model_SPH(context=ctx, vector_type="f64_3", sph_kernel="M4")
77
78 cfg = model.gen_default_config()
79 # cfg.set_artif_viscosity_Constant(alpha_u = 1, alpha_AV = 1, beta_AV = 2)
80 # cfg.set_artif_viscosity_VaryingMM97(alpha_min = 0.1,alpha_max = 1,sigma_decay = 0.1, alpha_u = 1, beta_AV = 2)
81 cfg.set_artif_viscosity_VaryingCD10(
82 alpha_min=0.0, alpha_max=1, sigma_decay=0.1, alpha_u=1, beta_AV=2
83 )
84 cfg.set_boundary_periodic()
85 cfg.set_eos_adiabatic(gamma)
86 cfg.print_status()
87 model.set_solver_config(cfg)
88 model.init_scheduler(scheduler_split_val, scheduler_merge_val)
89
90 bmin, bmax = model.get_ideal_hcp_box(dr, bmin, bmax)
91 xm, ym, zm = bmin
92 xM, yM, zM = bmax
93
94 model.resize_simulation_box(bmin, bmax)
95
96 setup = model.get_setup()
97 gen = setup.make_generator_lattice_hcp(dr, bmin, bmax)
98
99 # Kind of optimized for Aurora
100 setup.apply_setup(
101 gen,
102 gen_step=int(scheduler_split_val / 8),
103 insert_step=int(scheduler_split_val * 2),
104 msg_count_limit=1024,
105 rank_comm_size_limit=int(scheduler_split_val) * 2,
106 max_msg_size=int(scheduler_split_val / 8),
107 do_setup_log=False,
108 )
109
110 xc, yc, zc = model.get_closest_part_to((0, 0, 0))
111
112 if shamrock.sys.world_rank() == 0:
113 print("closest part to (0,0,0) is in :", xc, yc, zc)
114
115 vol_b = (xM - xm) * (yM - ym) * (zM - zm)
116
117 totmass = rho_g * vol_b
118 # print("Total mass :", totmass)
119
120 pmass = model.total_mass_to_part_mass(totmass)
121
122 model.set_value_in_a_box("uint", "f64", 0, bmin, bmax)
123
124 rinj = 16 * dr
125 u_inj = 1
126 model.add_kernel_value("uint", "f64", u_inj, (0, 0, 0), rinj)
127
128 tot_u = pmass * model.get_sum("uint", "f64")
129 if shamrock.sys.world_rank() == 0:
130 print("total u :", tot_u)
131
132 # print("Current part mass :", pmass)
133 model.set_particle_mass(pmass)
134
135 model.set_cfl_cour(0.1)
136 model.set_cfl_force(0.1)
137
138 shamrock.backends.reset_mem_info_max()
139
140 # converge smoothing length and compute initial dt
141 model.timestep()
142
143 # Now run the actual benchmark for 5 steps
144 res_rates = []
145 res_cnts = []
146 res_system_metrics = []
147 res_mpi_timers = []
148
149 """
150 Here we insert callbacks to measure solver MPI usage by fetching the timers twice at the begining and end of the step
151 """
152 before_mpi_timers, after_mpi_timers = None, None
153
154
155 def callback_before_mpi_timer():
156 global before_mpi_timers
157 # print(shamrock.sys.world_rank(), "register before_mpi_timers")
158 before_mpi_timers = shamrock.comm.get_timers()
159
160
161 def callback_after_mpi_timer():
162 global after_mpi_timers
163 # print(shamrock.sys.world_rank(), "register after_mpi_timers")
164 after_mpi_timers = shamrock.comm.get_timers()
165
166
167 model.add_timestep_callback(step_begin=callback_before_mpi_timer, step_end=callback_after_mpi_timer)
168
169 for i in range(10):
170 if shamrock.sys.world_rank() == 0:
171 print("running step ", i + 1, "/", 10, " ...")
172
173 shamrock.sys.mpi_barrier()
174
175 # To replay the same step
176 model.set_next_dt(0.0)
177 model.timestep()
178
179 if shamrock.sys.world_rank() == 0:
180 print("collecting results ...")
181
182 tmp_res_rate, tmp_res_cnt, tmp_system_metrics = (
183 model.solver_logs_last_rate(),
184 model.solver_logs_last_obj_count(),
185 model.solver_logs_last_system_metrics(),
186 )
187 res_rates.append(tmp_res_rate)
188 res_cnts.append(tmp_res_cnt)
189 res_system_metrics.append(tmp_system_metrics)
190 res_mpi_timers.append(shamrock.comm.mpi_timers_delta(before_mpi_timers, after_mpi_timers))
191
192 if shamrock.sys.world_rank() == 0:
193 print("sleeping 1 second ...")
194
195 import time
196
197 time.sleep(1)
198
199 if shamrock.sys.world_rank() == 0:
200 print("done sleeping 1 second ...")
201
202 # result is the best rate of the 5 steps
203 res_rate, res_cnt = max(res_rates), res_cnts[0]
204
205 # index of the max rate
206 max_rate_index = res_rates.index(max(res_rates))
207 max_rate_system_metrics = res_system_metrics[max_rate_index]
208 max_mpi_timers = res_mpi_timers[max_rate_index]
209 step_time = res_cnt / res_rate
210
211 if shamrock.sys.world_rank() == 0:
212 result_text = ""
213 result_text += f"--- final score for N_target_base={N_target_base} ---"
214 result_text += f"world size : {shamrock.sys.world_size()}\n"
215 result_text += f"result rate : {res_rate}\n"
216 result_text += f"result cnt : {res_cnt}\n"
217 result_text += f"cnt/rank : {res_cnt / shamrock.sys.world_size()}\n"
218 result_text += f"result rate per rank : {res_rate / shamrock.sys.world_size()}\n"
219 result_text += f"rates infos : max={max(res_rates)}, min={min(res_rates)}, mean={mean(res_rates)}, stddev={stdev(res_rates)}\n"
220 result_text += f"res_rates = {res_rates}\n"
221 result_text += f"res_cnts = {res_cnts}\n"
222 result_text += f"step time = {step_time}\n"
223
224 dic_out = {
225 "device_properties": device_properties,
226 "microbench_results": shamrock.sys.get_microbench_results(),
227 "shamrock_version": shamrock.version_string(),
228 "shamrock_compiler_id_string": shamrock.get_compiler_id_string(),
229 "shamrock_compile_flags": shamrock.get_compile_arg(),
230 "date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
231 "world_size": shamrock.sys.world_size(),
232 "rate": res_rate,
233 "cnt": res_cnt,
234 "step_time": step_time,
235 "mpi_timers": max_mpi_timers,
236 }
237
238 # print the system metrics
239 metrics_duration = max_rate_system_metrics["duration"]
240 result_text += "system metrics:\n"
241 for key, value in max_rate_system_metrics.items():
242 if not key == "duration":
243 result_text += f"{key}: {value} J\n"
244 dic_out[key] = value
245
246 for key, value in max_rate_system_metrics.items():
247 if not key == "duration":
248 result_text += f"avg power {key} / step time : {value / metrics_duration} W\n"
249 dic_out[f"power_{key}"] = value / metrics_duration
250
251 dic_out["system_metric_duration"] = metrics_duration
252
253 result_text += "---------submit this result--------\n"
254 result_text += f"{json.dumps(dic_out, indent=4)}\n"
255 result_text += "-----------------------------------\n"
256
257 print("current results:")
258 print(result_text)
Estimated memory usage: 0 MB