Source code for swc.aeon.schema.base
"""Base classes for defining experiment configuration and data models."""
import datetime
import os
from collections.abc import Callable
from functools import cached_property
from pathlib import Path
from typing import Self, TypeVar
import pandas as pd
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, model_validator
from pydantic.alias_generators import to_camel, to_pascal
from swc.aeon.io.reader import Reader
[docs]
class BaseSchema(BaseModel):
"""The base class for all experiment configuration and data models."""
model_config = ConfigDict(
alias_generator=to_camel,
arbitrary_types_allowed=True,
field_title_generator=lambda n, _: to_pascal(n),
populate_by_name=True,
from_attributes=True,
)
_container_prefix: str = ""
_container: "BaseSchema | None" = None
def _join_pattern_prefix(self, pattern_prefix: str) -> str:
return self._container_prefix
def _resolve_pattern_prefix(self) -> str:
container = self._container
pattern_prefix = self._container_prefix
while container is not None:
pattern_prefix = container._join_pattern_prefix(pattern_prefix)
container = container._container
return pattern_prefix
@model_validator(mode="after")
def _validate_container_prefix(self) -> Self:
for name in self.__class__.model_fields:
f = getattr(self, name)
if isinstance(f, dict):
for nk, nv in f.items():
if isinstance(nv, BaseSchema):
nv._container_prefix = nk
nv._container = self
elif isinstance(f, BaseSchema):
f._container_prefix = to_pascal(name)
f._container = self
return self
[docs]
class Experiment(BaseSchema):
"""The base class for creating experiment models."""
workflow: str = Field(description="Path to the workflow running the experiment.")
commit: str = Field(description="Commit hash of the experiment repo.")
repository_url: str = Field(
description="The URL of the git repository used to version experiment source code."
)
[docs]
class Dataset(BaseSchema):
"""The base class for creating dataset models."""
def _join_pattern_prefix(self, pattern_prefix: str) -> str:
return os.path.join(self._container_prefix, pattern_prefix)
ModelT = TypeVar("ModelT", bound=BaseSchema)
_SelfBaseSchema = TypeVar("_SelfBaseSchema", bound=BaseSchema)
_ReaderT = TypeVar("_ReaderT", bound=Reader)
[docs]
def data_reader(func: Callable[[_SelfBaseSchema, str], _ReaderT]) -> cached_property[_ReaderT]:
"""Decorator to include a data reader as `cached_property` in experiment dataset models."""
def decorator(self: _SelfBaseSchema) -> _ReaderT:
pattern_prefix = self._resolve_pattern_prefix() # pyright: ignore[reportPrivateUsage]
return func(self, pattern_prefix)
return cached_property(decorator)