Note
Go to the end to download the full example code.
Sph homogeneous benchmarks results#
Show the results on various devices

results from /work/examples/benchmarks/sph_homogeneous_bench_result.json
Apple M4 Max:
- 1 ranks, 851313.6209866619 rate, 8484840 cnt, 9.966761709 step time
NVIDIA GeForce RTX 3070:
- 1 ranks, 1732146.8265657455 rate, 4254912 cnt, 2.4564384120000002 step time
Intel(R) Core(TM) Ultra 9 285K:
- 1 ranks, 1612764.908074928 rate, 8484840 cnt, 5.261051972000001 step time
NVIDIA H100:
- 1 ranks, 25502055.876729604 rate, 33848064 cnt, 1.3272680510000001 step time
AMD EPYC 9654 96-Core Processor :
- 1 ranks, 2353244.022271519 rate, 8464638 cnt, 3.597008181 step time
Intel(R) Data Center GPU Max 1550:
- 1 ranks, 14612536.848131763 rate, 33848064 cnt, 2.316371507 step time
12 x Intel(R) Data Center GPU Max 1550:
- 12 ranks, 124464069.34448674 rate, 404289600 cnt, 3.2482434660000004 step time
9 import json
10 import os
11 import textwrap
12
13 import matplotlib.pyplot as plt
14 import numpy as np
15 from matplotlib.ticker import MaxNLocator
16
17 try:
18 base_path = os.path.dirname(os.path.abspath(__file__))
19 except NameError:
20 base_path = os.getcwd()
21
22 json_file = os.path.join(base_path, "sph_homogeneous_bench_result.json")
23 results = json.load(open(json_file))
24
25 json.dump(results, open(json_file, "w"), indent=4)
26
27 print(f"results from {json_file}")
28
29 results_per_model = {}
30
31
32 def key_name(name, world_size):
33 if world_size == 1:
34 return name
35 else:
36 return f"{world_size} x {name}"
37
38
39 for result in results:
40 name = key_name(result["device_properties"]["name"], result["world_size"])
41 if name not in results_per_model:
42 results_per_model[name] = result
43 else:
44 if result["rate"] > results_per_model[name]["rate"]:
45 results_per_model[name] = result
46
47 for name, result in results_per_model.items():
48 print(f"{name}:")
49 print(
50 f" - {result['world_size']} ranks, {result['rate']} rate, {result['cnt']} cnt, {result['step_time']} step time"
51 )
52
53
54 def _rate_bar_color(device_name: str) -> str:
55 """Color for the rate bar from device name (case-insensitive)."""
56 lower = device_name.lower()
57 if "nvidia" in lower:
58 return "#2ca02c" # green
59 if "amd" in lower or "radeon" in lower:
60 return "#d62728" # red
61 if "intel" in lower:
62 return "#1f77b4" # blue
63 if "apple" in lower:
64 return "#7f7f7f" # grey
65 return "steelblue"
66
67
68 def _micro_bw_and_fma(result):
69 """saxpy f64 -> GB/s; fma_chains f32/f64 -> Gflops (MicroBenchmark raw flop/s, /1e9)."""
70 m = result.get("microbench_results") or {}
71 bw_bs = m.get("saxpy_f64")
72 f64 = m.get("fma_chains_f64")
73 f32 = m.get("fma_chains_f32")
74 bw_gbps = (bw_bs / 1e9) if bw_bs is not None else float("nan")
75 flops_f64 = (f64) if f64 is not None else float("nan")
76 flops_f32 = (f32) if f32 is not None else float("nan")
77 return bw_gbps, flops_f64, flops_f32
78
79
80 # Stable sort by rate descending for a readable chart
81 items = sorted(results_per_model.items(), key=lambda kv: kv[1]["rate"], reverse=True)
82 names = [kv[0] for kv in items]
83 rates = [kv[1]["rate"] for kv in items]
84 bw_gbps = []
85 flops_f64 = []
86 flops_f32 = []
87 for _, r in items:
88 bw, f64, f32 = _micro_bw_and_fma(r)
89 bw_gbps.append(bw)
90 flops_f64.append(f64)
91 flops_f32.append(f32)
92
93 h_in = max(3.0, 0.45 * len(names) + 5)
94 y = np.arange(len(names))
95
96 fig, (ax_rate, ax_micro) = plt.subplots(
97 1,
98 2,
99 sharey=True,
100 figsize=(15, h_in),
101 gridspec_kw={"width_ratios": [75, 25], "wspace": 0.025},
102 )
103
104 # Wrap long device names so they stay inside the figure margin
105 _name_labels = ["\n".join(textwrap.wrap(n, 34)) for n in names]
106
107 _rate_colors = [_rate_bar_color(n) for n in names]
108 bars = ax_rate.barh(y, rates, color=_rate_colors, edgecolor="white", linewidth=0.5)
109 ax_rate.set_yticks(y)
110 ax_rate.set_yticklabels(_name_labels)
111 ax_rate.set_xlabel("rate (solver objects / s)")
112 ax_rate.set_xscale("log")
113 ax_rate.set_title("SPH homogeneous - rate by device")
114 ax_rate.bar_label(bars, fmt="%.3g", padding=3)
115 ax_rate.grid(axis="x", linestyle=":", alpha=0.6)
116 ax_rate.invert_yaxis()
117
118 # Extra room for bar-end labels; drop rightmost x tick (avoids clash with right panel)
119 _xmin, _xmax = ax_rate.get_xlim()
120 ax_rate.set_xlim(_xmin, _xmax + 0.5 * (_xmax - _xmin))
121 # ax_rate.xaxis.set_major_locator(MaxNLocator(prune="upper"))
122
123 # Three equal-height rows per device, evenly spaced around the tick (name at y)
124 _bar_h = 0.22
125 _spacing = 0.26 # distance between bar centers; middle bar (f32) on the tick
126 _y_saxpy = y - _spacing
127 _y_f32 = y
128 _y_f64 = y + _spacing
129
130 ax_micro.barh(
131 _y_saxpy,
132 bw_gbps,
133 height=_bar_h,
134 color="coral",
135 label="saxpy f64 (GB/s)",
136 edgecolor="white",
137 linewidth=0.5,
138 )
139 ax_micro.set_xlabel("Memory bandwidth saxpy f64 (GB/s)")
140 ax_micro.grid(axis="x", linestyle=":", alpha=0.6)
141 ax_micro.tick_params(axis="y", labelleft=False)
142
143 # f32 / f64 FMA can differ a lot in scale -> log-scaled Gflops axis (same y layout as saxpy)
144 ax_micro_top = ax_micro.twiny()
145 ax_micro_top.barh(
146 _y_f32,
147 flops_f32,
148 height=_bar_h,
149 color="mediumpurple",
150 label="fma_chains f32 (flops)",
151 edgecolor="white",
152 linewidth=0.5,
153 )
154 ax_micro_top.barh(
155 _y_f64,
156 flops_f64,
157 height=_bar_h,
158 color="seagreen",
159 label="fma_chains f64 (flops)",
160 edgecolor="white",
161 linewidth=0.5,
162 )
163 ax_micro_top.set_xlabel("Peak FMA f32 / f64 (flops, log scale)")
164 ax_micro_top.set_xscale("log")
165 ax_micro.set_xscale("log")
166
167 h0, l0 = ax_micro.get_legend_handles_labels()
168 h1, l1 = ax_micro_top.get_legend_handles_labels()
169 ax_micro.legend(h0 + h1, l0 + l1, loc="lower right", fontsize=8)
170
171 # Flush panels: constrained_layout always leaves a gap; manual wspace=0 truly abuts axes
172 ax_rate.spines["right"].set_visible(True)
173 ax_micro.spines["left"].set_visible(False)
174 fig.subplots_adjust(left=0.22, right=0.99, top=0.90, bottom=0.12, wspace=0)
175
176 plt.show()
Total running time of the script: (0 minutes 0.511 seconds)
Estimated memory usage: 157 MB