@@ -489,23 +489,23 @@ def _ag_interp_func(var, indexes_coords, method, **kwargs):
489
489
result = result .transpose (* out_dims )
490
490
return result
491
491
492
- def with_updated_data (
493
- self ,
494
- data : float ,
495
- coords : dict [str , Any ],
496
- ) -> DataArray :
492
+ def with_updated_data (self , data , coords : dict [str , Any ]) -> DataArray :
497
493
"""Make copy of ``DataArray`` with ``data`` at specified ``coords``, autograd-approved."""
494
+
495
+ mask = xr .zeros_like (self , dtype = bool )
496
+ mask .loc [coords ] = True
497
+
498
498
old_values = self .values
499
- slice_indices = [ slice ( None )] * len ( self . dims ) # Start with full slices for all dims
500
- for dim_idx , dim_name in enumerate ( self .dims ):
501
- if dim_name in coords :
502
- coord_value = coords [ dim_name ]
503
- pos = int ( self . get_index ( dim_name ). get_loc ( coord_value ))
504
- slice_indices [ dim_idx ] = pos
505
- mask_slice = tuple ( slice_indices )
506
- mask = np . zeros_like ( old_values , dtype = bool )
507
- mask [ mask_slice ] = True
508
- modified_values = np . where ( mask , data , old_values )
499
+ replacement_data = np . zeros_like ( old_values )
500
+ axes_to_fill = [ self .get_axis_num ( key ) for key in coords ]
501
+
502
+ reshape_guide = [ 1 ] * len ( old_values . shape )
503
+ for idx , dim in enumerate ( axes_to_fill ):
504
+ reshape_guide [ dim ] = data . shape [ idx ]
505
+
506
+ replacement_data = replacement_data + data . reshape ( reshape_guide )
507
+ modified_values = np . where ( mask , replacement_data , old_values )
508
+
509
509
return self .copy (deep = True , data = modified_values )
510
510
511
511
0 commit comments