Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
reorder_scan_dtt_result.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
18#include "shambase/assert.hpp"
24
25namespace shamtree::details {
26
27 inline void reorder_scan_dtt_result(
29
31
32 size_t interact_count = in_out.get_size();
33 size_t offsets_count = N + 1;
34
35 offsets.resize(offsets_count);
36 offsets.fill(0);
37
38 if (in_out.get_size() == 0) {
39 return; // no kernel call if there is no interaction, but we still need to return an
40 // offset table that is [0,0]
41 }
42
43 auto &q = in_out.get_dev_scheduler().get_queue();
44
45 // very brutal way of atomic counting the number of interactions for each sender
47 q,
48 sham::MultiRef{in_out},
49 sham::MultiRef{offsets},
50 interact_count,
51 [N](u32 i, const u32_2 *__restrict__ in_out, u32 *__restrict__ offsets) {
52 SHAM_ASSERT(in_out[i].x() < N);
53
54 sycl::atomic_ref<
55 u32,
56 sycl::memory_order_relaxed,
57 sycl::memory_scope_device,
58 sycl::access::address_space::global_space>
59 atom(offsets[in_out[i].x()]);
60 atom += 1_u32;
61 });
62
64
65 // here we can global sort in_out, or atomic store then local sort,
66 // for now i do a CPU sort for testing
67 if (true) {
68 sham::DeviceBuffer<u32_2> in_out_sorted(
69 in_out.get_size(), in_out.get_dev_scheduler_ptr());
70
71 sham::DeviceBuffer<u32> offset2 = offsets.copy();
72
73 // here we do a global sort by atomic fetch add on first index. The result is not yet
74 // deterministic since it depends on threads execution order.
76 q,
77 sham::MultiRef{in_out},
78 sham::MultiRef{in_out_sorted, offset2},
79 interact_count,
80 [N](u32 i,
81 const u32_2 *__restrict__ in_out,
82 u32_2 *__restrict__ in_out_sorted,
83 u32 *__restrict__ local_head) {
84 SHAM_ASSERT(in_out[i].x() < N);
85
86 sycl::atomic_ref<
87 u32,
88 sycl::memory_order_relaxed,
89 sycl::memory_scope_device,
90 sycl::access::address_space::global_space>
91 atom(local_head[in_out[i].x()]);
92
93 u32 ret = atom.fetch_add(1_u32);
94
95 in_out_sorted[ret] = in_out[i];
96 });
97
98 // we now perform a local sort on each slots which make the result deterministic
99 shamalgs::primitives::segmented_sort_in_place(in_out_sorted, offsets);
100
101 in_out = std::move(in_out_sorted);
102 } else {
103
104 std::vector<u32_2> in_out_stdvec = in_out.copy_to_stdvec();
105 std::sort(in_out_stdvec.begin(), in_out_stdvec.end(), [](u32_2 a, u32_2 b) {
106 return (a.x() == b.x()) ? (a.y() < b.y()) : (a.x() < b.x());
107 });
108 in_out.copy_from_stdvec(in_out_stdvec);
109 }
110 }
111
112} // namespace shamtree::details
std::uint32_t u32
32 bit unsigned integer
Shamrock assertion utility.
#define SHAM_ASSERT(x)
Shorthand for SHAM_ASSERT_NAMED without a message.
Definition assert.hpp:67
A buffer allocated in USM (Unified Shared Memory)
void resize(size_t new_size, bool keep_data=true)
Resizes the buffer to a given size.
void fill(T value, std::array< size_t, 2 > idx_range)
Fill a subpart of the buffer with a given value.
size_t get_size() const
Gets the number of elements in the buffer.
DeviceScheduler & get_dev_scheduler() const
Gets the Device scheduler corresponding to the held allocation.
DeviceBuffer< T, target > copy() const
Copy the current buffer.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
void scan_exclusive_sum_in_place(sham::DeviceBuffer< T > &buf1, u32 len)
Compute exclusive prefix sum in-place on a device buffer.
In-place exclusive scan (prefix sum) algorithm for device buffers.
This file contains the definition for the stacktrace related functionality.
#define __shamrock_stack_entry()
Macro to create a stack entry.
A class that references multiple buffers or similar objects.