Skip to content

Commit 548953f

Browse files
committed
gregs refactor
1 parent a0c1dc2 commit 548953f

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

tidy3d/components/data/data_array.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -489,23 +489,23 @@ def _ag_interp_func(var, indexes_coords, method, **kwargs):
489489
result = result.transpose(*out_dims)
490490
return result
491491

492-
def with_updated_data(
493-
self,
494-
data: float,
495-
coords: dict[str, Any],
496-
) -> DataArray:
492+
def with_updated_data(self, data, coords: dict[str, Any]) -> DataArray:
497493
"""Make copy of ``DataArray`` with ``data`` at specified ``coords``, autograd-approved."""
494+
495+
mask = xr.zeros_like(self, dtype=bool)
496+
mask.loc[coords] = True
497+
498498
old_values = self.values
499-
slice_indices = [slice(None)] * len(self.dims) # Start with full slices for all dims
500-
for dim_idx, dim_name in enumerate(self.dims):
501-
if dim_name in coords:
502-
coord_value = coords[dim_name]
503-
pos = int(self.get_index(dim_name).get_loc(coord_value))
504-
slice_indices[dim_idx] = pos
505-
mask_slice = tuple(slice_indices)
506-
mask = np.zeros_like(old_values, dtype=bool)
507-
mask[mask_slice] = True
508-
modified_values = np.where(mask, data, old_values)
499+
replacement_data = np.zeros_like(old_values)
500+
axes_to_fill = [self.get_axis_num(key) for key in coords]
501+
502+
reshape_guide = [1] * len(old_values.shape)
503+
for idx, dim in enumerate(axes_to_fill):
504+
reshape_guide[dim] = data.shape[idx]
505+
506+
replacement_data = replacement_data + data.reshape(reshape_guide)
507+
modified_values = np.where(mask, replacement_data, old_values)
508+
509509
return self.copy(deep=True, data=modified_values)
510510

511511

tidy3d/plugins/smatrix/component_modelers/modal.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> ModalPortDataArr
325325
"mode_index_out": mode_index_out,
326326
}
327327

328-
s_matrix = s_matrix.with_updated_data(data=s_matrix_elements, coords=coords_set)
328+
reshaped_smatrix_data = np.reshape(s_matrix_elements, (1, 1, 1, -1))
329+
s_matrix = s_matrix.with_updated_data(data=reshaped_smatrix_data, coords=coords_set)
329330

330331
# element can be determined by user-defined mapping
331332
for (row_in, col_in), (row_out, col_out), mult_by in self.element_mappings:
@@ -349,6 +350,7 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> ModalPortDataArr
349350
"mode_index_out": mode_index_out_to,
350351
}
351352

352-
s_matrix = s_matrix.with_updated_data(data=elements_from, coords=coords_to)
353+
reshaped_elements_from = np.reshape(elements_from, (1, 1, 1, -1))
354+
s_matrix = s_matrix.with_updated_data(data=reshaped_elements_from, coords=coords_to)
353355

354356
return s_matrix

0 commit comments

Comments
 (0)