Skip to content

Commit ed68d39

Browse files
committed
differentiable s-matrix calculation
1 parent 7c093b1 commit ed68d39

File tree

5 files changed

+71
-24
lines changed

5 files changed

+71
-24
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
- Objective functions that involve running `tidy3d.plugins.smatrix.ComponentModeler` can be differentiated with autograd.
12+
13+
1014
## [2.9.0rc1] - 2025-06-10
1115

1216
### Added

tests/test_plugins/smatrix/test_component_modeler.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,13 @@ def _test_mappings(element_mappings, s_matrix):
322322
"mode_index_out": mode_index_out_from,
323323
}
324324

325-
coords_to = {
326-
"port_in": port_in_to,
327-
"port_out": port_out_to,
328-
"mode_index_in": mode_index_in_to,
329-
"mode_index_out": mode_index_out_to,
330-
}
325+
326+
coords_to = dict(
327+
port_in=port_in_to,
328+
port_out=port_out_to,
329+
mode_index_in=mode_index_in_to,
330+
mode_index_out=mode_index_out_to,
331+
)
331332

332333
assert np.all(
333334
s_matrix.sel(**coords_to).values == mult_by * s_matrix.sel(**coords_from).values

tidy3d/plugins/autograd/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ We also support the following high-level features:
217217
- We automatically determine the number of adjoint simulations to run from a given forward simulation to maintain gradient accuracy.
218218
Adjoint sources are automatically grouped by either frequency or spatial port (whichever yields fewer adjoint simulations), and all adjoint simulations are run in a single batch (applies to both `run` and `run_async`).
219219
The parameter `max_num_adjoint_per_fwd` (default `10`) prevents launching unexpectedly large numbers of adjoint simulations automatically.
220+
- Differentiation of objective functions involving the scattering matrix produced by `tidy3d.plugins.smatrix.ComponentModeler`.
220221

221222
We currently have the following restrictions:
222223

tidy3d/plugins/smatrix/component_modelers/base.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from abc import ABC, abstractmethod
77
from typing import Optional, Union, get_args
88

9-
import numpy as np
9+
import autograd.numpy as np
1010
import pydantic.v1 as pd
1111

1212
from tidy3d.components.base import Tidy3dBaseModel, cached_property
@@ -24,6 +24,7 @@
2424
from tidy3d.plugins.smatrix.ports.wave import WavePort
2525
from tidy3d.web.api.container import Batch, BatchData
2626

27+
2728
# fwidth of gaussian pulse in units of central frequency
2829
FWIDTH_FRAC = 1.0 / 10
2930
DEFAULT_DATA_DIR = "."
@@ -196,7 +197,21 @@ def batch_path(self) -> str:
196197
@cached_property
197198
def batch_data(self) -> BatchData:
198199
"""The :class:`.BatchData` associated with the simulations run for this component modeler."""
199-
return self.batch.run(path_dir=self.path_dir)
200+
201+
# NOTE: uses run_async because Batch is not differentiable.
202+
batch = self.batch
203+
run_async_kwargs = batch.dict(
204+
exclude={
205+
"type",
206+
"path_dir",
207+
"attrs",
208+
"solver_version",
209+
"jobs_cached",
210+
"num_workers",
211+
"simulations",
212+
}
213+
)
214+
return run_async(batch.simulations, **run_async_kwargs, path_dir=self.path_dir)
200215

201216
def get_path_dir(self, path_dir: str) -> None:
202217
"""Check whether the supplied 'path_dir' matches the internal field value."""

tidy3d/plugins/smatrix/component_modelers/modal.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import Optional
88

9-
import numpy as np
9+
import autograd.numpy as np
1010
import pydantic.v1 as pd
1111

1212
from tidy3d.components.base import cached_property
@@ -298,6 +298,23 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> ModalPortDataArr
298298
}
299299
s_matrix = ModalPortDataArray(values, coords=coords)
300300

301+
def set_new_values(
302+
values: np.ndarray,
303+
new_values: np.ndarray,
304+
port_name_in: str,
305+
mode_index_in: int,
306+
port_name_out: str,
307+
mode_index_out: int,
308+
) -> np.ndarray:
309+
"""Replace ``values`` with ``new_values`` at indices given by dims in ``s_matrix```"""
310+
port_in_pos = int(s_matrix.get_index("port_in").get_loc(port_name_in))
311+
mode_in_pos = int(s_matrix.get_index("mode_index_in").get_loc(mode_index_in))
312+
port_out_pos = int(s_matrix.get_index("port_out").get_loc(port_name_out))
313+
mode_out_pos = int(s_matrix.get_index("mode_index_out").get_loc(mode_index_out))
314+
mask = np.zeros_like(values)
315+
mask[port_out_pos, port_in_pos, mode_out_pos, mode_in_pos, :] = 1
316+
return np.where(mask, new_values, values)
317+
301318
# loop through source ports
302319
for col_index in self.matrix_indices_run_sim:
303320
port_name_in, mode_index_in = col_index
@@ -317,14 +334,17 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> ModalPortDataArr
317334
)
318335
source_norm = self._normalization_factor(port_in, sim_data)
319336
s_matrix_elements = np.array(amp.data) / np.array(source_norm)
320-
s_matrix.loc[
321-
{
322-
"port_in": port_name_in,
323-
"mode_index_in": mode_index_in,
324-
"port_out": port_name_out,
325-
"mode_index_out": mode_index_out,
326-
}
327-
] = s_matrix_elements
337+
338+
values = set_new_values(
339+
values=values,
340+
new_values=s_matrix_elements,
341+
port_name_in=port_name_in,
342+
mode_index_in=mode_index_in,
343+
port_name_out=port_name_out,
344+
mode_index_out=mode_index_out,
345+
)
346+
347+
s_matrix = ModalPortDataArray(values, coords=coords)
328348

329349
# element can be determined by user-defined mapping
330350
for (row_in, col_in), (row_out, col_out), mult_by in self.element_mappings:
@@ -339,12 +359,18 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> ModalPortDataArr
339359

340360
port_out_to, mode_index_out_to = row_out
341361
port_in_to, mode_index_in_to = col_out
342-
coords_to = {
343-
"port_in": port_in_to,
344-
"mode_index_in": mode_index_in_to,
345-
"port_out": port_out_to,
346-
"mode_index_out": mode_index_out_to,
347-
}
348-
s_matrix.loc[coords_to] = mult_by * s_matrix.loc[coords_from].values
362+
363+
elements_from = mult_by * s_matrix.loc[coords_from].values
364+
365+
values = set_new_values(
366+
values=values,
367+
new_values=elements_from,
368+
port_name_in=port_in_to,
369+
mode_index_in=mode_index_in_to,
370+
port_name_out=port_out_to,
371+
mode_index_out=mode_index_out_to,
372+
)
373+
374+
s_matrix = ModalPortDataArray(values, coords=coords)
349375

350376
return s_matrix

0 commit comments

Comments
 (0)