@@ -276,81 +276,6 @@ def get_hardware_suffix(with_available_torch_build: bool = False, torch_version:
276276 return hardware_suffix
277277
278278
279- def add_hardware_suffix_to_torch (
280- requirement : Requirement ,
281- hardware_suffix : str | None = None ,
282- with_available_torch_build : bool = False ,
283- ) -> str :
284- """Add hardware suffix to the torch requirement.
285-
286- Args:
287- requirement (Requirement): Requirement object comprising requirement
288- details.
289- hardware_suffix (str | None): Hardware suffix. If None, it will be set
290- to the correct hardware suffix. Defaults to None.
291- with_available_torch_build (bool): To check whether the installed
292- CUDA version is supported by the latest available PyTorch build.
293- Defaults to False.
294-
295- Examples:
296- >>> from pkg_resources import Requirement
297- >>> req = "torch>=1.13.0, <=2.0.1"
298- >>> requirement = Requirement.parse(req)
299- >>> requirement.name, requirement.specs
300- ('torch', [('>=', '1.13.0'), ('<=', '2.0.1')])
301-
302- >>> add_hardware_suffix_to_torch(requirement)
303- 'torch>=1.13.0+cu121, <=2.0.1+cu121'
304-
305- ``with_available_torch_build=True`` will use the latest available PyTorch build.
306- >>> req = "torch==2.0.1"
307- >>> requirement = Requirement.parse(req)
308- >>> add_hardware_suffix_to_torch(requirement, with_available_torch_build=True)
309- 'torch==2.0.1+cu118'
310-
311- It is possible to pass the ``hardware_suffix`` manually.
312- >>> req = "torch==2.0.1"
313- >>> requirement = Requirement.parse(req)
314- >>> add_hardware_suffix_to_torch(requirement, hardware_suffix="cu121")
315- 'torch==2.0.1+cu111'
316-
317- Raises:
318- ValueError: When the requirement has more than two version criterion.
319-
320- Returns:
321- str: Updated torch package with the right cuda suffix.
322- """
323- name = requirement .unsafe_name
324- updated_specs : list [str ] = []
325-
326- for operator , version in requirement .specs :
327- hardware_suffix = hardware_suffix or get_hardware_suffix (with_available_torch_build , version )
328- updated_version = version + f"+{ hardware_suffix } " if not version .startswith (("2.1" , "2.2" )) else version
329-
330- # ``specs`` contains operators and versions as follows:
331- # These are to be concatenated again for the updated version.
332- updated_specs .append (operator + updated_version )
333-
334- updated_requirement : str = ""
335-
336- if updated_specs :
337- # This is the case when specs are e.g. ['<=1.9.1+cu111']
338- if len (updated_specs ) == 1 :
339- updated_requirement = name + updated_specs [0 ]
340- # This is the case when specs are e.g., ['<=1.9.1+cu111', '>=1.8.1+cu111']
341- elif len (updated_specs ) == 2 :
342- updated_requirement = name + updated_specs [0 ] + ", " + updated_specs [1 ]
343- else :
344- msg = (
345- "Requirement version can be a single value or a range. \n "
346- "For example it could be torch>=1.8.1 "
347- "or torch>=1.8.1, <=1.9.1\n "
348- f"Got { updated_specs } instead."
349- )
350- raise ValueError (msg )
351- return updated_requirement
352-
353-
354279def get_torch_install_args (requirement : str | Requirement ) -> list [str ]:
355280 """Get the install arguments for Torch requirement.
356281
@@ -368,7 +293,7 @@ def get_torch_install_args(requirement: str | Requirement) -> list[str]:
368293 >>> requriment = "torch>=1.13.0"
369294 >>> get_torch_install_args(requirement)
370295 ['--extra-index-url', 'https://download.pytorch.org/whl/cpu',
371- 'torch== 1.13.0+cpu ', 'torchvision==0.14.0+cpu ']
296+ 'torch>= 1.13.0', 'torchvision==0.14.0']
372297
373298 Returns:
374299 list[str]: The install arguments.
@@ -401,21 +326,15 @@ def get_torch_install_args(requirement: str | Requirement) -> list[str]:
401326 # Create the PyTorch Index URL to download the correct wheel.
402327 index_url = f"https://download.pytorch.org/whl/{ hardware_suffix } "
403328
404- # Create the PyTorch version depending on the CUDA version. For example,
405- # If CUDA version is 11.2, then the PyTorch version is 1.8.0+cu112.
406- # If CUDA version is None, then the PyTorch version is 1.8.0+cpu.
407- torch_version = add_hardware_suffix_to_torch (requirement , hardware_suffix , with_available_torch_build = True )
329+ torch_version = f"{ requirement .name } { operator } { version } " # eg: torch==1.13.0
408330
409331 # Get the torchvision version depending on the torch version.
410332 torchvision_version = AVAILABLE_TORCH_VERSIONS [version ]["torchvision" ]
411333 torchvision_requirement = f"torchvision{ operator } { torchvision_version } "
412- if isinstance (torchvision_version , str ) and not torchvision_version .startswith ("0.16" ):
413- torchvision_requirement += f"+{ hardware_suffix } "
414334
415335 # Return the install arguments.
416336 install_args += [
417337 "--extra-index-url" ,
418- # "--index-url",
419338 index_url ,
420339 torch_version ,
421340 torchvision_requirement ,
0 commit comments