Skip to content

Commit ef4421a

Browse files
committed
fix a warning
1 parent 4019202 commit ef4421a

9 files changed

+17
-17
lines changed

denoising_diffusion_pytorch/classifier_free_guidance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from torch import nn, einsum
1111
import torch.nn.functional as F
12-
from torch.cuda.amp import autocast
12+
from torch.amp import autocast
1313

1414
from einops import rearrange, reduce, repeat
1515
from einops.layers.torch import Rearrange
@@ -731,7 +731,7 @@ def interpolate(self, x1, x2, classes, t = None, lam = 0.5):
731731

732732
return img
733733

734-
@autocast(enabled = False)
734+
@autocast('cuda', enabled = False)
735735
def q_sample(self, x_start, t, noise=None):
736736
noise = default(noise, lambda: torch.randn_like(x_start))
737737

denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch import sqrt
44
from torch import nn, einsum
55
import torch.nn.functional as F
6-
from torch.cuda.amp import autocast
6+
from torch.amp import autocast
77
from torch.special import expm1
88

99
from tqdm import tqdm
@@ -234,7 +234,7 @@ def sample(self, batch_size = 16):
234234

235235
# training related functions - noise prediction
236236

237-
@autocast(enabled = False)
237+
@autocast('cuda', enabled = False)
238238
def q_sample(self, x_start, times, noise = None):
239239
noise = default(noise, lambda: torch.randn_like(x_start))
240240

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch import nn, einsum
1111
import torch.nn.functional as F
1212
from torch.nn import Module, ModuleList
13-
from torch.cuda.amp import autocast
13+
from torch.amp import autocast
1414
from torch.utils.data import Dataset, DataLoader
1515

1616
from torch.optim import Adam
@@ -772,7 +772,7 @@ def noise_assignment(self, x_start, noise):
772772
_, assign = linear_sum_assignment(dist.cpu())
773773
return torch.from_numpy(assign).to(dist.device)
774774

775-
@autocast(enabled = False)
775+
@autocast('cuda', enabled = False)
776776
def q_sample(self, x_start, t, noise = None):
777777
noise = default(noise, lambda: torch.randn_like(x_start))
778778

denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch import nn, einsum, Tensor
1010
from torch.nn import Module, ModuleList
1111
import torch.nn.functional as F
12-
from torch.cuda.amp import autocast
12+
from torch.amp import autocast
1313
from torch.optim import Adam
1414
from torch.utils.data import Dataset, DataLoader
1515

@@ -660,7 +660,7 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
660660

661661
return img
662662

663-
@autocast(enabled = False)
663+
@autocast('cuda', enabled = False)
664664
def q_sample(self, x_start, t, noise=None):
665665
noise = default(noise, lambda: torch.randn_like(x_start))
666666

denoising_diffusion_pytorch/guided_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from torch import nn, einsum
1111
import torch.nn.functional as F
12-
from torch.cuda.amp import autocast
12+
from torch.amp import autocast
1313
from torch.utils.data import Dataset, DataLoader
1414

1515
from torch.optim import Adam
@@ -709,7 +709,7 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
709709

710710
return img
711711

712-
@autocast(enabled = False)
712+
@autocast('cuda', enabled = False)
713713
def q_sample(self, x_start, t, noise=None):
714714
noise = default(noise, lambda: torch.randn_like(x_start))
715715

denoising_diffusion_pytorch/repaint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch import nn, einsum
1111
import torch.nn.functional as F
1212
from torch.nn import Module, ModuleList
13-
from torch.cuda.amp import autocast
13+
from torch.amp import autocast
1414
from torch.utils.data import Dataset, DataLoader
1515

1616
from torch.optim import Adam
@@ -815,7 +815,7 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
815815

816816
return img
817817

818-
@autocast(enabled = False)
818+
@autocast('cuda', enabled = False)
819819
def q_sample(self, x_start, t, noise = None):
820820
noise = default(noise, lambda: torch.randn_like(x_start))
821821

denoising_diffusion_pytorch/simple_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch import nn, einsum
77
import torch.nn.functional as F
88
from torch.special import expm1
9-
from torch.cuda.amp import autocast
9+
from torch.amp import autocast
1010

1111
from tqdm import tqdm
1212
from einops import rearrange, repeat, reduce, pack, unpack
@@ -651,7 +651,7 @@ def sample(self, batch_size = 16):
651651

652652
# training related functions - noise prediction
653653

654-
@autocast(enabled = False)
654+
@autocast('cuda', enabled = False)
655655
def q_sample(self, x_start, times, noise = None):
656656
noise = default(noise, lambda: torch.randn_like(x_start))
657657

denoising_diffusion_pytorch/v_param_continuous_time_gaussian_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import nn, einsum
55
import torch.nn.functional as F
66
from torch.special import expm1
7-
from torch.cuda.amp import autocast
7+
from torch.amp import autocast
88

99
from tqdm import tqdm
1010
from einops import rearrange, repeat, reduce
@@ -150,7 +150,7 @@ def sample(self, batch_size = 16):
150150

151151
# training related functions - noise prediction
152152

153-
@autocast(enabled = False)
153+
@autocast('cuda', enabled = False)
154154
def q_sample(self, x_start, times, noise = None):
155155
noise = default(noise, lambda: torch.randn_like(x_start))
156156

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.0.16'
1+
__version__ = '2.0.17'

0 commit comments

Comments
 (0)