@@ -494,6 +494,7 @@ class DiffusionPipeline(ConfigMixin):
494494 _optional_components = []
495495 _exclude_from_cpu_offload = []
496496 _load_connected_pipes = False
497+ _is_onnx = False
497498
498499 def register_modules (self , ** kwargs ):
499500 # import it here to avoid circular import
@@ -839,6 +840,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
839840 If set to `None`, the safetensors weights are downloaded if they're available **and** if the
840841 safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
841842 weights. If set to `False`, safetensors weights are not loaded.
843+ use_onnx (`bool`, *optional*, defaults to `None`):
844+ If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
845+ will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
846+ `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
847+ with `.onnx` and `.pb`.
842848 kwargs (remaining dictionary of keyword arguments, *optional*):
843849 Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
844850 class). The overwritten components are passed directly to the pipelines `__init__` method. See example
@@ -1268,6 +1274,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12681274 variant (`str`, *optional*):
12691275 Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
12701276 loading `from_flax`.
1277+ use_safetensors (`bool`, *optional*, defaults to `None`):
1278+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
1279+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
1280+ weights. If set to `False`, safetensors weights are not loaded.
1281+ use_onnx (`bool`, *optional*, defaults to `False`):
1282+ If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
1283+ will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
1284+ `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
1285+ with `.onnx` and `.pb`.
12711286
12721287 Returns:
12731288 `os.PathLike`:
@@ -1293,6 +1308,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12931308 custom_revision = kwargs .pop ("custom_revision" , None )
12941309 variant = kwargs .pop ("variant" , None )
12951310 use_safetensors = kwargs .pop ("use_safetensors" , None )
1311+ use_onnx = kwargs .pop ("use_onnx" , None )
12961312 load_connected_pipeline = kwargs .pop ("load_connected_pipeline" , False )
12971313
12981314 if use_safetensors and not is_safetensors_available ():
@@ -1364,7 +1380,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13641380 pretrained_model_name , use_auth_token , variant , revision , model_filenames
13651381 )
13661382
1367- model_folder_names = {os .path .split (f )[0 ] for f in model_filenames }
1383+ model_folder_names = {os .path .split (f )[0 ] for f in model_filenames if os . path . split ( f )[ 0 ] in folder_names }
13681384
13691385 # all filenames compatible with variant will be added
13701386 allow_patterns = list (model_filenames )
@@ -1411,6 +1427,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14111427 ):
14121428 ignore_patterns = ["*.bin" , "*.msgpack" ]
14131429
1430+ use_onnx = use_onnx if use_onnx is not None else pipeline_class ._is_onnx
1431+ if not use_onnx :
1432+ ignore_patterns += ["*.onnx" , "*.pb" ]
1433+
14141434 safetensors_variant_filenames = {f for f in variant_filenames if f .endswith (".safetensors" )}
14151435 safetensors_model_filenames = {f for f in model_filenames if f .endswith (".safetensors" )}
14161436 if (
@@ -1423,6 +1443,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14231443 else :
14241444 ignore_patterns = ["*.safetensors" , "*.msgpack" ]
14251445
1446+ use_onnx = use_onnx if use_onnx is not None else pipeline_class ._is_onnx
1447+ if not use_onnx :
1448+ ignore_patterns += ["*.onnx" , "*.pb" ]
1449+
14261450 bin_variant_filenames = {f for f in variant_filenames if f .endswith (".bin" )}
14271451 bin_model_filenames = {f for f in model_filenames if f .endswith (".bin" )}
14281452 if len (bin_variant_filenames ) > 0 and bin_model_filenames != bin_variant_filenames :
0 commit comments