103103from torch .utils ._pytree import tree_map
104104
105105try :
106- from torch .compiler import is_dynamo_compiling
106+ from torch .compiler import is_compiling
107107except ImportError : # torch 2.0
108- from torch ._dynamo import is_compiling as is_dynamo_compiling
108+ from torch ._dynamo import is_compiling
109109
110110try :
111111 from torch import _foreach_copy_
@@ -5247,7 +5247,7 @@ def _view_and_pad(tensor):
52475247 if v .device != storage .device :
52485248 v = v .to (storage .device , non_blocking = non_blocking )
52495249 stride = v .stride ()
5250- if is_dynamo_compiling ():
5250+ if is_compiling ():
52515251 if not v .is_contiguous ():
52525252 v = v .clone (memory_format = torch .contiguous_format )
52535253 elif (stride and stride [- 1 ] != 1 ) or v .storage_offset ():
@@ -6963,7 +6963,7 @@ def _values_list(
69636963 is_leaf = is_leaf ,
69646964 collapse = collapse ,
69656965 )
6966- if is_dynamo_compiling ():
6966+ if is_compiling ():
69676967 key_to_index = {key : i for i , key in enumerate (keys )}
69686968 return [vals [key_to_index [key ]] for key in sorting_keys ]
69696969 else :
@@ -6994,7 +6994,7 @@ def _items_list(
69946994 return list (keys ), list (vals )
69956995 if default is None :
69966996 # TODO: check that lists are identical
6997- if is_dynamo_compiling ():
6997+ if is_compiling ():
69986998 key_to_index = {key : i for i , key in enumerate (keys )}
69996999 new_vals = [vals [key_to_index [key ]] for key in sorting_keys ]
70007000 if len (new_vals ) < len (vals ):
@@ -7015,12 +7015,9 @@ def _items_list(
70157015 ] # intersection does not keep the sorting
70167016 else :
70177017 new_keys = list (set (sorting_keys ).union (keys ))
7018- if is_dynamo_compiling ():
7019- ...
7020- else :
7021- source = dict (zip (keys , vals ))
7022- vals = [source .get (key , default ) for key in new_keys ]
7023- return new_keys , vals
7018+ source = dict (zip (keys , vals ))
7019+ vals = [source .get (key , default ) for key in new_keys ]
7020+ return new_keys , vals
70247021
70257022 def _grad (self ):
70267023 # We can't cache this because zero_grad can be called outside (eg from optimizer) and we want the tensors
@@ -11931,7 +11928,7 @@ def unflatten_keys(self, separator: str = ".", inplace: bool = False) -> T:
1193111928 result .lock_ ()
1193211929 return result
1193311930 else :
11934- if not is_dynamo_compiling ():
11931+ if not is_compiling ():
1193511932 key_list = list (self .keys ())
1193611933 else :
1193711934 key_list = [k for k in self .keys ()] # noqa
@@ -12196,10 +12193,10 @@ def lock_(self) -> T:
1219612193 """
1219712194 if self .is_locked :
1219812195 return self
12199- is_compiling = is_dynamo_compiling ()
12200- if is_compiling :
12196+ is_comp = is_compiling ()
12197+ if is_comp :
1220112198 _lock_warn ()
12202- self ._propagate_lock (is_compiling = is_compiling )
12199+ self ._propagate_lock (is_compiling = is_comp )
1220312200 return self
1220412201
1220512202 @erase_cache
@@ -12611,7 +12608,7 @@ def copy_dict(d):
1261112608 def _sync_all (self ):
1261212609 if _has_cuda :
1261312610 # TODO: dynamo doesn't like torch.cuda.is_initialized
12614- if not is_dynamo_compiling () and torch .cuda .is_initialized ():
12611+ if not is_compiling () and torch .cuda .is_initialized ():
1261512612 torch .cuda .synchronize ()
1261612613 elif _has_mps :
1261712614 mps = getattr (torch , "mps" , None )
@@ -12799,7 +12796,7 @@ def _register_tensor_class(cls):
1279912796
1280012797
1280112798def _is_tensor_collection (datatype : type ) -> bool :
12802- is_dynamo = is_dynamo_compiling ()
12799+ is_dynamo = is_compiling ()
1280312800 out = None
1280412801 if not is_dynamo :
1280512802 out = _TENSOR_COLLECTION_MEMO .get (datatype )
0 commit comments