Skip to content

Similar problem in the denoising section, seeking help on what part of my code could be wrong. #37

@MahanVeisi8

Description

@MahanVeisi8

I fixed the problem, thank you!

Hi dear Ariz Mohammadi,

May I kindly ask how you managed to solve the issue you mentioned? I am currently facing a somewhat similar problem and would deeply appreciate your insights.

I’m training a model on the CIFAR-10 dataset. The VQVAE part works well, and I’ve checked the restoration—it looks fine. However, when it comes to training the UNet LDM, I initially struggled to find a learning rate that suited my dataset. After some trial and error, I used the Celeb configuration as a reference, adjusted it for my case, and managed to achieve a loss of around 0.08.
Actaully it may even be better to show you the config clas:
`
class Config:
def init(self):
self.config = {
"dataset_params": {
"im_path": "/kaggle/input/cifar10-64x64-resized-via-cai-super-resolution/cifar10-64/train",
"im_channels": 3,
"im_size": 64,
"name": "CIFAR10"
},
"diffusion_params": {
"num_timesteps": 1000,
"beta_start": 0.00085,
"beta_end": 0.012
},
"ldm_params": {
"down_channels": [256, 384, 512, 768],
"mid_channels": [768, 512],
"down_sample": [True, False, False],
"attn_down": [True, True, True],
"time_emb_dim": 512,
"norm_channels": 32,
"num_heads": 16,
"conv_out_channels": 128,
"num_down_layers": 3,
"num_mid_layers": 3,
"num_up_layers": 3,
"condition_config": {
"condition_types": ["class"],
"class_condition_config": {
"num_classes": 10,
"cond_drop_prob": 0.1
}
}
},
"autoencoder_params": {
"z_channels": 128,
"codebook_size": 100,
"down_channels": [32, 64, 128],
"mid_channels": [128, 128],
"down_sample": [True, True],
"attn_down": [False, False],
"norm_channels": 32,
"num_heads": 16,
"num_down_layers": 1,
"num_mid_layers": 1,
"num_up_layers": 1
},
"train_params": {
"seed": 1111,
"task_name": "/kaggle/working/VQVAE/models",
"ldm_batch_size": 64,
"autoencoder_batch_size": 64,
"disc_start": 1000,
"disc_weight": 0.5,
"codebook_weight": 1,
"commitment_beta": 0.25,
"perceptual_weight": 1,
"kl_weight": 0.000005,
"ldm_epochs": 100,
"autoencoder_epochs": 20,
"num_samples": 36,
"num_grid_rows": 6,
"ldm_lr": 0.00005,
"autoencoder_lr": 0.0001,
"autoencoder_acc_steps": 1,
"autoencoder_img_save_steps": 8,
"save_latents": True,
"cf_guidance_scale": 1.0,
"vae_latent_dir_name": "/vae_latents",
"vqvae_latent_dir_name": "/vqvae_latents",
"ldm_ckpt_name": "cifar10_ddpm_ckpt_class_cond.pth",
"vqvae_autoencoder_ckpt_name": "/kaggle/input/vqvae/pytorch/default/1/cifar10_vqvae_autoencoder_ckpt.pth",
"vqvae_discriminator_ckpt_name": "/kaggle/input/vqvae/pytorch/default/1/cifar10_vqvae_discriminator_ckpt.pth",
"lpips_weights_path": "/kaggle/input/kaggleinputdata/pytorch/default/1/vgg.pth",
}
}

`

Despite this, the denoising and sampling results are far from satisfactory. It has been quite discouraging, especially after dedicating about three weeks to this work.

Image

