Source code for hybrid_jp.config
import tomllib
from pathlib import Path
from typing import Any, Dict
from pydantic import BaseModel, DirectoryPath, NonNegativeInt, RootModel
[docs]class ShockParams(BaseModel):
name: str
test_data_dir: DirectoryPath
data_dir: Path
n_chunks: NonNegativeInt
n_threads: NonNegativeInt
start_sdf: NonNegativeInt
stop_sdf: NonNegativeInt
[docs]class Shocks(RootModel):
root: Dict[str, ShockParams]
def __iter__(self):
return iter(self.root)
def __getitem__(self, item):
return self.root[item]
def __getattr__(self, attr):
return self.root[attr]
[docs]class Config(BaseModel):
shocks: Shocks
use_shock: str
@property
def shk(self):
shock_name = self.use_shock
if shock_name not in self.shocks:
raise KeyError(f"use_shock={shock_name} is not a valid shock.")
return self.shocks[shock_name]
[docs]class TOMLConfigLoader:
def __init__(self, config_toml_path: str | Path) -> None:
if not isinstance(config_toml_path, Path):
self.config_toml_path = Path(config_toml_path)
else:
self.config_toml_path = config_toml_path
@staticmethod
def _broadcast_defaults(attrs: dict[str, Any]) -> dict[str, Any]:
data = attrs.copy()
defaults = data.pop("default")
for shock_name, shock_params in data.copy().items():
data[shock_name]["name"] = shock_name
for parameter_name, parameter_value in defaults.items():
if parameter_name not in shock_params:
data[shock_name][parameter_name] = parameter_value
return data
[docs] def load(self) -> dict[str, Any]:
with open(self.config_toml_path, "rb") as file:
data = tomllib.load(file)
data["shocks"] = self._broadcast_defaults(data["shocks"])
return data
[docs]def config_from_toml(toml_path: Path | str) -> Config:
loader = TOMLConfigLoader(toml_path)
return Config(**loader.load())