Skip to content

Commit a4d1737

Browse files
committed
differentiable s-matrix calculation
1 parent dbe5229 commit a4d1737

File tree

8 files changed

+103
-19
lines changed

8 files changed

+103
-19
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
- Objective functions that involve running `tidy3d.plugins.smatrix.ComponentModeler` can be differentiated with autograd.
12+
1013
## [2.9.0rc2] - 2025-07-17
1114

1215
### Added

tests/test_data/test_data_arrays.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
from typing import Optional
66

7-
import numpy as np
7+
import autograd as ag
8+
import autograd.numpy as np
89
import pytest
910
import xarray.testing as xrt
11+
from autograd.test_util import check_grads
1012

1113
import tidy3d as td
1214
from tidy3d.exceptions import DataError
@@ -468,3 +470,32 @@ def test_interp(method, scalar_index):
468470
xr_interp = data.interp(f=f)
469471
ag_interp = data._ag_interp(f=f)
470472
xrt.assert_allclose(xr_interp, ag_interp)
473+
474+
475+
def test_with_updated_data():
476+
"""Check the ``DataArray.with_updated_data()`` method."""
477+
478+
arr = td.SpatialDataArray(
479+
np.ones((2, 3, 4, 5), dtype=np.complex128),
480+
coords={"x": [0, 1], "y": [1, 2, 3], "z": [2, 3, 4, 5], "w": [0, 1, 2, 3, 4]},
481+
)
482+
483+
data = np.zeros((1, 1, 1, 5))
484+
485+
coords = {"x": 0, "y": 2, "z": 3}
486+
487+
arr2 = arr._with_updated_data(data=data, coords=coords)
488+
489+
data_expected = np.ones(arr.shape) + 0j
490+
data_expected[0, 1, 1, :] = 0.0 + 0j
491+
assert np.all(arr2.data == data_expected), "DataArray.with_updated_copy() failed"
492+
493+
def f(x):
494+
arr2 = arr._with_updated_data(data=x, coords=coords)
495+
return np.abs(np.sum(arr2.data))
496+
497+
# grad should just be all 1s because of sum, so check that this is true
498+
g = ag.grad(f)(data)
499+
assert np.all(g == np.ones_like(data))
500+
501+
check_grads(f)(data)

tidy3d/components/data/data_array.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,33 @@ def _ag_interp_func(var, indexes_coords, method, **kwargs):
489489
result = result.transpose(*out_dims)
490490
return result
491491

492+
def _with_updated_data(self, data: np.ndarray, coords: dict[str, Any]) -> DataArray:
493+
"""Make copy of ``DataArray`` with ``data`` at specified ``coords``, autograd compatible
494+
495+
Constraints / Edge cases:
496+
- `coords` must map to a specific value eg {x: '1'}, does not broadcast to arrays
497+
- `data` will be reshaped to try to match `self.shape` except where `coords` present
498+
"""
499+
500+
# make mask
501+
mask = xr.zeros_like(self, dtype=bool)
502+
mask.loc[coords] = True
503+
504+
# reshape `data` to line up with `self.dims`, with shape of 1 along the selected axis
505+
old_data = self.data
506+
new_shape = list(old_data.shape)
507+
for i, dim in enumerate(self.dims):
508+
if dim in coords:
509+
new_shape[i] = 1
510+
new_data = data.reshape(new_shape)
511+
512+
# broadcast data to repeat data along the selected dimensions to match mask
513+
new_data = new_data + np.zeros_like(old_data)
514+
515+
new_data = np.where(mask, new_data, old_data)
516+
517+
return self.copy(deep=True, data=new_data)
518+
492519

