Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
ForwardEuler.hpp
Go to the documentation of this file.
1// -------------------------------------------------------//
2//
3// SHAMROCK code for hydrodynamics
4// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
5// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
6// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
7//
8// -------------------------------------------------------//
9
10#pragma once
11
21#include "shamcomm/logs.hpp"
27
28namespace shammodels::common::modules {
29 template<class T>
31
32 using Tscal = shambase::VecComponent<T>;
33
34 u32 nvar;
35
36 public:
37 ForwardEuler(u32 nvar = 1) : nvar(nvar) {}
38
45
46 inline void set_edges(
48 std::shared_ptr<shamrock::solvergraph::IFieldSpan<T>> time_derivative,
49 std::shared_ptr<shamrock::solvergraph::Indexes<u32>> sizes,
50 std::shared_ptr<shamrock::solvergraph::IFieldSpan<T>> field) {
51 __internal_set_ro_edges({dt, time_derivative, sizes});
53 }
54
55 inline Edges get_edges() {
56 return Edges{
57 get_ro_edge<shamrock::solvergraph::IDataEdge<Tscal>>(0),
59 get_ro_edge<shamrock::solvergraph::Indexes<u32>>(2),
61 }
62
64
66
67 auto edges = get_edges();
68
69 edges.field.ensure_sizes(edges.sizes.indexes);
70
71 Tscal dt = edges.dt.data;
72
73 if (nvar == 1) {
74
76 shamsys::instance::get_compute_scheduler_ptr(),
77 sham::DDMultiRef{edges.time_derivative.get_spans()},
78 sham::DDMultiRef{edges.field.get_spans()},
79 edges.sizes.indexes,
80 [dt](u32 gid, const T *time_derivative, T *field) {
81 field[gid] = field[gid] + dt * time_derivative[gid];
82 });
83
84 } else {
85
86 auto var_count = edges.sizes.indexes.template map<u32>([&](u64 id, u32 count) {
87 return count * nvar;
88 });
89
91 shamsys::instance::get_compute_scheduler_ptr(),
92 sham::DDMultiRef{edges.time_derivative.get_spans()},
93 sham::DDMultiRef{edges.field.get_spans()},
94 var_count,
95 [dt](u32 gid, const T *time_derivative, T *field) {
96 field[gid] = field[gid] + dt * time_derivative[gid];
97 });
98 }
99 }
100
101 inline virtual std::string _impl_get_label() const { return "ForwardEuler"; };
102
103 virtual std::string _impl_get_tex() const { return "TODO"; }
104 };
105} // namespace shammodels::common::modules
Header file describing a Node Instance.
Source location utility.
std::uint32_t u32
32 bit unsigned integer
std::uint64_t u64
64 bit unsigned integer
void _impl_evaluate_internal()
evaluate the node
virtual std::string _impl_get_label() const
get the label of the node
virtual std::string _impl_get_tex() const
get the tex of the node
Interface for a solver graph edge representing a field as spans.
Inode is node between data edges, takes multiple inputs, multiple outputs.
Definition INode.hpp:30
T & get_rw_edge(int slot)
Get a read write edge and cast it to the type T.
Definition INode.hpp:86
void __internal_set_rw_edges(std::vector< std::shared_ptr< IEdge > > new_rw_edges)
Set the read write edges.
Definition INode.hpp:181
void __internal_set_ro_edges(std::vector< std::shared_ptr< IEdge > > new_ro_edges)
Set the read only edges.
Definition INode.hpp:171
const T & get_ro_edge(int slot)
Get a read only edge and cast it to the type T.
Definition INode.hpp:80
void distributed_data_kernel_call(sham::DeviceScheduler_ptr dev_sched, RefIn in, RefOut in_out, const shambase::DistributedData< index_t > &thread_counts, Functor &&func)
A variant of sham::kernel_call for distributed data.
#define __shamrock_stack_entry()
Macro to create a stack entry.
A variant of sham::MultiRef for distributed data.