Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
DeviceScheduler.cpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
18#include "shambase/memory.hpp"
19#include "shambase/string.hpp"
22#include "shamcomm/logs.hpp"
23
24namespace {
25 void test_queue(const sham::DeviceScheduler_ptr &dev_sched, sham::DeviceQueue &q) {
26 sham::DeviceBuffer<u32> b(10, dev_sched);
27
28 sham::kernel_call(q, sham::MultiRef{}, sham::MultiRef{b}, 10, [](u32 i, u32 *__restrict b) {
29 b[i] = i;
30 });
31
32 std::vector<u32> expected_acc = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
33
34 std::vector<u32> acc = b.copy_to_stdvec();
35 if (acc != expected_acc) {
36 auto &ctx = shambase::get_check_ref(q.ctx);
37 auto &device = shambase::get_check_ref(ctx.device);
39 "The chosen SYCL queue (name={}, device={}) cannot execute a basic kernel\n"
40 " expected acc = {}\n"
41 " actual acc = {}",
42 q.queue_name,
43 device.dev.get_info<sycl::info::device::name>(),
44 expected_acc,
45 acc));
46 }
47 }
48} // namespace
49
50namespace sham {
51
52 DeviceScheduler::DeviceScheduler(std::shared_ptr<DeviceContext> ctx) : ctx(ctx) {
53
54 queues.push_back(std::make_unique<DeviceQueue>("main_queue", ctx, false));
55 }
56
58
60
61 ctx->print_info();
62
63 shamcomm::logs::raw_ln(" Queue list:");
64 for (auto &q : queues) {
65 std::string tmp
66 = shambase::format(" - name : {:20s} in order : {}", q->queue_name, q->in_order);
67 shamcomm::logs::raw_ln(tmp);
68 }
69 }
70
71 bool DeviceScheduler::use_direct_comm() { return ctx->use_direct_comm(); }
72
73 void test_device_scheduler(const sham::DeviceScheduler_ptr &dev_sched) {
74 for (auto &q : dev_sched->queues) {
75
78 sham::Device &deviceref = shambase::get_check_ref(ctxref.device);
79 std::string device_name = deviceref.dev.get_info<sycl::info::device::name>();
80
81 std::exception_ptr eptr;
82 try {
83 logger::debug_ln("Backends", "[Queue testing] name = ", device_name);
84 test_queue(dev_sched, qref);
85 logger::debug_ln(
86 "Backends", "[Queue testing] name = ", device_name, " -> working !");
87 } catch (...) {
88 eptr = std::current_exception(); // capture
89 }
90
91 if (eptr) {
92 logger::err_ln(
93 "Backends", "[Queue testing] name = ", device_name, " -> not working !");
94 std::rethrow_exception(eptr);
95 }
96 }
97
98 logger::debug_ln("Backends", "[Alloc testing] starting...");
99
100 sham::DeviceQueue &qref = shambase::get_check_ref(dev_sched).get_queue();
102
103 u32 align = shambase::get_check_ref(ctxref.device).prop.mem_base_addr_align;
104 USMPtrHolder<sham::device> ptr1024_dev
105 = USMPtrHolder<sham::device>::create(1024, dev_sched, align);
106 ptr1024_dev.free_ptr();
107 USMPtrHolder<sham::host> ptr1024_host
108 = USMPtrHolder<sham::host>::create(1024, dev_sched, align);
109 ptr1024_host.free_ptr();
110
111 auto &dev = shambase::get_check_ref(ctxref.device);
112
113 size_t GBval = 1024 * 1024 * 1024;
114 // avoid <8GB card, they won't run at that scale anyway
115 if (dev.prop.global_mem_size > usize(3 * GBval)) {
116
117 if (dev.prop.max_mem_alloc_size_dev > 2 * GBval) {
118 try {
119 USMPtrHolder<sham::device> ptr2G_dev
120 = USMPtrHolder<sham::device>::create(2 * GBval + 1024, dev_sched, align);
121 ptr2G_dev.free_ptr();
122 } catch (std::runtime_error &e) {
123 logger::warn_ln(
124 "Backends",
125 " name = ",
126 dev.dev.get_info<sycl::info::device::name>(),
127 " -> large device allocation (>2GB) not working !");
128 dev.prop.max_mem_alloc_size_dev = i32_max;
129 }
130 }
131
132 if (dev.prop.max_mem_alloc_size_host > 2 * GBval) {
133 try {
134 USMPtrHolder<sham::host> ptr2G_host
135 = USMPtrHolder<sham::host>::create(2 * GBval + 1024, dev_sched, align);
136 ptr2G_host.free_ptr();
137 } catch (std::runtime_error &e) {
138 logger::warn_ln(
139 "Backends",
140 " name = ",
141 dev.dev.get_info<sycl::info::device::name>(),
142 " -> large host allocation (>2GB) not working !");
143 dev.prop.max_mem_alloc_size_host = i32_max;
144 }
145 }
146 }
147
148 logger::debug_ln("Backends", "[Alloc testing] done !");
149 }
150} // namespace sham
std::uint32_t u32
32 bit unsigned integer
std::size_t usize
size_t alias
A buffer allocated in USM (Unified Shared Memory)
A class that represents a SYCL context.
std::shared_ptr< Device > device
A SYCL queue associated with a device and a context.
std::string queue_name
The name of this queue.
std::shared_ptr< DeviceContext > ctx
The device context of this queue.
bool in_order
Whether the queue is in order.
bool use_direct_comm()
Check if the context corresponding to the device scheduler should use direct communication.
DeviceScheduler(std::shared_ptr< DeviceContext > ctx)
Constructor.
std::shared_ptr< DeviceContext > ctx
Reference to the device context associated with this DeviceScheduler.
void print_info()
Print information about the DeviceScheduler.
DeviceQueue & get_queue(u32 id=0)
Get a reference to a DeviceQueue.
std::vector< std::unique_ptr< DeviceQueue > > queues
Vector of unique pointers to the DeviceQueues associated with this DeviceScheduler.
Represents a SYCL device.
Definition Device.hpp:147
sycl::device dev
The SYCL device object.
Definition Device.hpp:157
static USMPtrHolder create(size_t sz, std::shared_ptr< DeviceScheduler > dev_sched, std::optional< size_t > alignment=std::nullopt)
Create a USM pointer holder.
This header file contains utility functions related to exception handling in the code.
namespace for backends this one is named only sham since shambackends is too long to write
@ device
Device memory.
void kernel_call(sham::DeviceQueue &q, RefIn in, RefOut in_out, u32 n, Functor &&func, SourceLocation &&callsite=SourceLocation{})
Submit a kernel to a SYCL queue.
void throw_with_loc(std::string message, SourceLocation loc=SourceLocation{})
Throw an exception and append the source location to it.
T & get_check_ref(const std::unique_ptr< T > &ptr, SourceLocation loc=SourceLocation())
Takes a std::unique_ptr and returns a reference to the object it holds. It throws a std::runtime_erro...
Definition memory.hpp:110
constexpr i32 i32_max
i32 max value
A class that references multiple buffers or similar objects.