Skip to content

Commit 7baa866

Browse files
committed
nicer
1 parent 6cfbb87 commit 7baa866

File tree

2 files changed

+33
-36
lines changed

2 files changed

+33
-36
lines changed

tidy3d/components/data/data_array.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,25 @@ def _ag_interp_func(var, indexes_coords, method, **kwargs):
489489
result = result.transpose(*out_dims)
490490
return result
491491

492+
def with_updated_data(
493+
self,
494+
data: float,
495+
coords: dict[str, Any],
496+
) -> DataArray:
497+
"""Make copy of ``DataArray`` with ``data`` at specified ``coords``, autograd-approved."""
498+
old_values = self.values
499+
slice_indices = [slice(None)] * len(self.dims) # Start with full slices for all dims
500+
for dim_idx, dim_name in enumerate(self.dims):
501+
if dim_name in coords:
502+
coord_value = coords[dim_name]
503+
pos = int(self.get_index(dim_name).get_loc(coord_value))
504+
slice_indices[dim_idx] = pos
505+
mask_slice = tuple(slice_indices)
506+
mask = np.zeros_like(old_values, dtype=bool)
507+
mask[mask_slice] = True
508+
modified_values = np.where(mask, data, old_values)
509+
return self.copy(deep=True, data=modified_values)
510+
492511

493512
class FreqDataArray(DataArray):
494513
"""Frequency-domain array.

tidy3d/plugins/smatrix/component_modelers/modal.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -298,23 +298,6 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> ModalPortDataArr
298298
}
299299
s_matrix = ModalPortDataArray(values, coords=coords)
300300

301-
def set_new_values(
302-
values: np.ndarray,
303-
new_values: np.ndarray,
304-
port_name_in: str,
305-
mode_index_in: int,
306-
port_name_out: str,
307-
mode_index_out: int,
308-
) -> np.ndarray:
309-
"""Replace ``values`` with ``new_values`` at indices given by dims in ``s_matrix```"""
310-
port_in_pos = int(s_matrix.get_index("port_in").get_loc(port_name_in))
311-
mode_in_pos = int(s_matrix.get_index("mode_index_in").get_loc(mode_index_in))
312-
port_out_pos = int(s_matrix.get_index("port_out").get_loc(port_name_out))
313-
mode_out_pos = int(s_matrix.get_index("mode_index_out").get_loc(mode_index_out))
314-
mask = np.zeros_like(values)
315-
mask[port_out_pos, port_in_pos, mode_out_pos, mode_in_pos, :] = 1
316-
return np.where(mask, new_values, values)
317-
318301
# loop through source ports
319302
for col_index in self.matrix_indices_run_sim:
320303
port_name_in, mode_index_in = col_index
@@ -335,16 +318,14 @@ def set_new_values(
335318
source_norm = self._normalization_factor(port_in, sim_data)
336319
s_matrix_elements = np.array(amp.data) / np.array(source_norm)
337320

338-
values = set_new_values(
339-
values=values,
340-
new_values=s_matrix_elements,
341-
port_name_in=port_name_in,
342-
mode_index_in=mode_index_in,
343-
port_name_out=port_name_out,
344-
mode_index_out=mode_index_out,
345-
)
321+
coords_set = {
322+
"port_in": port_name_in,
323+
"mode_index_in": mode_index_in,
324+
"port_out": port_name_out,
325+
"mode_index_out": mode_index_out,
326+
}
346327

347-
s_matrix = ModalPortDataArray(values, coords=coords)
328+
s_matrix = s_matrix.with_updated_data(data=s_matrix_elements, coords=coords_set)
348329

349330
# element can be determined by user-defined mapping
350331
for (row_in, col_in), (row_out, col_out), mult_by in self.element_mappings:
@@ -361,16 +342,13 @@ def set_new_values(
361342
port_in_to, mode_index_in_to = col_out
362343

363344
elements_from = mult_by * s_matrix.loc[coords_from].values
345+
coords_to = {
346+
"port_in": port_in_to,
347+
"mode_index_in": mode_index_in_to,
348+
"port_out": port_out_to,
349+
"mode_index_out": mode_index_out_to,
350+
}
364351

365-
values = set_new_values(
366-
values=values,
367-
new_values=elements_from,
368-
port_name_in=port_in_to,
369-
mode_index_in=mode_index_in_to,
370-
port_name_out=port_out_to,
371-
mode_index_out=mode_index_out_to,
372-
)
373-
374-
s_matrix = ModalPortDataArray(values, coords=coords)
352+
s_matrix = s_matrix.with_updated_data(data=elements_from, coords=coords_to)
375353

376354
return s_matrix

0 commit comments

Comments
 (0)