Skip to content

Commit 3e288b1

Browse files
committed
feat: Moving TYPE_TO_CLASS_MAP to Tidy3dBaseModel and allow Tidy3dBaseModel.from_file for various components
1 parent f84536b commit 3e288b1

File tree

9 files changed

+209
-182
lines changed

9 files changed

+209
-182
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3737
- Introduced `BroadbandPulse` for exciting simulations across a wide frequency spectrum.
3838
- Added `interp_spec` in `ModeSpec` to allow downsampling and interpolation of waveguide modes in frequency.
3939
- Added warning if port mesh refinement is incompatible with the `GridSpec` in the `TerminalComponentModeler`.
40+
- Various types, e.g. different `Simulation` or `SimulationData` sub-classes, can be loaded from file directly with `Tidy3dBaseModel.from_file()`.
4041

4142
### Breaking Changes
4243
- Edge singularity correction at PEC and lossy metal edges defaults to `True`.

tests/test_components/test_base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
from __future__ import annotations
44

5+
from typing import Literal
6+
57
import numpy as np
68
import pytest
9+
from pydantic.v1 import ValidationError
710

811
import tidy3d as td
912
from tidy3d.components.base import Tidy3dBaseModel
@@ -275,3 +278,19 @@ def test_updated_hash_and_json_with_changed_attr():
275278

276279
assert new_hash != old_hash
277280
assert json_old != json_new
281+
282+
283+
def test_parse_obj_respects_subclasses():
284+
class DispatchBase(Tidy3dBaseModel):
285+
type: Literal["DispatchBase"] = "DispatchBase"
286+
value: int
287+
288+
class DispatchChild(DispatchBase):
289+
type: Literal["DispatchChild"] = "DispatchChild"
290+
291+
data = {"type": "DispatchChild", "value": 1}
292+
parsed = Tidy3dBaseModel._parse_model_dict(data)
293+
assert isinstance(parsed, DispatchChild)
294+
295+
with pytest.raises(ValidationError):
296+
DispatchChild.parse_obj({"type": "DispatchBase", "value": 2})
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from tidy3d.plugins.expressions.base import Expression
6+
from tidy3d.plugins.expressions.variables import Constant
7+
8+
9+
def test_expression_parse_obj_round_trip():
10+
expr = Constant(3.14)
11+
parsed = Expression.parse_obj(expr.dict())
12+
assert isinstance(parsed, Constant)
13+
assert parsed.value == pytest.approx(3.14)
14+
15+
16+
def test_expression_parse_obj_rejects_unrelated_types():
17+
# Simulation registers a distinct type in the global map; parsing via Expression should fail.
18+
with pytest.raises(ValueError, match="Cannot parse type"):
19+
Expression.parse_obj({"type": "Simulation"})

tests/test_web/test_local_cache.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
from tidy3d.web.api import webapi as web
2727
from tidy3d.web.api.autograd import autograd, engine, io_utils
2828
from tidy3d.web.api.autograd.autograd import run as run_autograd
29-
from tidy3d.web.api.autograd.constants import SIM_VJP_FILE
29+
from tidy3d.web.api.autograd.constants import (
30+
AUX_KEY_SIM_DATA_FWD,
31+
AUX_KEY_SIM_DATA_ORIGINAL,
32+
SIM_VJP_FILE,
33+
)
3034
from tidy3d.web.api.container import Batch, WebContainer
3135
from tidy3d.web.api.webapi import load_simulation_if_cached
3236
from tidy3d.web.cache import (
@@ -181,13 +185,37 @@ def _fake_download_file(resource_id, remote_filename, to_file=None, **kwargs):
181185
if str(remote_filename) == SIM_VJP_FILE:
182186
counters["download"] += 1
183187

184-
def _fake_from_file(*args, **kwargs):
185-
field_map = FieldMap(tracers=())
186-
return field_map
188+
def _fake_postprocess_fwd(*, sim_data_combined=None, sim_original=None, aux_data=None, **_):
189+
"""Mimic ``autograd.postprocess_fwd`` side effects for tests."""
190+
if sim_original is None:
191+
sim_original = next(iter(PATH_TO_SIM.values()), None)
192+
if sim_original is None:
193+
sim_original = td.Simulation(
194+
size=(1, 1, 1),
195+
grid_spec=td.GridSpec.auto(wavelength=1.0),
196+
run_time=1e-12,
197+
)
198+
stub_data = _FakeStubData(sim_original)
199+
if aux_data is not None:
200+
aux_data[AUX_KEY_SIM_DATA_ORIGINAL] = stub_data
201+
aux_data[AUX_KEY_SIM_DATA_FWD] = stub_data
202+
return stub_data._strip_traced_fields()
203+
204+
def _fake_postprocess_adj(
205+
sim_data_adj=None, sim_data_orig=None, sim_data_fwd=None, sim_fields_keys=None, **_
206+
):
207+
"""Return zeros for every requested field key."""
208+
counters["download"] += 1 # mimic VJP file download per autograd run
209+
sim_fields_keys = sim_fields_keys or []
210+
return dict.fromkeys(sim_fields_keys, 0.0)
211+
212+
def _fake_field_map_from_file(*args, **kwargs):
213+
return FieldMap(tracers=())
187214

188215
monkeypatch.setattr(io_utils, "download_file", _fake_download_file)
189-
monkeypatch.setattr(autograd, "postprocess_fwd", _fake_from_file)
190-
monkeypatch.setattr(FieldMap, "from_file", _fake_from_file)
216+
monkeypatch.setattr(autograd, "postprocess_fwd", _fake_postprocess_fwd)
217+
monkeypatch.setattr(autograd, "postprocess_adj", _fake_postprocess_adj)
218+
monkeypatch.setattr(FieldMap, "from_file", _fake_field_map_from_file)
191219
monkeypatch.setattr(WebContainer, "_check_folder", _fake__check_folder)
192220
monkeypatch.setattr(web, "upload", _fake_upload)
193221
monkeypatch.setattr(web, "start", _fake_start)

tests/test_web/test_tidy3d_stub.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ def make_sim():
4242
)
4343

