Shamrock 2025.10.0
Astrophysical Code
Loading...
Searching...
No Matches
disc_setup.py
1"""
2Standard SPH disc setup helpers.
3"""
4
5# pylint: disable=invalid-name
6
7from dataclasses import dataclass
8from typing import Any, Literal
9
10import shamrock
11from shamrock.utils.numba_helper import maybe_njit
12
13RotationMode = Literal["keplerian", "subkeplerian", "subkeplerian_3d"]
14_VALID_ROTATIONS = ("keplerian", "subkeplerian", "subkeplerian_3d")
15
16
17@dataclass
19 """
20 Helper class to store the profiles of the disc.
21 """
22
23 sigma: Any
24 H: Any
25 vtheta_kepler: Any
26 omega_k: Any
27 cs: Any
28 vtheta: Any = None
29 velocity: Any = None
30 cs_field: Any = None
31
32
33@dataclass
35 """
36 Locally isothermal LP07 disc profiles and Monte Carlo generator helper.
37
38 All radii and masses are expressed in the provided code unit system.
39 """
40
41 units: shamrock.UnitSystem
42 center_mass: float
43 disc_mass: float
44 rin: float
45 rout: float
46 H_r_0: float = 0.05
47 q: float = 0.5
48 p: float = 1.5
49 r0: float = 1.0
50 H_factor: float = 1.0
51 rotation: RotationMode = "subkeplerian"
52 inner_tapering: bool = False
53 outer_tapering: bool = False
54
55 def __post_init__(self) -> None:
56 if self.rin >= self.rout:
57 raise ValueError("rin must be smaller than rout")
58 if self.center_mass <= 0:
59 raise ValueError("center_mass must be positive")
60 if self.disc_mass <= 0:
61 raise ValueError("disc_mass must be positive")
62 if self.rotation not in _VALID_ROTATIONS:
63 raise ValueError(f"rotation must be one of {_VALID_ROTATIONS}, got {self.rotation!r}")
64
65 def _G(self) -> float:
66 return shamrock.Constants(self.units).G()
67
68 def _get_sigma(self) -> Any:
69 if self.outer_tapering:
70 raise NotImplementedError("outer_tapering is not implemented yet")
71
72 rin = self.rin
73 r0 = self.r0
74 p = self.p
75 inner_tapering = self.inner_tapering
76
77 def sigma(r: float) -> float:
78 sigma_0 = 1.0
79 if inner_tapering:
80 sigma_0 *= 1 - (rin / r) ** 0.5
81 return sigma_0 * (r / r0) ** (-p)
82
83 return maybe_njit(sigma)
84
85 def _get_vtheta_kepler(self) -> Any:
86 G = self._G()
87 center_mass = self.center_mass
88
89 def vtheta_kepler(r: float) -> float:
90 return (G * center_mass / r) ** 0.5
91
92 return maybe_njit(vtheta_kepler)
93
94 def _get_omega_k(self, vtheta_kepler: Any) -> Any:
95 def omega_k(r: float) -> float:
96 return vtheta_kepler(r) / r
97
98 return maybe_njit(omega_k)
99
100 def _get_cs(self, vtheta_kepler: Any) -> Any:
101 H_r_0 = self.H_r_0
102 q = self.q
103 r0 = self.r0
104 cs_in = H_r_0 * vtheta_kepler(r0) # == (H_r_0 * r0) * omega_k(r0)
105
106 def cs(r: float) -> float:
107 return ((r / r0) ** (-q)) * cs_in
108
109 return maybe_njit(cs)
110
111 def _get_H(self, cs: Any, omega_k: Any) -> Any:
112 H_factor = self.H_factor
113
114 def H(r: float) -> float:
115 return H_factor * cs(r) / omega_k(r)
116
117 return maybe_njit(H)
118
119 def _get_vtheta_subkeplerian(self, vtheta_kepler: Any, cs: Any) -> Any:
120 p, q = self.p, self.q
121
122 def vtheta(r: float) -> float:
123 return ((vtheta_kepler(r) ** 2) - (2 * p + q) * cs(r) ** 2) ** 0.5
124
125 return maybe_njit(vtheta)
126
127 def _get_velocity_vertical(self, vtheta_kepler: Any, cs: Any) -> Any:
128 p, q = self.p, self.q
129
130 def vtheta_rz(r: float, z: float) -> float:
131 gm_r = vtheta_kepler(r) ** 2
132 term2 = -(p + q + 3.0 / 2.0) * cs(r) ** 2
133 r2z2_sqrt = (r**2 + z**2) ** 0.5
134 term3 = -gm_r * r * (2 * q) * (1 / r - 1 / r2z2_sqrt)
135 return (gm_r + term2 + term3) ** 0.5
136
137 vtheta_rz = maybe_njit(vtheta_rz)
138
139 def velocity(pos):
140 x, y, z = pos[0], pos[1], pos[2]
141 r = (x**2 + y**2) ** 0.5
142 v_mag = vtheta_rz(r, z)
143 return (v_mag * (-y / r), v_mag * (x / r), 0.0)
144
145 return maybe_njit(velocity)
146
147 def _get_rotation(self, vtheta_kepler: Any, cs: Any) -> tuple[Any, Any]:
148 if self.rotation == "keplerian":
149 return vtheta_kepler, None
150 if self.rotation == "subkeplerian":
151 return self._get_vtheta_subkeplerian(vtheta_kepler, cs), None
152 return None, self._get_velocity_vertical(vtheta_kepler, cs)
153
154 def get_profiles(self) -> DiscProfiles:
155 """
156 Get the profiles of the disc.
157 """
158 sigma = self._get_sigma()
159 vtheta_kepler = self._get_vtheta_kepler()
160 omega_k = self._get_omega_k(vtheta_kepler)
161 cs = self._get_cs(vtheta_kepler)
162 H = self._get_H(cs, omega_k)
163 vtheta, velocity = self._get_rotation(vtheta_kepler, cs)
164
165 return DiscProfiles(
166 sigma=sigma,
167 H=H,
168 vtheta_kepler=vtheta_kepler,
169 omega_k=omega_k,
170 cs=cs,
171 vtheta=vtheta,
172 velocity=velocity,
173 )
174
175 def part_mass(self, npart: int) -> float:
176 """
177 Get the mass of a single particle from the total mass & number of particles.
178 """
179 return self.disc_mass / npart
180
181 def cs0(self) -> float:
182 """
183 Get the sound speed at the reference radius.
184 """
185 return self.get_profiles().cs(self.r0)
186
188 self,
189 setup: Any,
190 npart: int,
191 *,
192 random_seed: int = 666,
193 init_h_factor: float = 0.8,
194 ):
195 """
196 Make a SPH generator for the disc.
197 """
198 profiles = self.get_profiles()
199 kwargs = {
200 "part_mass": self.part_mass(npart),
201 "disc_mass": self.disc_mass,
202 "r_in": self.rin,
203 "r_out": self.rout,
204 "sigma_profile": profiles.sigma,
205 "H_profile": profiles.H,
206 "random_seed": random_seed,
207 "init_h_factor": init_h_factor,
208 }
209 if profiles.velocity is not None:
210 kwargs["velocity_field"] = profiles.velocity
211 else:
212 kwargs["rot_profile"] = profiles.vtheta
213 if profiles.cs_field is not None:
214 kwargs["cs_field"] = profiles.cs_field
215 else:
216 kwargs["cs_profile"] = profiles.cs
217 return setup.make_generator_disc_mc(**kwargs)
Any _get_vtheta_subkeplerian(self, Any vtheta_kepler, Any cs)
make_generator(self, Any setup, int npart, *, int random_seed=666, float init_h_factor=0.8)
Any _get_cs(self, Any vtheta_kepler)
Any _get_velocity_vertical(self, Any vtheta_kepler, Any cs)
tuple[Any, Any] _get_rotation(self, Any vtheta_kepler, Any cs)
Any _get_H(self, Any cs, Any omega_k)
Any _get_omega_k(self, Any vtheta_kepler)
Definition disc_setup.py:94