Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
ForwardEuler.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
18
21#include "shamcomm/logs.hpp"
27
28#define NODE_EDGES(X_RO, X_RW) \
29 /* ------------------- inputs ------------------- */ \
30 X_RO(shamrock::solvergraph::IDataEdge<Tscal>, dt) \
31 X_RO(shamrock::solvergraph::IFieldSpan<T>, time_derivative) \
32 X_RO(shamrock::solvergraph::Indexes<u32>, sizes) \
33 \
34 /* ------------------- outputs ------------------- */ \
35 X_RW(shamrock::solvergraph::IFieldSpan<T>, field)
36
37namespace shammodels::common::modules {
38 template<class T>
39 class ForwardEuler : public shamrock::solvergraph::INode {
40
41 using Tscal = shambase::VecComponent<T>;
42
43 u32 nvar;
44
45 public:
46 ForwardEuler(u32 nvar = 1) : nvar(nvar) {}
47
48 EXPAND_NODE_EDGES(NODE_EDGES)
49
51
53
54 auto edges = get_edges();
55
56 edges.field.ensure_sizes(edges.sizes.indexes);
57
58 Tscal dt = edges.dt.data;
59
60 if (nvar == 1) {
61
63 shamsys::instance::get_compute_scheduler_ptr(),
64 sham::DDMultiRef{edges.time_derivative.get_spans()},
65 sham::DDMultiRef{edges.field.get_spans()},
66 edges.sizes.indexes,
67 [dt](u32 gid, const T *time_derivative, T *field) {
68 field[gid] = field[gid] + dt * time_derivative[gid];
69 });
70
71 } else {
72
73 auto var_count = edges.sizes.indexes.template map<u32>([&](u64 id, u32 count) {
74 return count * nvar;
75 });
76
78 shamsys::instance::get_compute_scheduler_ptr(),
79 sham::DDMultiRef{edges.time_derivative.get_spans()},
80 sham::DDMultiRef{edges.field.get_spans()},
81 var_count,
82 [dt](u32 gid, const T *time_derivative, T *field) {
83 field[gid] = field[gid] + dt * time_derivative[gid];
84 });
85 }
86 }
87
88 inline virtual std::string _impl_get_label() const { return "ForwardEuler"; }
89
90 inline virtual std::string _impl_get_tex() const { return "TODO"; }
91 };
92} // namespace shammodels::common::modules
93
94#undef NODE_EDGES
Header file describing a Node Instance.
Source location utility.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
void _impl_evaluate_internal()
evaluate the node
virtual std::string _impl_get_label() const
get the label of the node
virtual std::string _impl_get_tex() const
get the tex of the node
Inode is node between data edges, takes multiple inputs, multiple outputs.
Definition INode.hpp:30
void distributed_data_kernel_call(sham::DeviceScheduler_ptr dev_sched, RefIn in, RefOut in_out, const shambase::DistributedData< index_t > &thread_counts, Functor &&func)
A variant of sham::kernel_call for distributed data.
#define __shamrock_stack_entry()
Macro to create a stack entry.
A variant of sham::MultiRef for distributed data.