493520
class FreqDataArray(DataArray):
494521
"""Frequency-domain array.

tidy3d/plugins/autograd/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ We also support the following high-level features:
217217
- We automatically determine the number of adjoint simulations to run from a given forward simulation to maintain gradient accuracy.
218218
Adjoint sources are automatically grouped by either frequency or spatial port (whichever yields fewer adjoint simulations), and all adjoint simulations are run in a single batch (applies to both `run` and `run_async`).
219219
The parameter `max_num_adjoint_per_fwd` (default `10`) prevents launching unexpectedly large numbers of adjoint simulations automatically.
220+
- Differentiation of objective functions involving the scattering matrix produced by `tidy3d.plugins.smatrix.ComponentModeler`.
220221

221222
We currently have the following restrictions:
222223

tidy3d/plugins/smatrix/component_modelers/base.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from abc import ABC, abstractmethod
77
from typing import Optional, Union, get_args
88

9-
import numpy as np
9+
import autograd.numpy as np
1010
import pydantic.v1 as pd
1111

1212
from tidy3d.components.base import Tidy3dBaseModel, cached_property
@@ -22,6 +22,7 @@
2222
from tidy3d.plugins.smatrix.ports.modal import Port
2323
from tidy3d.plugins.smatrix.ports.rectangular_lumped import LumpedPort
2424
from tidy3d.plugins.smatrix.ports.wave import WavePort
25+
from tidy3d.web import run_async
2526
from tidy3d.web.api.container import Batch, BatchData
2627

2728
# fwidth of gaussian pulse in units of central frequency
@@ -196,7 +197,21 @@ def batch_path(self) -> str:
196197
@cached_property
197198
def batch_data(self) -> BatchData:
198199
"""The :class:`.BatchData` associated with the simulations run for this component modeler."""
199-
return self.batch.run(path_dir=self.path_dir)
200+
201+
# NOTE: uses run_async because Batch is not differentiable.
202+
batch = self.batch
203+
run_async_kwargs = batch.dict(
204+
exclude={
205+
"type",
206+
"path_dir",
207+
"attrs",
208+
"solver_version",
209+
"jobs_cached",
210+
"num_workers",
211+
"simulations",
212+
}
213+
)
214+
return run_async(batch.simulations, **run_async_kwargs, path_dir=self.path_dir)
200215

201216
def get_path_dir(self, path_dir: str) -> None:
202217
"""Check whether the supplied 'path_dir' matches the internal field value."""

tidy3d/plugins/smatrix/component_modelers/modal.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import Optional
88

9-
import numpy as np
9+
import autograd.numpy as np
1010
import pydantic.v1 as pd
1111

1212
from tidy3d.components.base import cached_property
@@ -317,14 +317,15 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> ModalPortDataArr
317317
)
318318
source_norm = self._normalization_factor(port_in, sim_data)
319319
s_matrix_elements = np.array(amp.data) / np.array(source_norm)
320-
s_matrix.loc[
321-
{
322-
"port_in": port_name_in,
323-
"mode_index_in": mode_index_in,
324-
"port_out": port_name_out,
325-
"mode_index_out": mode_index_out,
326-
}
327-
] = s_matrix_elements
320+
321+
coords_set = {
322+
"port_in": port_name_in,
323+
"mode_index_in": mode_index_in,
324+
"port_out": port_name_out,
325+
"mode_index_out": mode_index_out,
326+
}
327+
328+
s_matrix = s_matrix._with_updated_data(data=s_matrix_elements, coords=coords_set)
328329

329330
# element can be determined by user-defined mapping
330331
for (row_in, col_in), (row_out, col_out), mult_by in self.element_mappings:
@@ -339,12 +340,14 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> ModalPortDataArr
339340

340341
port_out_to, mode_index_out_to = row_out
341342
port_in_to, mode_index_in_to = col_out
343+
344+
elements_from = mult_by * s_matrix.loc[coords_from].values
342345
coords_to = {
343346
"port_in": port_in_to,
344347
"mode_index_in": mode_index_in_to,
345348
"port_out": port_out_to,
346349
"mode_index_out": mode_index_out_to,
347350
}
348-
s_matrix.loc[coords_to] = mult_by * s_matrix.loc[coords_from].values
351+
s_matrix = s_matrix._with_updated_data(data=elements_from, coords=coords_to)
349352

350353
return s_matrix

tidy3d/plugins/smatrix/component_modelers/terminal.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,9 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> TerminalPortData
221221
for port_in in self.ports:
222222
sim_data = batch_data[self._task_name(port=port_in)]
223223
a, b = self.compute_power_wave_amplitudes_at_each_port(port_impedances, sim_data)
224-
indexer = {"f": a.f, "port_in": port_in.name, "port_out": a.port}
225-
a_matrix.loc[indexer] = a
226-
b_matrix.loc[indexer] = b
224+
indexer = {"port_in": port_in.name}
225+
a_matrix = a_matrix._with_updated_data(data=a.data, coords=indexer)
226+
b_matrix = b_matrix._with_updated_data(data=b.data, coords=indexer)
227227

228228
s_matrix = self.ab_to_s(a_matrix, b_matrix)
229229
return s_matrix
@@ -481,10 +481,14 @@ def _port_reference_impedances(self, batch_data: BatchData) -> PortDataArray:
481481
# WavePorts have a port impedance calculated from its associated modal field distribution
482482
# and is frequency dependent.
483483
impedances = port.compute_port_impedance(sim_data_port).values
484-
port_impedances.loc[{"port": port.name}] = impedances.squeeze()
484+
port_impedances = port_impedances._with_updated_data(
485+
data=impedances, coords={"port": port.name}
486+
)
485487
else:
486488
# LumpedPorts have a constant reference impedance
487-
port_impedances.loc[{"port": port.name}] = np.full(len(self.freqs), port.impedance)
489+
port_impedances = port_impedances._with_updated_data(
490+
data=np.full(len(self.freqs), port.impedance), coords={"port": port.name}
491+
)
488492

489493
port_impedances = TerminalComponentModeler._set_port_data_array_attributes(port_impedances)
490494
return port_impedances

0 commit comments

Comments
 (0)