@@ -190,6 +190,7 @@ def __init__(
190190 self ._transformer = None
191191 self ._data_sampler = None
192192 self ._generator = None
193+ self ._discriminator = None
193194 self .loss_values = None
194195
195196 @staticmethod
@@ -330,7 +331,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
330331 self ._embedding_dim + self ._data_sampler .dim_cond_vec (), self ._generator_dim , data_dim
331332 ).to (self ._device )
332333
333- discriminator = Discriminator (
334+ self . _discriminator = Discriminator (
334335 data_dim + self ._data_sampler .dim_cond_vec (), self ._discriminator_dim , pac = self .pac
335336 ).to (self ._device )
336337
@@ -342,7 +343,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
342343 )
343344
344345 optimizerD = optim .Adam (
345- discriminator .parameters (),
346+ self . _discriminator .parameters (),
346347 lr = self ._discriminator_lr ,
347348 betas = (0.5 , 0.9 ),
348349 weight_decay = self ._discriminator_decay ,
@@ -395,10 +396,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
395396 real_cat = real
396397 fake_cat = fakeact
397398
398- y_fake = discriminator (fake_cat )
399- y_real = discriminator (real_cat )
399+ y_fake = self . _discriminator (fake_cat )
400+ y_real = self . _discriminator (real_cat )
400401
401- pen = discriminator .calc_gradient_penalty (
402+ pen = self . _discriminator .calc_gradient_penalty (
402403 real_cat , fake_cat , self ._device , self .pac
403404 )
404405 loss_d = - (torch .mean (y_real ) - torch .mean (y_fake ))
@@ -423,9 +424,9 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
423424 fakeact = self ._apply_activate (fake )
424425
425426 if c1 is not None :
426- y_fake = discriminator (torch .cat ([fakeact , c1 ], dim = 1 ))
427+ y_fake = self . _discriminator (torch .cat ([fakeact , c1 ], dim = 1 ))
427428 else :
428- y_fake = discriminator (fakeact )
429+ y_fake = self . _discriminator (fakeact )
429430
430431 if condvec is None :
431432 cross_entropy = 0
@@ -520,3 +521,5 @@ def set_device(self, device):
520521 self ._device = device
521522 if self ._generator is not None :
522523 self ._generator .to (self ._device )
524+ if self ._discriminator is not None :
525+ self ._discriminator .to (self ._device )
0 commit comments