From 3e288b1023e08ccf1b833111034abb3fad5a4abf Mon Sep 17 00:00:00 2001 From: Momchil Minkov Date: Wed, 12 Nov 2025 11:36:26 +0100 Subject: [PATCH] feat: Moving TYPE_TO_CLASS_MAP to Tidy3dBaseModel and allow Tidy3dBaseModel.from_file for various components --- CHANGELOG.md | 1 + tests/test_components/test_base.py | 19 +++ .../test_plugins/expressions/test_dispatch.py | 19 +++ tests/test_web/test_local_cache.py | 40 ++++- tests/test_web/test_tidy3d_stub.py | 24 +-- tests/test_web/test_webapi.py | 12 +- tidy3d/components/base.py | 106 ++++++++++-- tidy3d/plugins/expressions/base.py | 19 +-- tidy3d/web/api/tidy3d_stub.py | 151 +++--------------- 9 files changed, 209 insertions(+), 182 deletions(-) create mode 100644 tests/test_plugins/expressions/test_dispatch.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e9e7cd2363..44f1a125d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Introduced `BroadbandPulse` for exciting simulations across a wide frequency spectrum. - Added `interp_spec` in `ModeSpec` to allow downsampling and interpolation of waveguide modes in frequency. - Added warning if port mesh refinement is incompatible with the `GridSpec` in the `TerminalComponentModeler`. +- Various types, e.g. different `Simulation` or `SimulationData` sub-classes, can be loaded from file directly with `Tidy3dBaseModel.from_file()`. ### Breaking Changes - Edge singularity correction at PEC and lossy metal edges defaults to `True`. diff --git a/tests/test_components/test_base.py b/tests/test_components/test_base.py index 4118359044..b996025aa6 100644 --- a/tests/test_components/test_base.py +++ b/tests/test_components/test_base.py @@ -2,8 +2,11 @@ from __future__ import annotations +from typing import Literal + import numpy as np import pytest +from pydantic.v1 import ValidationError import tidy3d as td from tidy3d.components.base import Tidy3dBaseModel @@ -275,3 +278,19 @@ def test_updated_hash_and_json_with_changed_attr(): assert new_hash != old_hash assert json_old != json_new + + +def test_parse_obj_respects_subclasses(): + class DispatchBase(Tidy3dBaseModel): + type: Literal["DispatchBase"] = "DispatchBase" + value: int + + class DispatchChild(DispatchBase): + type: Literal["DispatchChild"] = "DispatchChild" + + data = {"type": "DispatchChild", "value": 1} + parsed = Tidy3dBaseModel._parse_model_dict(data) + assert isinstance(parsed, DispatchChild) + + with pytest.raises(ValidationError): + DispatchChild.parse_obj({"type": "DispatchBase", "value": 2}) diff --git a/tests/test_plugins/expressions/test_dispatch.py b/tests/test_plugins/expressions/test_dispatch.py new file mode 100644 index 0000000000..ee7fbec31c --- /dev/null +++ b/tests/test_plugins/expressions/test_dispatch.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import pytest + +from tidy3d.plugins.expressions.base import Expression +from tidy3d.plugins.expressions.variables import Constant + + +def test_expression_parse_obj_round_trip(): + expr = Constant(3.14) + parsed = Expression.parse_obj(expr.dict()) + assert isinstance(parsed, Constant) + assert parsed.value == pytest.approx(3.14) + + +def test_expression_parse_obj_rejects_unrelated_types(): + # Simulation registers a distinct type in the global map; parsing via Expression should fail. + with pytest.raises(ValueError, match="Cannot parse type"): + Expression.parse_obj({"type": "Simulation"}) diff --git a/tests/test_web/test_local_cache.py b/tests/test_web/test_local_cache.py index a3985000fd..0762fe08c7 100644 --- a/tests/test_web/test_local_cache.py +++ b/tests/test_web/test_local_cache.py @@ -26,7 +26,11 @@ from tidy3d.web.api import webapi as web from tidy3d.web.api.autograd import autograd, engine, io_utils from tidy3d.web.api.autograd.autograd import run as run_autograd -from tidy3d.web.api.autograd.constants import SIM_VJP_FILE +from tidy3d.web.api.autograd.constants import ( + AUX_KEY_SIM_DATA_FWD, + AUX_KEY_SIM_DATA_ORIGINAL, + SIM_VJP_FILE, +) from tidy3d.web.api.container import Batch, WebContainer from tidy3d.web.api.webapi import load_simulation_if_cached from tidy3d.web.cache import ( @@ -181,13 +185,37 @@ def _fake_download_file(resource_id, remote_filename, to_file=None, **kwargs): if str(remote_filename) == SIM_VJP_FILE: counters["download"] += 1 - def _fake_from_file(*args, **kwargs): - field_map = FieldMap(tracers=()) - return field_map + def _fake_postprocess_fwd(*, sim_data_combined=None, sim_original=None, aux_data=None, **_): + """Mimic ``autograd.postprocess_fwd`` side effects for tests.""" + if sim_original is None: + sim_original = next(iter(PATH_TO_SIM.values()), None) + if sim_original is None: + sim_original = td.Simulation( + size=(1, 1, 1), + grid_spec=td.GridSpec.auto(wavelength=1.0), + run_time=1e-12, + ) + stub_data = _FakeStubData(sim_original) + if aux_data is not None: + aux_data[AUX_KEY_SIM_DATA_ORIGINAL] = stub_data + aux_data[AUX_KEY_SIM_DATA_FWD] = stub_data + return stub_data._strip_traced_fields() + + def _fake_postprocess_adj( + sim_data_adj=None, sim_data_orig=None, sim_data_fwd=None, sim_fields_keys=None, **_ + ): + """Return zeros for every requested field key.""" + counters["download"] += 1 # mimic VJP file download per autograd run + sim_fields_keys = sim_fields_keys or [] + return dict.fromkeys(sim_fields_keys, 0.0) + + def _fake_field_map_from_file(*args, **kwargs): + return FieldMap(tracers=()) monkeypatch.setattr(io_utils, "download_file", _fake_download_file) - monkeypatch.setattr(autograd, "postprocess_fwd", _fake_from_file) - monkeypatch.setattr(FieldMap, "from_file", _fake_from_file) + monkeypatch.setattr(autograd, "postprocess_fwd", _fake_postprocess_fwd) + monkeypatch.setattr(autograd, "postprocess_adj", _fake_postprocess_adj) + monkeypatch.setattr(FieldMap, "from_file", _fake_field_map_from_file) monkeypatch.setattr(WebContainer, "_check_folder", _fake__check_folder) monkeypatch.setattr(web, "upload", _fake_upload) monkeypatch.setattr(web, "start", _fake_start) diff --git a/tests/test_web/test_tidy3d_stub.py b/tests/test_web/test_tidy3d_stub.py index b51d5d519e..9d0f283ab7 100644 --- a/tests/test_web/test_tidy3d_stub.py +++ b/tests/test_web/test_tidy3d_stub.py @@ -42,6 +42,15 @@ def make_sim(): ) +def is_lazy_object(data): + assert set(data.__dict__.keys()) == { + "_lazy_fname", + "_lazy_group_path", + "_lazy_parse_obj_kwargs", + } + return True + + def make_sim_data(file_size_gb=0.001): """Makes a simulation data.""" N = int(2.528e8 / 4 * file_size_gb) @@ -146,22 +155,16 @@ def test_stub_data_lazy_loading(tmp_path): sim_data = Tidy3dStubData.postprocess(file_path, lazy=True) sim_data_copy = sim_data.copy() - assert type(sim_data).__name__ == "SimulationDataProxy" - assert type(sim_data_copy).__name__ == "SimulationDataProxy" + + # variable dict should only contain metadata to load the data, not the data itself + assert is_lazy_object(sim_data) # the type should be still SimulationData despite being lazy assert isinstance(sim_data, SimulationData) - # variable dict should only contain metadata to load the data, not the data itself - assert set(sim_data.__dict__.keys()) == { - "_lazy_fname", - "_lazy_group_path", - "_lazy_parse_obj_kwargs", - } - # we expect a warning from the lazy object if some field is accessed with AssertLogLevel("WARNING", contains_str=sim_diverged_log): - _ = sim_data.monitor_data + _ = sim_data_copy.monitor_data @pytest.mark.parametrize( @@ -185,6 +188,7 @@ def test_stub_pathlike_roundtrip(tmp_path, path_builder): # Simulation data stub roundtrip sim_data = make_sim_data() + sim_data = sim_data.updated_copy(log="log") stub_data = Tidy3dStubData(data=sim_data) data_path = path_builder(tmp_path, "pathlike_data.hdf5") stub_data.to_file(data_path) diff --git a/tests/test_web/test_webapi.py b/tests/test_web/test_webapi.py index 48ee90f4bf..074c44b069 100644 --- a/tests/test_web/test_webapi.py +++ b/tests/test_web/test_webapi.py @@ -15,6 +15,7 @@ from responses import matchers import tidy3d as td +from tests.test_web.test_tidy3d_stub import is_lazy_object from tidy3d import Simulation from tidy3d.__main__ import main from tidy3d.components.data.data_array import ScalarFieldDataArray @@ -958,21 +959,18 @@ def test_run_with_flexible_containers_offline_lazy(monkeypatch, tmp_path): apply_common_patches(monkeypatch, tmp_path, taskid_to_sim=taskid_to_sim) data = run(sim_container, task_name=task_name, folder_name="PROJECT", path=str(out_dir)) - + assert is_lazy_object(data[0]) assert isinstance(data, list) and len(data) == 3 - assert isinstance(data[0], SimulationData) - assert data[0].__class__.__name__ == "SimulationDataProxy" - assert isinstance(data[1], dict) assert "sim2" in data[1] + assert is_lazy_object(data[1]["sim2"]) assert isinstance(data[1]["sim2"], SimulationData) - assert data[1]["sim2"].__class__.__name__ == "SimulationDataProxy" + assert is_lazy_object(data[2][0]) assert isinstance(data[2], tuple) - assert data[2][0].__class__.__name__ == "SimulationDataProxy" + assert is_lazy_object(data[2][1][0]) assert isinstance(data[2][1], list) - assert data[2][1][0].__class__.__name__ == "SimulationDataProxy" assert data[0].simulation == sim1 assert data[1]["sim2"].simulation == sim2 diff --git a/tidy3d/components/base.py b/tidy3d/components/base.py index 2ad302040a..cc7a8d91c5 100644 --- a/tidy3d/components/base.py +++ b/tidy3d/components/base.py @@ -42,6 +42,7 @@ MAX_STRING_LENGTH = 1_000_000_000 FORBID_SPECIAL_CHARACTERS = ["/"] TRACED_FIELD_KEYS_ATTR = "__tidy3d_traced_field_keys__" +TYPE_TO_CLASS_MAP: dict[str, type[Tidy3dBaseModel]] = {} def cache(prop): @@ -195,6 +196,76 @@ def __init_subclass__(cls) -> None: cls.add_type_field() cls.generate_docstring() + type_value = cls.__fields__.get(TYPE_TAG_STR) + if type_value and type_value.default: + TYPE_TO_CLASS_MAP[type_value.default] = cls + + @classmethod + def _get_type_value(cls, obj: dict[str, Any]) -> str: + """Return the type tag from a raw dictionary.""" + if not isinstance(obj, dict): + raise TypeError("Input must be a dict") + try: + type_value = obj[TYPE_TAG_STR] + except KeyError as exc: + raise ValueError(f'Missing "{TYPE_TAG_STR}" in data') from exc + if not isinstance(type_value, str) or not type_value: + raise ValueError(f'Invalid "{TYPE_TAG_STR}" value: {type_value!r}') + return type_value + + @classmethod + def _get_registered_class(cls, type_value: str) -> type[Tidy3dBaseModel]: + try: + return TYPE_TO_CLASS_MAP[type_value] + except KeyError as exc: + raise ValueError(f"Unknown type: {type_value}") from exc + + @classmethod + def _should_dispatch_to(cls, target_cls: type[Tidy3dBaseModel]) -> bool: + """Return True if ``cls`` allows auto-dispatch to ``target_cls``.""" + return issubclass(target_cls, cls) + + @classmethod + def _resolve_dispatch_target(cls, obj: dict[str, Any]) -> type[Tidy3dBaseModel]: + """Determine which subclass should receive ``obj``.""" + type_value = cls._get_type_value(obj) + target_cls = cls._get_registered_class(type_value) + if cls._should_dispatch_to(target_cls): + return target_cls + if target_cls is cls: + return cls + raise ValueError( + f'Cannot parse type "{type_value}" using {cls.__name__}; expected subclass of {cls.__name__}.' + ) + + @classmethod + def _target_cls_from_file( + cls, fname: PathLike, group_path: Optional[str] = None + ) -> type[Tidy3dBaseModel]: + """Peek the file metadata to determine the subclass to instantiate.""" + model_dict = cls.dict_from_file( + fname=fname, + group_path=group_path, + load_data_arrays=False, + ) + return cls._resolve_dispatch_target(model_dict) + + @classmethod + def _parse_obj(cls, obj: dict[str, Any], **parse_obj_kwargs: Any) -> Tidy3dBaseModel: + """Dispatch ``obj`` to the correct subclass registered in the type map.""" + target_cls = cls._resolve_dispatch_target(obj) + if target_cls is cls: + return super().parse_obj(obj, **parse_obj_kwargs) + return target_cls.parse_obj(obj, **parse_obj_kwargs) + + @classmethod + def _parse_model_dict( + cls, model_dict: dict[str, Any], **parse_obj_kwargs: Any + ) -> Tidy3dBaseModel: + """Parse ``model_dict`` while optionally auto-dispatching when called on the base class.""" + if cls is Tidy3dBaseModel: + return cls._parse_obj(model_dict, **parse_obj_kwargs) + return cls.parse_obj(model_dict, **parse_obj_kwargs) class Config: """Sets config for all :class:`Tidy3dBaseModel` objects. @@ -404,16 +475,19 @@ def from_file( >>> simulation = Simulation.from_file(fname='folder/sim.json') # doctest: +SKIP """ if lazy: - Proxy = _make_lazy_proxy(cls, on_load=on_load) # staticmethod usage + target_cls = cls._target_cls_from_file(fname=fname, group_path=group_path) + Proxy = _make_lazy_proxy(target_cls, on_load=on_load) return Proxy(fname, group_path, parse_obj_kwargs) model_dict = cls.dict_from_file(fname=fname, group_path=group_path) - obj = cls.parse_obj(model_dict, **parse_obj_kwargs) + obj = cls._parse_model_dict(model_dict, **parse_obj_kwargs) if not lazy and on_load is not None: on_load(obj) return obj @classmethod - def dict_from_file(cls, fname: PathLike, group_path: Optional[str] = None) -> dict: + def dict_from_file( + cls, fname: PathLike, group_path: Optional[str] = None, *, load_data_arrays: bool = True + ) -> dict: """Loads a dictionary containing the model from a .yaml, .json, .hdf5, or .hdf5.gz file. Parameters @@ -437,11 +511,14 @@ def dict_from_file(cls, fname: PathLike, group_path: Optional[str] = None) -> di kwargs = {"fname": fname_path} if group_path is not None: - if extension == ".hdf5" or extension == ".hdf5.gz": + if extension in {".hdf5", ".hdf5.gz", ".h5"}: kwargs["group_path"] = group_path else: log.warning("'group_path' provided, but this feature only works with hdf5 files.") + if extension in {".hdf5", ".hdf5.gz", ".h5"}: + kwargs["load_data_arrays"] = load_data_arrays + converter = { ".json": cls.dict_from_json, ".yaml": cls.dict_from_yaml, @@ -493,7 +570,7 @@ def from_json(cls, fname: PathLike, **parse_obj_kwargs: Any) -> Self: >>> simulation = Simulation.from_json(fname='folder/sim.json') # doctest: +SKIP """ model_dict = cls.dict_from_json(fname=fname) - return cls.parse_obj(model_dict, **parse_obj_kwargs) + return cls._parse_model_dict(model_dict, **parse_obj_kwargs) @classmethod def dict_from_json(cls, fname: PathLike) -> dict: @@ -558,7 +635,7 @@ def from_yaml(cls, fname: PathLike, **parse_obj_kwargs: Any) -> Self: >>> simulation = Simulation.from_yaml(fname='folder/sim.yaml') # doctest: +SKIP """ model_dict = cls.dict_from_yaml(fname=fname) - return cls.parse_obj(model_dict, **parse_obj_kwargs) + return cls._parse_model_dict(model_dict, **parse_obj_kwargs) @classmethod def dict_from_yaml(cls, fname: PathLike) -> dict: @@ -679,6 +756,7 @@ def dict_from_hdf5( fname: PathLike, group_path: str = "", custom_decoders: Optional[list[Callable]] = None, + load_data_arrays: bool = True, ) -> dict: """Loads a dictionary containing the model contents from a .hdf5 file. @@ -752,7 +830,8 @@ def load_data_from_file(model_dict: dict, group_path: str = "") -> None: model_dict = json.loads(cls._json_string_from_hdf5(fname=fname_path)) group_path = cls._construct_group_path(group_path) model_dict = cls.get_sub_model(group_path=group_path, model_dict=model_dict) - load_data_from_file(model_dict=model_dict, group_path=group_path) + if load_data_arrays: + load_data_from_file(model_dict=model_dict, group_path=group_path) return model_dict @classmethod @@ -790,7 +869,7 @@ def from_hdf5( group_path=group_path, custom_decoders=custom_decoders, ) - return cls.parse_obj(model_dict, **parse_obj_kwargs) + return cls._parse_model_dict(model_dict, **parse_obj_kwargs) def to_hdf5( self, @@ -861,6 +940,7 @@ def dict_from_hdf5_gz( fname: PathLike, group_path: str = "", custom_decoders: Optional[list[Callable]] = None, + load_data_arrays: bool = True, ) -> dict: """Loads a dictionary containing the model contents from a .hdf5.gz file. @@ -893,6 +973,7 @@ def dict_from_hdf5_gz( extracted_path, group_path=group_path, custom_decoders=custom_decoders, + load_data_arrays=load_data_arrays, ) finally: extracted_path.unlink(missing_ok=True) @@ -934,7 +1015,7 @@ def from_hdf5_gz( group_path=group_path, custom_decoders=custom_decoders, ) - return cls.parse_obj(model_dict, **parse_obj_kwargs) + return cls._parse_model_dict(model_dict, **parse_obj_kwargs) def to_hdf5_gz( self, fname: PathLike | io.BytesIO, custom_encoders: Optional[list[Callable]] = None @@ -1148,7 +1229,7 @@ def insert_value(x, path: tuple[str, ...], sub_dict: dict) -> None: for path, value in field_mapping.items(): insert_value(value, path=path, sub_dict=self_dict) - return self.parse_obj(self_dict) + return type(self)._parse_model_dict(self_dict) def _serialized_traced_field_keys( self, field_mapping: AutogradFieldMap | None = None @@ -1360,6 +1441,7 @@ def _make_lazy_proxy( A class named ``Proxy`` with init args: ``(fname, group_path, parse_obj_kwargs)``. """ + proxy_name = f"{target_cls.__name__}Proxy" class _LazyProxy(target_cls): @@ -1398,11 +1480,11 @@ def __getattribute__(self, name: str): kwargs = d["_lazy_parse_obj_kwargs"] model_dict = target_cls.dict_from_file(fname=fname, group_path=group_path) - target = target_cls.parse_obj(model_dict, **kwargs) + target = target_cls._parse_model_dict(model_dict, **kwargs) d.clear() d.update(target.__dict__) - object.__setattr__(self, "__class__", target_cls) + object.__setattr__(self, "__class__", target.__class__) object.__setattr__(self, "__fields_set__", set(target.__fields_set__)) private_attrs = getattr(target, "__private_attributes__", {}) or {} for attr_name in private_attrs: diff --git a/tidy3d/plugins/expressions/base.py b/tidy3d/plugins/expressions/base.py index ff52b648ef..a3c1ed4b3d 100644 --- a/tidy3d/plugins/expressions/base.py +++ b/tidy3d/plugins/expressions/base.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Optional from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.types import TYPE_TAG_STR from .types import ExpressionType, NumberOrExpression, NumberType @@ -23,8 +22,6 @@ Subtract, ) -TYPE_TO_CLASS_MAP: dict[str, Any] = {} - class Expression(Tidy3dBaseModel, ABC): """ @@ -44,23 +41,9 @@ def evaluate(self, *args: Any, **kwargs: Any) -> NumberType: def __call__(self, *args: Any, **kwargs: Any) -> NumberType: return self.evaluate(*args, **kwargs) - def __init_subclass__(cls, **kwargs: dict[str, Any]) -> None: - super().__init_subclass__(**kwargs) - type_value = cls.__fields__.get(TYPE_TAG_STR) - if type_value and type_value.default: - TYPE_TO_CLASS_MAP[type_value.default] = cls - @classmethod def parse_obj(cls, obj: dict[str, Any]) -> ExpressionType: - if not isinstance(obj, dict): - raise TypeError("Input must be a dict") - type_value = obj.get(TYPE_TAG_STR) - if type_value is None: - raise ValueError('Missing "type" in data') - subclass = TYPE_TO_CLASS_MAP.get(type_value) - if subclass is None: - raise ValueError(f"Unknown type: {type_value}") - return subclass(**obj) + return super()._parse_obj(obj) def filter( self, target_type: type[Expression], target_field: Optional[str] = None diff --git a/tidy3d/web/api/tidy3d_stub.py b/tidy3d/web/api/tidy3d_stub.py index fc979a1e2d..a55bafe6f9 100644 --- a/tidy3d/web/api/tidy3d_stub.py +++ b/tidy3d/web/api/tidy3d_stub.py @@ -2,30 +2,22 @@ from __future__ import annotations -import json from datetime import datetime from os import PathLike -from pathlib import Path from typing import Callable, Optional import pydantic.v1 as pd from pydantic.v1 import BaseModel from tidy3d import log -from tidy3d.components.base import _get_valid_extension +from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.data.monitor_data import ModeSolverData from tidy3d.components.data.sim_data import SimulationData -from tidy3d.components.eme.data.sim_data import EMESimulationData from tidy3d.components.eme.simulation import EMESimulation from tidy3d.components.microwave.data.monitor_data import MicrowaveModeSolverData from tidy3d.components.mode.data.sim_data import ModeSimulationData from tidy3d.components.mode.simulation import ModeSimulation from tidy3d.components.simulation import Simulation -from tidy3d.components.tcad.data.sim_data import ( - HeatChargeSimulationData, - HeatSimulationData, - VolumeMesherData, -) from tidy3d.components.tcad.mesher import VolumeMesher from tidy3d.components.tcad.simulation.heat import HeatSimulation from tidy3d.components.tcad.simulation.heat_charge import HeatChargeSimulation @@ -43,11 +35,6 @@ from tidy3d.plugins.smatrix.data.terminal import ( TerminalComponentModelerData, ) -from tidy3d.web.core.file_util import ( - read_simulation_from_hdf5, - read_simulation_from_hdf5_gz, - read_simulation_from_json, -) from tidy3d.web.core.stub import TaskStub, TaskStubData from tidy3d.web.core.types import TaskType @@ -76,93 +63,41 @@ class Tidy3dStub(BaseModel, TaskStub): @classmethod def from_file(cls, file_path: PathLike) -> WorkflowType: - """Loads a Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] - from .yaml, .json, or .hdf5 file. + """Loads a ``WorkflowType`` instance from .yaml, .json, or .hdf5 file. Parameters ---------- file_path : PathLike Full path to the .yaml or .json or .hdf5 file to load the - Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] from. + ``WorkflowType`` from. Returns ------- - Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] + WorkflowType An instance of the component class calling ``load``. - - Example - ------- - >>> simulation = Simulation.from_file(fname='folder/sim.json') # doctest: +SKIP """ - path = Path(file_path) - extension = _get_valid_extension(path) - if extension == ".json": - json_str = read_simulation_from_json(path) - elif extension == ".hdf5": - json_str = read_simulation_from_hdf5(path) - elif extension == ".hdf5.gz": - json_str = read_simulation_from_hdf5_gz(path) - - data = json.loads(json_str) - type_ = data["type"] - - supported_classes = [ - Simulation, - ModeSolver, - HeatSimulation, - HeatChargeSimulation, - EMESimulation, - ModeSimulation, - VolumeMesher, - ModalComponentModeler, - TerminalComponentModeler, - ] - - class_map = {cls.__name__: cls for cls in supported_classes} - - if type_ not in class_map: - raise ValueError( - f"Unsupported type '{type_}'. Supported types: {list(class_map.keys())}" - ) - - sim_class = class_map[type_] - sim = sim_class.from_file(path) - - return sim + return Tidy3dBaseModel.from_file(file_path) def to_file( self, file_path: PathLike, ) -> None: - """Exports Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] instance to .yaml, .json, - or .hdf5 file + """Exports ``WorkflowType`` instance to .yaml, .json, or .hdf5 file Parameters ---------- file_path : PathLike - Full path to the .yaml or .json or .hdf5 file to save the :class:`Stub` to. - - Example - ------- - >>> simulation.to_file(fname='folder/sim.json') # doctest: +SKIP + Full path to the .yaml or .json or .hdf5 file to save the ``WorkflowType`` to. """ self.simulation.to_file(file_path) - def to_hdf5_gz(self, fname: PathLike, custom_encoders: Optional[list[Callable]] = None) -> None: - """Exports Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] instance to .hdf5.gz file. + def to_hdf5_gz(self, fname: PathLike) -> None: + """Exports ``WorkflowType`` instance to .hdf5.gz file. Parameters ---------- fname : PathLike - Full path to the .hdf5.gz file to save - the Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] to. - custom_encoders : List[Callable] - List of functions accepting (fname: PathLike, group_path: str, value: Any) that take - the ``value`` supplied and write it to the hdf5 ``fname`` at ``group_path``. - - Example - ------- - >>> simulation.to_hdf5_gz(fname='folder/sim.hdf5.gz') # doctest: +SKIP + Full path to the .hdf5.gz file to save the ``WorkflowType`` to. """ self.simulation.to_hdf5_gz(fname) @@ -215,14 +150,13 @@ class Tidy3dStubData(BaseModel, TaskStubData): def from_file( cls, file_path: PathLike, lazy: bool = False, on_load: Optional[Callable] = None ) -> WorkflowDataType: - """Loads a Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] - from .yaml, .json, or .hdf5 file. + """Loads a ``WorkflowDataType`` instance from .yaml, .json, or .hdf5 file. Parameters ---------- file_path : PathLike - Full path to the .yaml or .json or .hdf5 file to load the - Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] from. + Full path to the .yaml or .json or .hdf5 file to load the ``WorkflowDataType`` instance + from. lazy : bool = False Whether to load the actual data (``lazy=False``) or return a proxy that loads the data when accessed (``lazy=True``). @@ -234,85 +168,44 @@ def from_file( Returns ------- - Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] + ``WorkflowDataType`` instance An instance of the component class calling ``load``. """ - path = Path(file_path) - extension = _get_valid_extension(path) - if extension == ".json": - json_str = read_simulation_from_json(path) - elif extension == ".hdf5": - json_str = read_simulation_from_hdf5(path) - elif extension == ".hdf5.gz": - json_str = read_simulation_from_hdf5_gz(path) - - data = json.loads(json_str) - type_ = data["type"] - - supported_data_classes = [ - SimulationData, - ModeSolverData, - MicrowaveModeSolverData, - HeatSimulationData, - HeatChargeSimulationData, - EMESimulationData, - ModeSimulationData, - VolumeMesherData, - ModalComponentModelerData, - TerminalComponentModelerData, - ] - - data_class_map = {cls.__name__: cls for cls in supported_data_classes} - - if type_ not in data_class_map: - raise ValueError( - f"Unsupported data type '{type_}'. Supported types: {list(data_class_map.keys())}" - ) - - data_class = data_class_map[type_] - sim_data = data_class.from_file(path, lazy=lazy, on_load=on_load) - - return sim_data + return Tidy3dBaseModel.from_file(file_path, lazy=lazy, on_load=on_load) def to_file(self, file_path: PathLike) -> None: - """Exports Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] instance + """Exports ``WorkflowDataType`` instance to .yaml, .json, or .hdf5 file Parameters ---------- file_path : PathLike - Full path to the .yaml or .json or .hdf5 file to save the - Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] to. - - Example - ------- - >>> simulation.to_file(fname='folder/sim.json') # doctest: +SKIP + Full path to the .yaml or .json or .hdf5 file to save the ``WorkflowDataType`` instance to. """ self.data.to_file(file_path) @classmethod def postprocess(cls, file_path: PathLike, lazy: bool = True) -> WorkflowDataType: """Load .yaml, .json, or .hdf5 file to - Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] instance. + ``WorkflowDataType`` instance. Parameters ---------- file_path : PathLike - Full path to the .yaml or .json or .hdf5 file to save the - Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] to. + Full path to the .yaml or .json or .hdf5 file to save the ``WorkflowDataType`` instance to. lazy : bool = False Whether to load the actual data (``lazy=False``) or return a proxy that loads the data when accessed (``lazy=True``). Returns ------- - Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`] + ``WorkflowDataType`` instance An instance of the component class calling ``load``. """ - stub_data = Tidy3dStubData.from_file( + workflow_data = Tidy3dBaseModel.from_file( file_path, lazy=lazy, on_load=cls._check_convergence_and_warnings ) - return stub_data + return workflow_data @staticmethod def _check_convergence_and_warnings(stub_data: WorkflowDataType) -> None: