diff --git a/CHANGELOG.md b/CHANGELOG.md index 44f1a125d5..6af6d50b02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `interp_spec` in `ModeSpec` to allow downsampling and interpolation of waveguide modes in frequency. - Added warning if port mesh refinement is incompatible with the `GridSpec` in the `TerminalComponentModeler`. - Various types, e.g. different `Simulation` or `SimulationData` sub-classes, can be loaded from file directly with `Tidy3dBaseModel.from_file()`. +- Added `interp_spec` in `EMEModeSpec` to enable faster multi-frequency EME simulations. Note that the default is now `ModeInterpSpec.cheb(num_points=3, reduce_data=True)`; previously the computation was repeated at all frequencies. ### Breaking Changes - Edge singularity correction at PEC and lossy metal edges defaults to `True`. @@ -64,6 +65,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Simulation data of batch jobs are now automatically downloaded upon their individual completion in `Batch.run()`, avoiding waiting for the entire batch to reach completion. - Port names in `ModalComponentModeler` and `TerminalComponentModeler` can no longer include the `@` symbol. - Improved speed of convolutions for large inputs. +- Default value of `EMEModeSpec.interp_spec` is `ModeInterpSpec.cheb(num_points=3, reduce_data=True)` for faster multi-frequency EME simulations. ### Fixed - Ensured the legacy `Env` proxy mirrors `config.web` profile switches and preserves API URL. diff --git a/schemas/EMESimulation.json b/schemas/EMESimulation.json index c70989201a..8ff177f3c4 100644 --- a/schemas/EMESimulation.json +++ b/schemas/EMESimulation.json @@ -5294,7 +5294,18 @@ { "$ref": "#/definitions/ModeInterpSpec" } - ] + ], + "default": { + "attrs": {}, + "method": "poly", + "reduce_data": true, + "sampling_spec": { + "attrs": {}, + "num_points": 3, + "type": "ChebSampling" + }, + "type": "ModeInterpSpec" + } }, "num_modes": { "default": 1, diff --git a/tests/sims/full_fdtd.h5 b/tests/sims/full_fdtd.h5 index b0f831d307..40730d6e30 100644 Binary files a/tests/sims/full_fdtd.h5 and b/tests/sims/full_fdtd.h5 differ diff --git a/tests/sims/full_fdtd.json b/tests/sims/full_fdtd.json index 33fd352732..e048a146d1 100644 --- a/tests/sims/full_fdtd.json +++ b/tests/sims/full_fdtd.json @@ -2075,6 +2075,7 @@ "track_freq": "central", "type": "ModeSortSpec" }, + "interp_spec": null, "type": "ModeSpec" }, "mode_index": 0, @@ -2666,6 +2667,7 @@ "track_freq": "central", "type": "ModeSortSpec" }, + "interp_spec": null, "type": "ModeSpec" }, "store_fields_direction": null, diff --git a/tests/test_components/test_eme.py b/tests/test_components/test_eme.py index c2c5d5a4af..6712d921b6 100644 --- a/tests/test_components/test_eme.py +++ b/tests/test_components/test_eme.py @@ -568,7 +568,8 @@ def test_eme_simulation(): sim = sim_no_field.updated_copy(sweep_spec=td.EMELengthSweep(scale_factors=[1, 2])) assert not sim._sweep_modes assert sim._num_sweep == 2 - sim = sim.updated_copy(sweep_spec=td.EMEFreqSweep(freq_scale_factors=[1, 2])) + with AssertLogLevel("WARNING", contains_str="'EMEFreqSweep' is deprecated"): + sim = sim.updated_copy(sweep_spec=td.EMEFreqSweep(freq_scale_factors=[1, 2])) assert sim._sweep_modes assert sim._num_sweep == 2 assert sim._monitor_num_sweep(sim.monitors[0]) == 1 @@ -911,7 +912,9 @@ def _get_mode_solver_data(modes_out=False, num_modes=3): size=(td.inf, td.inf, 0), center=(0, 0, offset), freqs=[td.C_0], - mode_spec=td.ModeSpec(num_modes=num_modes), + mode_spec=td.ModeSpec( + num_modes=num_modes, interp_spec=td.ModeInterpSpec.cheb(num_points=3, reduce_data=True) + ), name=name, ) eme_mode_data = _get_eme_mode_solver_data() diff --git a/tests/test_components/test_mode_interp.py b/tests/test_components/test_mode_interp.py index 6b6e3f7b7a..67a087372c 100644 --- a/tests/test_components/test_mode_interp.py +++ b/tests/test_components/test_mode_interp.py @@ -244,14 +244,14 @@ def test_mode_solver_monitor_valid_with_tracking(): def test_interp_num_points_less_than_freqs(): - """Test that num_points must be less than total freqs.""" + """Test that num_points can be greater than total freqs.""" mode_spec = td.ModeSpec( num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central"), interp_spec=td.ModeInterpSpec.uniform(num_points=25, method="linear"), ) - with AssertLogLevel("WARNING", contains_str="num_points"): + with AssertLogLevel(None): td.ModeSolverMonitor( center=(0, 0, 0), size=SIZE_2D, @@ -262,14 +262,14 @@ def test_interp_num_points_less_than_freqs(): def test_interp_num_points_equal_to_freqs(): - """Test that num_points equal to freqs is rejected.""" + """Test that num_points equal to freqs is not rejected.""" mode_spec = td.ModeSpec( num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central"), interp_spec=td.ModeInterpSpec.uniform(num_points=20, method="linear"), ) - with AssertLogLevel("WARNING", contains_str="num_points"): + with AssertLogLevel(None): td.ModeSolverMonitor( center=(0, 0, 0), size=SIZE_2D, @@ -354,7 +354,7 @@ def test_mode_solver_valid_with_tracking(): @td.packaging.disable_local_subpixel def test_mode_solver_warns_num_points(): - """Test that ModeSolver warns when num_points >= num_freqs.""" + """Test that ModeSolver does not warn when num_points >= num_freqs.""" sim = get_simple_sim() mode_spec = td.ModeSpec( num_modes=2, @@ -363,14 +363,14 @@ def test_mode_solver_warns_num_points(): ) plane = td.Box(center=(0, 0, 0), size=SIZE_2D) - with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"): + with AssertLogLevel(None): ms = ModeSolver( simulation=sim, plane=plane, freqs=FREQS_DENSE, mode_spec=mode_spec, ) - _ = ms.data_raw + _ = ms.data_raw def test_mode_solver_interp_spec_none(): @@ -1041,7 +1041,7 @@ def test_mode_solver_monitor_with_interp_spec(): def test_mode_monitor_warns_redundant_num_points(): - """Test warning when num_points >= number of frequencies in ModeMonitor.""" + """Test no warning when num_points >= number of frequencies in ModeMonitor.""" freqs = np.linspace(1e14, 2e14, 5) mode_spec = td.ModeSpec( num_modes=2, @@ -1049,7 +1049,7 @@ def test_mode_monitor_warns_redundant_num_points(): interp_spec=td.ModeInterpSpec.uniform(num_points=5, method="linear"), ) - with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"): + with AssertLogLevel(None): td.ModeMonitor( center=(0, 0, 0), size=SIZE_2D, @@ -1060,7 +1060,7 @@ def test_mode_monitor_warns_redundant_num_points(): def test_mode_solver_monitor_warns_redundant_num_points(): - """Test warning when num_points >= number of frequencies in ModeSolverMonitor.""" + """Test no warning when num_points >= number of frequencies in ModeSolverMonitor.""" freqs = np.linspace(1e14, 2e14, 5) mode_spec = td.ModeSpec( num_modes=2, @@ -1068,7 +1068,7 @@ def test_mode_solver_monitor_warns_redundant_num_points(): interp_spec=td.ModeInterpSpec.uniform(num_points=6, method="linear"), ) - with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"): + with AssertLogLevel(None): td.ModeSolverMonitor( center=(0, 0, 0), size=SIZE_2D, diff --git a/tidy3d/components/data/dataset.py b/tidy3d/components/data/dataset.py index e3e4307fc2..aed00f8aa6 100644 --- a/tidy3d/components/data/dataset.py +++ b/tidy3d/components/data/dataset.py @@ -139,6 +139,10 @@ def _interp_dataarray_in_freq( DataArray Interpolated data array with the same structure but new frequency points. """ + # if dataarray is already stored at the correct frequencies, do nothing + if np.array_equal(freqs, data.f): + return data + # Map 'poly' to xarray's 'barycentric' method xr_method = "barycentric" if method == "poly" else method diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index 79cd62b655..e99d8881c7 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -1286,6 +1286,22 @@ def to_zbf( return e_x, e_y + def _interpolated_copies_if_needed( + self, other: ElectromagneticFieldData + ) -> tuple[ElectromagneticFieldData, ElectromagneticFieldData]: + """Return interpolated copies of self, other if needed (different interp_spec).""" + mode_spec1 = self.monitor.mode_spec if isinstance(self, ModeSolverData) else None + mode_spec2 = other.monitor.mode_spec if isinstance(other, ModeSolverData) else None + if ( + mode_spec1 is not None + and mode_spec2 is not None + and self.monitor.mode_spec._same_nontrivial_interp_spec(other=other.monitor.mode_spec) + ): + return self, other + self_copy = self.interpolated_copy if isinstance(self, ModeSolverData) else self + other_copy = other.interpolated_copy if isinstance(other, ModeSolverData) else other + return self_copy, other_copy + class FieldData(FieldDataset, ElectromagneticFieldData): """ @@ -2685,6 +2701,8 @@ def _reduced_data(self) -> bool: @property def interpolated_copy(self) -> ModeSolverData: """Return a copy of the data with interpolated fields.""" + if self.monitor.mode_spec.interp_spec is None: + return self if not self._reduced_data: return self interpolated_data = self.interp_in_freq( diff --git a/tidy3d/components/eme/data/sim_data.py b/tidy3d/components/eme/data/sim_data.py index 03d16e65d9..05c1e6e739 100644 --- a/tidy3d/components/eme/data/sim_data.py +++ b/tidy3d/components/eme/data/sim_data.py @@ -197,11 +197,19 @@ def smatrix_in_basis( modes1 = port_modes1 if not modes2_provided: modes2 = port_modes2 - f1 = list(modes1.field_components.values())[0].f.values - f2 = list(modes2.field_components.values())[0].f.values + f1 = list(modes1.monitor.freqs) + f2 = list(modes2.monitor.freqs) f = np.array(sorted(set(f1).intersection(f2).intersection(self.simulation.freqs))) + mode_spec1 = modes1.monitor.mode_spec if isinstance(modes1, ModeData) else None + mode_spec2 = modes2.monitor.mode_spec if isinstance(modes2, ModeData) else None + + interp_spec1 = mode_spec1.interp_spec if mode_spec1 is not None else None + interp_spec2 = mode_spec2.interp_spec if mode_spec2 is not None else None + + modes1, modes2 = modes1._interpolated_copies_if_needed(other=modes2) + modes_in_1 = "mode_index" in list(modes1.field_components.values())[0].coords modes_in_2 = "mode_index" in list(modes2.field_components.values())[0].coords @@ -259,6 +267,10 @@ def smatrix_in_basis( overlaps1 = modes1.outer_dot(port_modes1, conjugate=False) if not modes_in_1: overlaps1 = overlaps1.expand_dims(dim={"mode_index_0": mode_index_1}, axis=1) + if interp_spec1 is not None: + overlaps1 = modes1._interp_dataarray_in_freq( + overlaps1, freqs=f, method=interp_spec1.method + ) O1 = overlaps1.sel(f=f, mode_index_1=keep_mode_inds1) O1out = O1.rename(mode_index_0="mode_index_out", mode_index_1="mode_index_out_old") @@ -288,6 +300,10 @@ def smatrix_in_basis( overlaps2 = modes2.outer_dot(port_modes2, conjugate=False) if not modes_in_2: overlaps2 = overlaps2.expand_dims(dim={"mode_index_0": mode_index_2}, axis=1) + if interp_spec2 is not None: + overlaps2 = modes2._interp_dataarray_in_freq( + overlaps2, freqs=f, method=interp_spec2.method + ) O2 = overlaps2.sel(f=f, mode_index_1=keep_mode_inds2) O2out = O2.rename(mode_index_0="mode_index_out", mode_index_1="mode_index_out_old") diff --git a/tidy3d/components/eme/grid.py b/tidy3d/components/eme/grid.py index 2ed22e26ca..7b1b548784 100644 --- a/tidy3d/components/eme/grid.py +++ b/tidy3d/components/eme/grid.py @@ -11,9 +11,9 @@ from tidy3d.components.base import Tidy3dBaseModel, skip_if_fields_missing from tidy3d.components.geometry.base import Box from tidy3d.components.grid.grid import Coords1D -from tidy3d.components.mode_spec import ModeSpec +from tidy3d.components.mode_spec import ModeInterpSpec, ModeSpec from tidy3d.components.structure import Structure -from tidy3d.components.types import ArrayFloat1D, Axis, Coordinate, Size, TrackFreq +from tidy3d.components.types import ArrayFloat1D, Axis, Coordinate, Size from tidy3d.constants import RADIAN, fp_eps, inf from tidy3d.exceptions import SetupError, ValidationError @@ -26,13 +26,14 @@ class EMEModeSpec(ModeSpec): """Mode spec for EME cells. Overrides some of the defaults and allowed values.""" - track_freq: Union[TrackFreq, None] = pd.Field( - None, - title="Mode Tracking Frequency", - description="Parameter that turns on/off mode tracking based on their similarity. " - "Can take values ``'lowest'``, ``'central'``, or ``'highest'``, which correspond to " - "mode tracking based on the lowest, central, or highest frequency. " - "If ``None`` no mode tracking is performed, which is the default for best performance.", + interp_spec: Optional[ModeInterpSpec] = pd.Field( + ModeInterpSpec.cheb(num_points=3, reduce_data=True), + title="Mode frequency interpolation specification", + description="Specification for computing modes at a reduced set of frequencies and " + "interpolating to obtain results at all requested frequencies. This can significantly " + "reduce computational cost for broadband simulations where modes vary smoothly with " + "frequency. Requires frequency tracking to be enabled (``sort_spec.track_freq`` must " + "not be ``None``) to ensure consistent mode ordering across frequencies.", ) angle_theta: Literal[0.0] = pd.Field( diff --git a/tidy3d/components/eme/simulation.py b/tidy3d/components/eme/simulation.py index 84b3fec261..6569d58a73 100644 --- a/tidy3d/components/eme/simulation.py +++ b/tidy3d/components/eme/simulation.py @@ -722,6 +722,12 @@ def _validate_sweep_spec(self) -> None: "which is not compatible with 'EMELengthSweep'." ) elif isinstance(self.sweep_spec, EMEFreqSweep): + log.warning( + "'EMEFreqSweep' is deprecated. Instead, it is recommended to use " + "'EMESimulation.freqs' directly, and set " + "'EMEModeSpec.interp_spec' as desired to balance " + "performance and accuracy." + ) for i, scale_factor in enumerate(self.sweep_spec.freq_scale_factors): scaled_freqs = np.array(self.freqs) * scale_factor if np.min(scaled_freqs) < MIN_FREQUENCY: @@ -1007,6 +1013,18 @@ def _monitor_freqs(self, monitor: Monitor) -> list[pd.NonNegativeFloat]: return list(self.freqs) return list(monitor.freqs) + def _monitor_mode_freqs(self, monitor: EMEModeSolverMonitor) -> list[pd.NonNegativeFloat]: + """Monitor frequencies.""" + freqs = set() + cell_inds = self._monitor_eme_cell_indices(monitor=monitor) + for cell_ind in cell_inds: + interp_spec = self.eme_grid.mode_specs[cell_ind].interp_spec + if interp_spec is None: + freqs |= set(self.freqs) + else: + freqs |= set(interp_spec.sampling_points(self.freqs)) + return list(freqs) + def _monitor_num_freqs(self, monitor: Monitor) -> int: """Total number of freqs included in monitor.""" return len(self._monitor_freqs(monitor=monitor)) diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index 74218e96b8..3d41c2ffb8 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -75,7 +75,6 @@ from tidy3d.components.types.mode_spec import ModeSpecType from tidy3d.components.types.monitor_data import ModeSolverDataType from tidy3d.components.validators import ( - _warn_interp_num_points, validate_freqs_min, validate_freqs_not_empty, ) @@ -515,9 +514,6 @@ def data_raw(self) -> ModeSolverDataType: A mode solver data type object containing the effective index and mode fields. """ - if self.mode_spec.interp_spec is not None: - _warn_interp_num_points(self.mode_spec.interp_spec, self.freqs) - if self.mode_spec.angle_rotation and np.abs(self.mode_spec.angle_theta) > 0: return self.rotated_mode_solver_data diff --git a/tidy3d/components/mode/simulation.py b/tidy3d/components/mode/simulation.py index e964ce0d87..2e836c55ec 100644 --- a/tidy3d/components/mode/simulation.py +++ b/tidy3d/components/mode/simulation.py @@ -26,7 +26,6 @@ from tidy3d.components.source.field import ModeSource from tidy3d.components.types import TYPE_TAG_STR, Ax, Direction, EMField, FreqArray from tidy3d.components.types.mode_spec import ModeSpecType -from tidy3d.components.validators import validate_interp_num_points from tidy3d.constants import C_0 from tidy3d.exceptions import SetupError, ValidationError from tidy3d.log import log @@ -235,8 +234,6 @@ def plane_in_sim_bounds(cls, val, values): raise SetupError("'ModeSimulation.plane' must intersect 'ModeSimulation.geometry.") return val - _warn_interp_num_points = validate_interp_num_points() - def _post_init_validators(self) -> None: """Call validators taking `self` that get run after init.""" _ = self._mode_solver diff --git a/tidy3d/components/mode_spec.py b/tidy3d/components/mode_spec.py index 0ae0256c80..0dfb12b531 100644 --- a/tidy3d/components/mode_spec.py +++ b/tidy3d/components/mode_spec.py @@ -433,7 +433,7 @@ def sampling_points(self, freqs: FreqArray) -> FreqArray: >>> interp_spec = ModeInterpSpec.cheb(num_points=10) >>> sampling_freqs = interp_spec.sampling_points(freqs) """ - if self.num_points > len(freqs): + if self.num_points >= len(freqs): return freqs return self.sampling_spec.sampling_points(freqs) @@ -738,6 +738,14 @@ def _is_interp_spec_applied(self, freqs: FreqArray) -> bool: """Whether interp_spec is used to compute modes at the given frequencies.""" return self.interp_spec is not None and self.interp_spec.num_points < len(freqs) + def _same_nontrivial_interp_spec(self, other: ModeSpec) -> bool: + """Whether two mode specs have identical nontrivial interp specs.""" + return ( + self.interp_spec is not None + and other.interp_spec is not None + and self.interp_spec == other.interp_spec + ) + class ModeSpec(AbstractModeSpec): """ diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index 1739dbe40f..df4ec33049 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -37,7 +37,6 @@ assert_plane, validate_freqs_min, validate_freqs_not_empty, - validate_interp_num_points, ) from .viz import ARROW_ALPHA, ARROW_COLOR_MONITOR @@ -429,16 +428,21 @@ def _warn_num_modes(cls, val, values): ) return val + @property + def _stored_freqs(self) -> list[float]: + """Return actually stored frequencies of the data.""" + return self.mode_spec._sampling_freqs_mode_solver_data(freqs=self.freqs) + def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: """Size of intermediate data recorded by the monitor during a solver run.""" # Need to store all fields on the mode surface - bytes_single = BYTES_COMPLEX * num_cells * len(self.freqs) * self.mode_spec.num_modes * 6 + bytes_single = ( + BYTES_COMPLEX * num_cells * len(self._stored_freqs) * self.mode_spec.num_modes * 6 + ) if self.mode_spec.precision == "double": return 2 * bytes_single return bytes_single - _warn_interp_num_points = validate_interp_num_points() - class FieldMonitor(AbstractFieldMonitor, FreqMonitor): """:class:`Monitor` that records electromagnetic fields in the frequency domain. diff --git a/tidy3d/components/validators.py b/tidy3d/components/validators.py index d1efc03018..c6a9d9c640 100644 --- a/tidy3d/components/validators.py +++ b/tidy3d/components/validators.py @@ -494,35 +494,3 @@ def _warn_traced_arg(cls, val, values): return val return _warn_traced_arg - - -def _warn_interp_num_points(interp_spec, freqs) -> None: - """Warn if the number of sampling points for interpolation is greater than or equal to the number of target frequencies.""" - - num_freqs = len(freqs) - - if interp_spec.num_points >= num_freqs: - log.warning( - f"'interp_spec.num_points' ({interp_spec.num_points}) is greater than or equal to " - f"the number of frequencies ({num_freqs}). Interpolation will be skipped and " - f"modes will be computed at all {num_freqs} frequencies.", - custom_loc=["mode_spec", "interp_spec", "num_points"], - ) - - -def validate_interp_num_points(): - @pydantic.root_validator(allow_reuse=True) - @skip_if_fields_missing(["freqs", "mode_spec"], root=True) - def _validate_warn_interp_num_points(cls, values): - """Warn if the number of sampling points for interpolation is greater than or equal to the number of target frequencies.""" - - interp_spec = values.get("mode_spec").interp_spec - if interp_spec is None: - return values - - freqs = values.get("freqs") - _warn_interp_num_points(interp_spec, freqs) - - return values - - return _validate_warn_interp_num_points