@@ -10456,17 +10456,88 @@ def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
1045610456 if pin_memory :
1045710457 storage = storage .pin_memory ()
1045810458 storage_cast = storage .to (device , non_blocking = True )
10459+ if is_dynamo_compiling ():
10460+ return self ._to_reconstruct_compiled (
10461+ storage , storage_cast , device , num_threads , non_blocking
10462+ )
10463+ return self ._to_reconstruct (
10464+ storage , storage_cast , device , num_threads , non_blocking
10465+ )
10466+
10467+ def _to_reconstruct (self , storage , storage_cast , device , num_threads , non_blocking ):
1045910468 untyped_storage = storage_cast .untyped_storage ()
1046010469
1046110470 def set_ (x ):
10471+ if x .is_nested :
10472+ if x .layout != torch .jagged :
10473+ raise RuntimeError (
10474+ "to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10475+ "Please raise an issue on GitHub."
10476+ )
10477+ values = x ._values
10478+ lengths = x ._lengths
10479+ offsets = x ._offsets
10480+ return torch .nested .nested_tensor_from_jagged (
10481+ set_ (values ),
10482+ offsets = set_ (offsets ),
10483+ lengths = set_ (lengths ) if lengths is not None else None ,
10484+ )
1046210485 storage_offset = x .storage_offset ()
1046310486 stride = x .stride ()
10464- return torch . empty_like ( x , device = device ).set_ (
10487+ return x . new_empty (( 0 ,) , device = device ).set_ (
1046510488 untyped_storage ,
1046610489 size = x .shape ,
1046710490 stride = stride ,
1046810491 storage_offset = storage_offset ,
1046910492 )
10493+ # return torch.empty_like(x, device=device).set_(
10494+ # untyped_storage,
10495+ # size=x.shape,
10496+ # stride=stride,
10497+ # storage_offset=storage_offset,
10498+ # )
10499+
10500+ result = self ._fast_apply (
10501+ set_ , device = torch .device (device ), num_threads = num_threads
10502+ )
10503+ result ._consolidated = {"storage" : storage_cast }
10504+ if "metadata" in self ._consolidated :
10505+ result ._consolidated ["metadata" ] = deepcopy (self ._consolidated ["metadata" ])
10506+ if non_blocking in (False , None ):
10507+ if device .type == "cuda" and non_blocking is False :
10508+ # sending to CUDA force sync
10509+ cuda_device = device
10510+ elif storage .device .type == "cuda" :
10511+ # sending from cuda: need sync unless intentionally not asked for
10512+ cuda_device = storage .device .type
10513+ else :
10514+ cuda_device = None
10515+ if cuda_device is not None :
10516+ torch .cuda .current_stream (cuda_device ).synchronize ()
10517+
10518+ return result
10519+
10520+ def _to_reconstruct_compiled (self , storage , storage_cast , device , num_threads , non_blocking ):
10521+ def set_ (x ):
10522+ if x .is_nested :
10523+ if x .layout != torch .jagged :
10524+ raise RuntimeError (
10525+ "to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10526+ "Please raise an issue on GitHub."
10527+ )
10528+ values = x ._values
10529+ lengths = x ._lengths
10530+ offsets = x ._offsets
10531+ return torch ._nested_view_from_jagged (
10532+ set_ (values ),
10533+ set_ (offsets ),
10534+ x ,
10535+ lengths = set_ (lengths ) if lengths is not None else None ,
10536+ )
10537+ storage_offset = x .storage_offset ()
10538+ stride = x .stride ()
10539+ index_slice = slice (storage_offset , storage_offset + x .numel (), stride [0 ])
10540+ return storage_cast .view (x .dtype )[index_slice ].view (x .type )
1047010541
1047110542 result = self ._fast_apply (
1047210543 set_ , device = torch .device (device ), num_threads = num_threads
0 commit comments