Source code for hybrid_jp.dtypes
"""Base data types for the hybrid_jp."""
from dataclasses import dataclass
from typing import Iterator, Protocol, TypeVar
import numpy as np
import numpy.typing as npt
BaseChild = TypeVar("BaseChild", bound="BaseContainer")
arrfloat = npt.NDArray[np.float64]
arrint = npt.NDArray[np.int_]
[docs]@dataclass
class BaseContainer(Protocol):
"""Base container for parameters."""
@property
def all(self) -> dict[str, arrfloat]:
"""Return all parameters as a dict."""
...
def __mul__(self: BaseChild, value: float | int) -> BaseChild:
"""Multiply."""
return type(self)(**{k: v * value for k, v in self.all.items()})
def __rmul__(self: BaseChild, value: float | int) -> BaseChild:
"""Multiply."""
return self.__mul__(value)
def __imul__(self: BaseChild, value: float | int) -> BaseChild:
"""Multiply."""
return self.__mul__(value)
def __iter__(self) -> Iterator[arrfloat]:
"""Iterate over all values."""
for v in self.all.values():
yield v
[docs] def is_2d(self) -> bool:
"""Check if all components are 2D arrays.
Returns:
bool: True if all components are 2D arrays.
"""
return all([len(b.shape) == 2 for b in self.all.values()])
[docs] def mean_over_axis(self: BaseChild, axis: int) -> BaseChild:
"""Take the mean over the specified axis in a 2d array.
axis=0: mean over x
axis=1: mean over y
Args:
axis (int): Axis to take the mean over.
Raises:
ValueError: All components must be 2D arrays.
"""
if not self.is_2d():
raise ValueError("All components must be 2D arrays.")
return type(self)(**{k: np.mean(v, axis=axis) for k, v in self.all.items()})
[docs] def mean_over_y(self: BaseChild) -> BaseChild:
"""Take the mean over the y axis.
Returns:
Mag: Magnetic field components with the mean taken over the y axis.
Raises:
ValueError: All components must be 2D arrays.
"""
return self.mean_over_axis(axis=1)
[docs] def slice_x(self: BaseChild, start: int, stop: int) -> BaseChild:
"""Slice in the x direction."""
if self.is_2d():
return type(self)(**{k: v[start:stop, :] for k, v in self.all.items()})
else:
items = list(self.all.items())
items[0] = (items[0][0], items[0][1][start:stop])
return type(self)(**dict(items))
[docs]@dataclass
class Grid(BaseContainer):
"""Simulation grid.
Note:
Can be either edges or midpoints.
Parameters:
x (np.ndarray): x grid.
y (np.ndarray): y grid.
"""
x: np.ndarray
y: np.ndarray
@property
def all(self) -> dict[str, arrfloat]:
"""All parameters as a dict."""
return dict(x=self.x, y=self.y)
@property
def shape(self) -> tuple[int, int]:
"""Shape of the grid (nx, ny)."""
return self.x.shape[0], self.y.shape[0]
[docs]@dataclass
class Mag(BaseContainer):
"""Magnetic field components.
Parameters:
bx (np.ndarray): x component of magnetic field.
by (np.ndarray): y component of magnetic field.
bz (np.ndarray): z component of magnetic field.
"""
bx: np.ndarray
by: np.ndarray
bz: np.ndarray
@property
def all(self) -> dict[str, arrfloat]:
"""All parameters as a dict."""
return dict(bx=self.bx, by=self.by, bz=self.bz)
[docs]@dataclass
class Elec(BaseContainer):
"""Electric field components.
Parameters:
ex (np.ndarray): x component of electric field.
ey (np.ndarray): y component of electric field.
ez (np.ndarray): z component of electric field.
"""
ex: np.ndarray
ey: np.ndarray
ez: np.ndarray
@property
def all(self) -> dict[str, arrfloat]:
"""All parameters as a dict."""
return dict(ex=self.ex, ey=self.ey, ez=self.ez)
[docs]@dataclass
class Current(BaseContainer):
"""Current components."""
jx: arrfloat
jy: arrfloat
jz: arrfloat
@property
def all(self) -> dict[str, arrfloat]:
"""All currents.
Returns:
dict[str, arrf]: 'jx'|'jy'|'jz', (x, y)
"""
return dict(jx=self.jx, jy=self.jy, jz=self.jz)