@@ -298,23 +298,6 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> ModalPortDataArr
298
298
}
299
299
s_matrix = ModalPortDataArray (values , coords = coords )
300
300
301
- def set_new_values (
302
- values : np .ndarray ,
303
- new_values : np .ndarray ,
304
- port_name_in : str ,
305
- mode_index_in : int ,
306
- port_name_out : str ,
307
- mode_index_out : int ,
308
- ) -> np .ndarray :
309
- """Replace ``values`` with ``new_values`` at indices given by dims in ``s_matrix```"""
310
- port_in_pos = int (s_matrix .get_index ("port_in" ).get_loc (port_name_in ))
311
- mode_in_pos = int (s_matrix .get_index ("mode_index_in" ).get_loc (mode_index_in ))
312
- port_out_pos = int (s_matrix .get_index ("port_out" ).get_loc (port_name_out ))
313
- mode_out_pos = int (s_matrix .get_index ("mode_index_out" ).get_loc (mode_index_out ))
314
- mask = np .zeros_like (values )
315
- mask [port_out_pos , port_in_pos , mode_out_pos , mode_in_pos , :] = 1
316
- return np .where (mask , new_values , values )
317
-
318
301
# loop through source ports
319
302
for col_index in self .matrix_indices_run_sim :
320
303
port_name_in , mode_index_in = col_index
@@ -335,16 +318,14 @@ def set_new_values(
335
318
source_norm = self ._normalization_factor (port_in , sim_data )
336
319
s_matrix_elements = np .array (amp .data ) / np .array (source_norm )
337
320
338
- values = set_new_values (
339
- values = values ,
340
- new_values = s_matrix_elements ,
341
- port_name_in = port_name_in ,
342
- mode_index_in = mode_index_in ,
343
- port_name_out = port_name_out ,
344
- mode_index_out = mode_index_out ,
345
- )
321
+ coords_set = {
322
+ "port_in" : port_name_in ,
323
+ "mode_index_in" : mode_index_in ,
324
+ "port_out" : port_name_out ,
325
+ "mode_index_out" : mode_index_out ,
326
+ }
346
327
347
- s_matrix = ModalPortDataArray ( values , coords = coords )
328
+ s_matrix = s_matrix . with_updated_data ( data = s_matrix_elements , coords = coords_set )
348
329
349
330
# element can be determined by user-defined mapping
350
331
for (row_in , col_in ), (row_out , col_out ), mult_by in self .element_mappings :
@@ -361,16 +342,13 @@ def set_new_values(
361
342
port_in_to , mode_index_in_to = col_out
362
343
363
344
elements_from = mult_by * s_matrix .loc [coords_from ].values
345
+ coords_to = {
346
+ "port_in" : port_in_to ,
347
+ "mode_index_in" : mode_index_in_to ,
348
+ "port_out" : port_out_to ,
349
+ "mode_index_out" : mode_index_out_to ,
350
+ }
364
351
365
- values = set_new_values (
366
- values = values ,
367
- new_values = elements_from ,
368
- port_name_in = port_in_to ,
369
- mode_index_in = mode_index_in_to ,
370
- port_name_out = port_out_to ,
371
- mode_index_out = mode_index_out_to ,
372
- )
373
-
374
- s_matrix = ModalPortDataArray (values , coords = coords )
352
+ s_matrix = s_matrix .with_updated_data (data = elements_from , coords = coords_to )
375
353
376
354
return s_matrix
0 commit comments