from typing import Any, Callable
import numpy as np
import numpy.typing as npt
from epoch_cheats.deck import Deck
from ..arrays import create_orthonormal_basis_from_vec, rotate_arr_to_new_basis
from ..sdf_files import SDF
[docs]class NoTimestampError(Exception):
...
[docs]class NChunksNotSetError(Exception):
"""Raised in a CentredShock when n_chunks is not set."""
message = "n_chunks must be set before calling this method."
def __init__(self, message: str = message) -> None:
self.message = message
super().__init__(message)
[docs]class InvalidChunkError(Exception):
"""Raised when a chunk is not valid."""
def __init__(self, chunk_idx: int, t_idx: int) -> None:
self.message = f"Chunk {chunk_idx} at timestep {t_idx} is not valid."
super().__init__(self.message)
[docs]class CenteredShock:
def __init__(self, sdfs: list[SDF], deck: Deck) -> None:
self.sdfs = sdfs
self.deck = deck
self.time = self._get_time()
self.grid = self.sdfs[0].mid_grid
self._nd_median_y_cm = np.asarray(
[np.median(sdf.numberdensity, axis=1) for sdf in sdfs]
).T / (100**3)
self.shock_i = self.shock_index_from_nd()
self.shock_x = self.grid.x[self.shock_i]
self.dx = self.grid.x[1] - self.grid.x[0]
self.dy = self.grid.y[1] - self.grid.y[0]
self.dist_either_side = self._get_dist_either_side()
self.max_widths = self.dist_either_side.max(axis=0)
self.full_width = sum(self.max_widths)
self.n_chunks: int | None = None
# Can be set by set_chunk_i_missing
self._chunks_info: dict[str, Any] = {}
self._chunk_i: int | None = None
self._missing: npt.NDArray[np.int_] | None = None
self._downstream_start_chunk: int | None = None
# Can be set by set_valid_chunks
self._valid_chunks: npt.NDArray[np.bool_] | None = None
# Can be set by set_start_offset
self._start_offset: npt.NDArray[np.int64] | None = None
def _get_chunks_info(self) -> dict[str, Any]:
if not self._chunks_info:
self._chunks_info = self.set_chunk_i_missing()
return self._chunks_info
@property
def chunk_i(self) -> int:
if self._chunk_i is None:
ci = self._get_chunks_info()
self._chunk_i = ci["chunk_i"]
return self._chunk_i
@property
def missing(self) -> npt.NDArray[np.int_]:
if self._missing is None:
ci = self._get_chunks_info()
self._missing = ci["missing"]
return self._missing
@property
def downstream_start_chunk(self) -> int:
if self._downstream_start_chunk is None:
ci = self._get_chunks_info()
self._downstream_start_chunk = ci["downstream_start_chunk"]
return self._downstream_start_chunk
@property
def valid_chunks(self) -> npt.NDArray[np.bool_]:
if self._valid_chunks is None:
self._valid_chunks = self.set_valid_chunks()
return self._valid_chunks
@property
def start_offset(self) -> npt.NDArray[np.int64]:
if self._start_offset is None:
self._start_offset = self.set_start_offset()
return self._start_offset
def _get_time(self):
times = [sdf.tstamp for sdf in self.sdfs]
if any(t is None for t in times):
raise NoTimestampError("All SDFs require a tstamp that is not None.")
return np.asarray(times)
[docs] def shock_index_from_nd(self, threshold=0.37):
mask = self._nd_median_y_cm > (self._nd_median_y_cm.max() * threshold)
shock_i = np.asarray(np.argmax(mask, axis=0))
return shock_i
def _get_dist_either_side(self) -> npt.NDArray[np.float64]:
"""Number of elements in the array before & after the shock.
Returns:
dist_either_side: The distance from the shock in cells from the left
and right boundaries
"""
# Get the distance from the shock in cells from the left and right boundaries
dist_either_side = np.asarray(
[[s, abs(s - self.grid.x.size)] for s in self.shock_i]
)
return dist_either_side
[docs] def reshape_qty_to_shock_arr(self, qty: npt.NDArray[np.float64]):
"""Reshape qty so that shock is aligned.
Note:
- qty must have shape (grid.x.size, time.size) i.e. the first dimension
is the width of the grid in x and the second is the number of timesteps.
Args:
qty (npt.NDArray[np.float64]): The quantity to be reshaped.
"""
required_shape = np.asarray([self.grid.x.size, self.time.size])
if not np.array_equal(qty.shape, required_shape):
raise ValueError(
f"qty has shape {qty.shape} but required"
f"shape is {tuple(required_shape)}"
)
arr = np.empty((self.full_width, self.time.size))
arr[:] = np.nan
for i in range(self.time.size):
insert_i = self.max_widths[0] - self.dist_either_side[i][0]
arr[insert_i : insert_i + self.grid.x.size, i] = qty[:, i]
return arr
[docs] def set_chunk_i_missing(self) -> dict[str, Any]:
if self.n_chunks is None:
raise NChunksNotSetError()
# Needs to be 1 larger than the number of chunks
chunk_grow = self.n_chunks + 1
chunk_i: int = np.floor(self.full_width / chunk_grow).astype(int)
ratio = self.max_widths / self.max_widths[1]
fraction = ratio * (chunk_grow / ratio.sum())
fraction = np.floor(fraction).astype(int)
missing: npt.NDArray[np.int_] = self.max_widths - fraction * chunk_i
self._chunk_i = chunk_i
self._missing = missing
self._downstream_start_chunk = fraction[0]
return dict(
chunk_i=chunk_i,
missing=missing,
downstream_start_chunk=fraction[0],
)
[docs] def set_valid_chunks(self):
if self.n_chunks is None:
raise NChunksNotSetError()
# Reshape an array full of ones
arr = self.reshape_qty_to_shock_arr(np.ones((self.grid.x.size, self.time.size)))
arr[np.isnan(arr)] = 0 # Set areas of no data to False instead of NaN
arr.astype(bool) # arr is 1 where there would be data and 0 elsewhere.
valid_chunks = np.zeros((self.n_chunks, self.time.size), dtype=bool)
for i in range(self.time.size):
valid = np.empty(self.n_chunks, dtype=bool)
for j in range(self.n_chunks):
cstart = self.chunk_i * j + self.missing[0]
cend = cstart + self.chunk_i
if cend > arr.shape[0]:
raise Exception(
f"End index for chunk {j} at timestep {i} is {cend} which is "
f"larger than the maximum size of {arr.shape[0]}"
)
valid[j] = arr[cstart:cend, i].all()
valid_chunks[:, i] = valid
self._valid_chunks = valid_chunks
return valid_chunks
[docs] def set_start_offset(self) -> npt.NDArray[np.int64]:
data_start = self.max_widths[0] - self.dist_either_side[:, 0]
first_chunk = (
np.argmax(self.valid_chunks, axis=0) * self.chunk_i + self.missing[0]
)
start_offset = (first_chunk - data_start).astype(np.int64)
self._start_offset = start_offset
return start_offset
[docs] def get_x_offset_for_frame(self, chunk: int, t_idx: int) -> int:
"""The start index for a chunk at a timestep."""
if not self.valid_chunks[chunk, t_idx]:
raise InvalidChunkError(chunk, t_idx)
data_chunk = (np.cumsum(self.valid_chunks, axis=0) - 1) * self.valid_chunks
return data_chunk[chunk, t_idx] * self.chunk_i + self.start_offset[t_idx]
[docs] def get_qty_in_frame(
self,
qty_func: Callable[[SDF], npt.NDArray[np.float64]],
chunk: int,
t_idx: int,
) -> tuple[npt.NDArray[np.float64], tuple[int, int]]:
"""Extract from an SDF using qty_func at a chosen chunk and timestep.
Example:
>>> def get_nd(sdf: SDF) -> npt.NDArray[np.float64]:
... return sdf.numberdensity
>>> qty, start_stop = cs.get_qty_in_frame(get_qty, 0, 0)
>>> def get_median_bx(sdf: SDF) -> npt.NDArray[np.float64]:
... return np.median(sdf.mag.bx, axis=1)
>>> qty, start_stop = cs.get_qty_in_frame(get_median_bx, 0, 0)
>>> plt.plot(x=cs.grid.x[slice(*start_stop)], y=qty)
"""
if self.n_chunks is None:
raise NChunksNotSetError()
if 0 > chunk >= self.n_chunks:
raise ValueError(f"chunk must be in range [0, {self.n_chunks}).")
chunk_offset = self.get_x_offset_for_frame(chunk, t_idx)
start_stop = (chunk_offset, chunk_offset + self.chunk_i)
return qty_func(self.sdfs[t_idx])[slice(*start_stop)], start_stop