4444

45+
def is_lazy_object(data):
46+
assert set(data.__dict__.keys()) == {
47+
"_lazy_fname",
48+
"_lazy_group_path",
49+
"_lazy_parse_obj_kwargs",
50+
}
51+
return True
52+
53+
4554
def make_sim_data(file_size_gb=0.001):
4655
"""Makes a simulation data."""
4756
N = int(2.528e8 / 4 * file_size_gb)
@@ -146,22 +155,16 @@ def test_stub_data_lazy_loading(tmp_path):
146155
sim_data = Tidy3dStubData.postprocess(file_path, lazy=True)
147156

148157
sim_data_copy = sim_data.copy()
149-
assert type(sim_data).__name__ == "SimulationDataProxy"
150-
assert type(sim_data_copy).__name__ == "SimulationDataProxy"
158+
159+
# variable dict should only contain metadata to load the data, not the data itself
160+
assert is_lazy_object(sim_data)
151161

152162
# the type should be still SimulationData despite being lazy
153163
assert isinstance(sim_data, SimulationData)
154164

155-
# variable dict should only contain metadata to load the data, not the data itself
156-
assert set(sim_data.__dict__.keys()) == {
157-
"_lazy_fname",
158-
"_lazy_group_path",
159-
"_lazy_parse_obj_kwargs",
160-
}
161-
162165
# we expect a warning from the lazy object if some field is accessed
163166
with AssertLogLevel("WARNING", contains_str=sim_diverged_log):
164-
_ = sim_data.monitor_data
167+
_ = sim_data_copy.monitor_data
165168

166169

167170
@pytest.mark.parametrize(
@@ -185,6 +188,7 @@ def test_stub_pathlike_roundtrip(tmp_path, path_builder):
185188

186189
# Simulation data stub roundtrip
187190
sim_data = make_sim_data()
191+
sim_data = sim_data.updated_copy(log="log")
188192
stub_data = Tidy3dStubData(data=sim_data)
189193
data_path = path_builder(tmp_path, "pathlike_data.hdf5")
190194
stub_data.to_file(data_path)

tests/test_web/test_webapi.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from responses import matchers
1616

1717
import tidy3d as td
18+
from tests.test_web.test_tidy3d_stub import is_lazy_object
1819
from tidy3d import Simulation
1920
from tidy3d.__main__ import main
2021
from tidy3d.components.data.data_array import ScalarFieldDataArray
@@ -958,21 +959,18 @@ def test_run_with_flexible_containers_offline_lazy(monkeypatch, tmp_path):
958959
apply_common_patches(monkeypatch, tmp_path, taskid_to_sim=taskid_to_sim)
959960

960961
data = run(sim_container, task_name=task_name, folder_name="PROJECT", path=str(out_dir))
961-
962+
assert is_lazy_object(data[0])
962963
assert isinstance(data, list) and len(data) == 3
963964

964-
assert isinstance(data[0], SimulationData)
965-
assert data[0].__class__.__name__ == "SimulationDataProxy"
966-
967965
assert isinstance(data[1], dict)
968966
assert "sim2" in data[1]
967+
assert is_lazy_object(data[1]["sim2"])
969968
assert isinstance(data[1]["sim2"], SimulationData)
970-
assert data[1]["sim2"].__class__.__name__ == "SimulationDataProxy"
971969

970+
assert is_lazy_object(data[2][0])
972971
assert isinstance(data[2], tuple)
973-
assert data[2][0].__class__.__name__ == "SimulationDataProxy"
972+
assert is_lazy_object(data[2][1][0])
974973
assert isinstance(data[2][1], list)
975-
assert data[2][1][0].__class__.__name__ == "SimulationDataProxy"
976974

977975
assert data[0].simulation == sim1
978976
assert data[1]["sim2"].simulation == sim2

0 commit comments

Comments
 (0)