As you can see, i think my results of X0 are somehow alike yourself's even if the datasets are different.
Here is the sampling code:
`py
import torch
import torchvision
import argparse
import yaml
import os
from torchvision.utils import make_grid
from PIL import Image
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def sample(model, scheduler, train_config, diffusion_model_config,
autoencoder_model_config, diffusion_config, dataset_config, vae):
r"""
Sample stepwise by going backward one timestep at a time.
We save the x0 predictions
"""
im_size = dataset_config['im_size'] // 2 ** sum(autoencoder_model_config['down_sample'])

########### Sample random noise latent ##########
xt = torch.randn((train_config['num_samples'],
                  autoencoder_model_config['z_channels'],
                  im_size,
                  im_size)).to(device)
###############################################

############# Validate the config #################
condition_config = get_config_value(diffusion_model_config, key='condition_config', default_value=None)
assert condition_config is not None, ("This sampling script is for class conditional "
                                      "but no conditioning config found")
condition_types = get_config_value(condition_config, 'condition_types', [])
assert 'class' in condition_types, ("This sampling script is for class conditional "
                                      "but no class condition found in config")
validate_class_config(condition_config)
###############################################

############ Create Conditional input ###############
num_classes = condition_config['class_condition_config']['num_classes']
sample_classes = torch.randint(0, num_classes, (train_config['num_samples'], ))
print('Generating images for {}'.format(list(sample_classes.numpy())))
cond_input = {
    'class': torch.nn.functional.one_hot(sample_classes, num_classes).to(device)
}
# Unconditional input for classifier free guidance
uncond_input = {
    'class': cond_input['class'] * 0
}
###############################################

# By default classifier free guidance is disabled
# Change value in config or change default value here to enable it
cf_guidance_scale = get_config_value(train_config, 'cf_guidance_scale', 1.0)

################# Sampling Loop ########################
for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
    # Get prediction of noise
    t = (torch.ones((xt.shape[0],))*i).long().to(device)
    noise_pred_cond = model(xt, t, cond_input)
    
    if cf_guidance_scale > 1:
        noise_pred_uncond = model(xt, t, uncond_input)
        noise_pred = noise_pred_uncond + cf_guidance_scale*(noise_pred_cond - noise_pred_uncond)
    else:
        noise_pred = noise_pred_cond
    
    # Use scheduler to get x0 and xt-1
    xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
    
    # if i == 0:
    #     # Decode ONLY the final image to save time
    #     ims = vae.decode(xt)
    # else:
    #     ims = x0_pred
    
    ims = vae.decode(xt)
    ims = torch.clamp(ims, -1., 1.).detach().cpu()
    ims = (ims + 1) / 2
    grid = make_grid(ims, nrow=1)
    img = torchvision.transforms.ToPILImage()(grid)
    
    if not os.path.exists(os.path.join(train_config['task_name'], 'cond_class_samples')):
        os.mkdir(os.path.join(train_config['task_name'], 'cond_class_samples'))
    img.save(os.path.join(train_config['task_name'], 'cond_class_samples', 'x0_{}.png'.format(i)))
    img.close()
##############################################################

def infer():
# Read the config file #
config_instance = Config()
config = config_instance.config

diffusion_config = config['diffusion_params']
dataset_config = config['dataset_params']
diffusion_model_config = config['ldm_params']
autoencoder_model_config = config['autoencoder_params']
train_config = config['train_params']
print(config)

########## Create the noise scheduler #############
scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
                                 beta_start=diffusion_config['beta_start'],
                                 beta_end=diffusion_config['beta_end'])
###############################################

########## Load Unet #############
model = Unet(im_channels=autoencoder_model_config['z_channels'],
             model_config=diffusion_model_config).to(device)
model.eval()
if os.path.exists(os.path.join(train_config['task_name'],
                               train_config['ldm_ckpt_name'])):
    print('Loaded unet checkpoint')
    model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
                                                  train_config['ldm_ckpt_name']),
                                     map_location=device))
else:
    raise Exception('Model checkpoint {} not found'.format(os.path.join(train_config['task_name'],
                                                                        train_config['ldm_ckpt_name'])))
#####################################

# Create output directories
if not os.path.exists(train_config['task_name']):
    os.mkdir(train_config['task_name'])

########## Load VQVAE #############
vae = VQVAE(im_channels=dataset_config['im_channels'],
            model_config=autoencoder_model_config).to(device)
vae.eval()

# Load vae if found
if os.path.exists(os.path.join(train_config['vqvae_autoencoder_ckpt_name'])):
    print('Loaded vae checkpoint')
    # Load the checkpoint first
    checkpoint = torch.load(train_config['vqvae_autoencoder_ckpt_name'], map_location=device)
    # Now load the state_dict into the model
    vae.load_state_dict(checkpoint)
else:
    raise Exception('VAE checkpoint {} not found'.format(os.path.join(train_config['vqvae_autoencoder_ckpt_name'])))
#####################################

with torch.no_grad():
    sample(model, scheduler, train_config, diffusion_model_config,
           autoencoder_model_config, diffusion_config, dataset_config, vae)

infer()

`

I’ve attached an example image generated after training using my current setup. It doesn’t look good at all. Any guidance you can offer—whether related to your loss values or how you resolved your own issue—would mean a lot to me.

Thank you so much in advance for your time and help!

Originally posted by @MahanVeisi8 in #36

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions