diff --git a/config.py b/config.py index 60971e420..882308cef 100644 --- a/config.py +++ b/config.py @@ -216,10 +216,12 @@ def create_parser(): help='Num of cycles for cosine decay and cyclic (default=1)') group.add_argument('--cycle_decay', type=float, default=1.0, help='Decay rate of lr max in each cosine cycle (default=1.0)') + group.add_argument('--layer_decay', type=float, default=None, + help='layer(model) decay rate of lr (default=None)') # Loss parameters group = parser.add_argument_group('Loss parameters') - group.add_argument('--loss', type=str, default='CE', choices=['BCE', 'CE'], + group.add_argument('--loss', type=str, default='CE', choices=['BCE', 'CE', 'None'], help='Type of loss, BCE (BinaryCrossEntropy) or CE (CrossEntropy) (default="CE")') group.add_argument('--label_smoothing', type=float, default=0.0, help='Use label smoothing (default=0.0)') @@ -256,8 +258,26 @@ def create_parser(): group.add_argument('--drop_overflow_update', type=bool, default=False, help='Whether to execute optimizer if there is an overflow (default=False)') + # pre-train + group = parser.add_argument_group('pre-train') + group.add_argument('--pretrain_resize', type=list, default=[224], + help='Crop the size of the image for pre-training.' + 'The length of list should be 2 if tokenizer is required. (default=[224])') + group.add_argument('--pretrain_interpolations', type=list, default=['bicubic', 'bilinear'], + help='Image interpolation mode for resize operator for pre-trainin') + group.add_argument('--tokenizer', type=str, default=None, + help='Name of tokenizer model for pre-train') + group.add_argument('--tokenizer_ckpt_path', type=str, default='', + help='Initialize tokenizer model from this checkpoint') + group.add_argument('--mask_type', type=str, default='random', + choices=['block_wise', 'patch_aligned', 'random'], + help='Type of mask generator') + group.add_argument('--mask_ratio', type=float, default=0.75, + help='Masking ratio') + group.add_argument('--mask_patch_size', type=int, default=32, + help='Size of mask patch') + return parser_config, parser -# fmt: on def _check_cfgs_in_parser(cfgs: dict, parser: argparse.ArgumentParser): diff --git a/configs/mae/mae_b_16_224_finetune_ascend.yaml b/configs/mae/mae_b_16_224_finetune_ascend.yaml new file mode 100644 index 000000000..7ddfcbcd9 --- /dev/null +++ b/configs/mae/mae_b_16_224_finetune_ascend.yaml @@ -0,0 +1,58 @@ +# system +mode: 0 +distribute: True +num_parallel_workers: 8 +val_while_train: True + +# dataset +dataset: "imagenet" +data_dir: "/path/to/imagenet" +shuffle: True +dataset_download: False +batch_size: 32 +drop_remainder: True + +# augmentation +image_resize: 224 +scale: [0.08, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +interpolation: "bicubic" +auto_augment: "randaug-m9-mstd0.5-inc1" +re_prob: 0.25 +mixup: 0.8 +cutmix: 1.0 +re_value: "random" + +# model +model: "mae_b_16_224_finetune" +drop_rate: 0.0 +drop_path_rate: 0.1 +pretrained: False +ckpt_path: "" +keep_checkpoint_max: 10 +ckpt_save_dir: "./ckpt" +epoch_size: 100 +dataset_sink_mode: True +amp_level: "O2" + +# loss +loss: "CE" +loss_scale: 1024.0 +label_smoothing: 0.1 + +# lr scheduler +scheduler: "warmup_cosine_decay" +lr: 5e-4 +min_lr: 1e-6 +warmup_epochs: 5 +warmup_factor: 0 +decay_epochs: 95 +layer_decay: 0.65 +lr_epoch_stair: False + +# optimizer +opt: "adamw" +weight_decay: 0.05 +filter_bias_and_bn: True +use_nesterov: False diff --git a/configs/mae/mae_b_16_224_pretrain_ascend.yaml b/configs/mae/mae_b_16_224_pretrain_ascend.yaml new file mode 100644 index 000000000..601f64d60 --- /dev/null +++ b/configs/mae/mae_b_16_224_pretrain_ascend.yaml @@ -0,0 +1,57 @@ +# system +mode: 0 +distribute: True +num_parallel_workers: 8 + +# dataset +dataset: "imagenet" +data_dir: "/path/to/imagenet" +shuffle: True +dataset_download: False +batch_size: 64 +drop_remainder: True + +# augmentation +scale: [0.2, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +color_jitter: [0.4, 0.4, 0.4] + +# model +model: "mae_b_16_224_pretrain" +drop_rate: 0.0 +drop_path_rate: 0.0 +pretrained: False +ckpt_path: "" +keep_checkpoint_max: 10 +ckpt_save_dir: "./ckpt" +epoch_size: 800 +dataset_sink_mode: True +amp_level: "O2" +clip_grad: True +clip_value: 3.0 + +# loss +loss: "None" +loss_scale: 1024.0 + +# lr scheduler +scheduler: "warmup_cosine_decay" +lr: 1.5e-4 +min_lr: 0 +warmup_epochs: 40 +warmup_factor: 0 +decay_epochs: 760 +lr_epoch_stair: False + +# optimizer +opt: "adamw" +weight_decay: 0.05 +filter_bias_and_bn: True +use_nesterov: False + +# pre-train +pretrain_resize: [224] +pretrain_interpolations: ["bicubic"] +mask_type: "random" +mask_ratio: 0.75 diff --git a/examples/ssl/finetune.py b/examples/ssl/finetune.py new file mode 100644 index 000000000..2db813a06 --- /dev/null +++ b/examples/ssl/finetune.py @@ -0,0 +1,321 @@ +""" Model training pipeline """ +import logging +import os +import sys + +mindcv_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.append(mindcv_path) + +import mindspore as ms +from mindspore import Tensor +from mindspore.communication import get_group_size, get_rank, init + +from mindcv.data import create_dataset, create_loader, create_transforms +from mindcv.loss import create_loss +from mindcv.models import create_model +from mindcv.optim import create_finetune_optimizer +from mindcv.scheduler import create_scheduler +from mindcv.utils import ( + AllReduceSum, + StateMonitor, + create_trainer, + get_metrics, + require_customized_train_step, + set_logger, + set_seed, +) + +from config import parse_args, save_args # isort: skip + +logger = logging.getLogger("mindcv.train") + + +def main(): + args = parse_args() + ms.set_context(mode=args.mode) + if args.distribute: + init() + rank_id, device_num = get_rank(), get_group_size() + ms.set_auto_parallel_context( + device_num=device_num, + parallel_mode="data_parallel", + gradients_mean=True, + # we should but cannot set parameter_broadcast=True, which will cause error on gpu. + ) + all_reduce = AllReduceSum() + else: + device_num, rank_id = None, None + all_reduce = None + + set_seed(args.seed) + set_logger(name="mindcv", output_dir=args.ckpt_save_dir, rank=rank_id, color=False) + logger.info( + "We recommend installing `termcolor` via `pip install termcolor` " + "and setup logger by `set_logger(..., color=True)`" + ) + + # create dataset + dataset_train = create_dataset( + name=args.dataset, + root=args.data_dir, + split=args.train_split, + shuffle=args.shuffle, + num_samples=args.num_samples, + num_shards=device_num, + shard_id=rank_id, + num_parallel_workers=args.num_parallel_workers, + download=args.dataset_download, + num_aug_repeats=args.aug_repeats, + ) + + if args.num_classes is None: + num_classes = dataset_train.num_classes() + else: + num_classes = args.num_classes + + # create transforms + num_aug_splits = 0 + if args.aug_splits > 0: + assert args.aug_splits == 3, "Currently, only support 3 splits of augmentation" + assert args.auto_augment is not None, "aug_splits should be set with one auto_augment" + num_aug_splits = args.aug_splits + transform_list = create_transforms( + dataset_name=args.dataset, + is_training=True, + image_resize=args.image_resize, + scale=args.scale, + ratio=args.ratio, + hflip=args.hflip, + vflip=args.vflip, + color_jitter=args.color_jitter, + interpolation=args.interpolation, + auto_augment=args.auto_augment, + mean=args.mean, + std=args.std, + re_prob=args.re_prob, + re_scale=args.re_scale, + re_ratio=args.re_ratio, + re_value=args.re_value, + re_max_attempts=args.re_max_attempts, + separate=num_aug_splits > 0, + ) + + # load dataset + loader_train = create_loader( + dataset=dataset_train, + batch_size=args.batch_size, + drop_remainder=args.drop_remainder, + is_training=True, + mixup=args.mixup, + cutmix=args.cutmix, + cutmix_prob=args.cutmix_prob, + num_classes=num_classes, + transform=transform_list, + num_parallel_workers=args.num_parallel_workers, + separate=num_aug_splits > 0, + ) + num_batches = loader_train.get_dataset_size() + train_count = dataset_train.get_dataset_size() + if args.distribute: + train_count = all_reduce(Tensor(train_count, ms.int32)) + + if args.val_while_train: + dataset_eval = create_dataset( + name=args.dataset, + root=args.data_dir, + split=args.val_split, + num_shards=device_num, + shard_id=rank_id, + num_parallel_workers=args.num_parallel_workers, + download=args.dataset_download, + ) + + transform_list_eval = create_transforms( + dataset_name=args.dataset, + is_training=False, + image_resize=args.image_resize, + crop_pct=args.crop_pct, + interpolation=args.interpolation, + mean=args.mean, + std=args.std, + ) + + loader_eval = create_loader( + dataset=dataset_eval, + batch_size=args.batch_size, + drop_remainder=False, + is_training=False, + transform=transform_list_eval, + num_parallel_workers=args.num_parallel_workers, + ) + # validation dataset count + eval_count = dataset_eval.get_dataset_size() + if args.distribute: + eval_count = all_reduce(Tensor(eval_count, ms.int32)) + else: + loader_eval = None + eval_count = None + + # create model + network = create_model( + model_name=args.model, + num_classes=num_classes, + in_channels=args.in_channels, + drop_rate=args.drop_rate, + drop_path_rate=args.drop_path_rate, + pretrained=args.pretrained, + checkpoint_path=args.ckpt_path, + ema=args.ema, + ) + + num_params = sum([param.size for param in network.get_parameters()]) + + # create loss + loss = create_loss( + name=args.loss, + reduction=args.reduction, + label_smoothing=args.label_smoothing, + aux_factor=args.aux_factor, + ) + + # create learning rate schedule + lr_scheduler = create_scheduler( + num_batches, + scheduler=args.scheduler, + lr=args.lr, + min_lr=args.min_lr, + warmup_epochs=args.warmup_epochs, + warmup_factor=args.warmup_factor, + decay_epochs=args.decay_epochs, + decay_rate=args.decay_rate, + milestones=args.multi_step_decay_milestones, + num_epochs=args.epoch_size, + num_cycles=args.num_cycles, + cycle_decay=args.cycle_decay, + lr_epoch_stair=args.lr_epoch_stair, + ) + + # resume training if ckpt_path is given + if args.ckpt_path != "" and args.resume_opt: + opt_ckpt_path = os.path.join(args.ckpt_save_dir, f"optim_{args.model}.ckpt") + else: + opt_ckpt_path = "" + + # create optimizer + # TODO: consistent naming opt, name, dataset_name + if ( + args.loss_scale_type == "fixed" + and args.drop_overflow_update is False + and not require_customized_train_step(args.ema, args.clip_grad, args.gradient_accumulation_steps) + ): + optimizer_loss_scale = args.loss_scale + else: + optimizer_loss_scale = 1.0 + optimizer = create_finetune_optimizer( + network, + opt=args.opt, + lr=lr_scheduler, + weight_decay=args.weight_decay, + momentum=args.momentum, + nesterov=args.use_nesterov, + filter_bias_and_bn=args.filter_bias_and_bn, + loss_scale=optimizer_loss_scale, + checkpoint_path=opt_ckpt_path, + eps=args.eps, + scale=args.layer_decay, + ) + + # Define eval metrics. + metrics = get_metrics(num_classes) + + # create trainer + trainer = create_trainer( + network, + loss, + optimizer, + metrics, + amp_level=args.amp_level, + amp_cast_list=args.amp_cast_list, + loss_scale_type=args.loss_scale_type, + loss_scale=args.loss_scale, + drop_overflow_update=args.drop_overflow_update, + ema=args.ema, + ema_decay=args.ema_decay, + clip_grad=args.clip_grad, + clip_value=args.clip_value, + gradient_accumulation_steps=args.gradient_accumulation_steps, + ) + + # callback + # save checkpoint, summary training loss + # record val acc and do model selection if val dataset is available + begin_step = 0 + begin_epoch = 0 + if args.ckpt_path != "": + begin_step = optimizer.global_step.asnumpy()[0] + begin_epoch = args.ckpt_path.split("/")[-1].split("-")[1].split("_")[0] + begin_epoch = int(begin_epoch) + + summary_dir = f"./{args.ckpt_save_dir}/summary" + assert ( + args.ckpt_save_policy != "top_k" or args.val_while_train is True + ), "ckpt_save_policy is top_k, val_while_train must be True." + state_cb = StateMonitor( + trainer, + model_name=args.model, + model_ema=args.ema, + last_epoch=begin_epoch, + dataset_sink_mode=args.dataset_sink_mode, + dataset_val=loader_eval, + metric_name=list(metrics.keys()), + val_interval=args.val_interval, + ckpt_save_dir=args.ckpt_save_dir, + ckpt_save_interval=args.ckpt_save_interval, + ckpt_save_policy=args.ckpt_save_policy, + ckpt_keep_max=args.keep_checkpoint_max, + summary_dir=summary_dir, + log_interval=args.log_interval, + rank_id=rank_id, + device_num=device_num, + ) + + callbacks = [state_cb] + essential_cfg_msg = "\n".join( + [ + "Essential Experiment Configurations:", + f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", + f"Distributed mode: {args.distribute}", + f"Number of devices: {device_num if device_num is not None else 1}", + f"Number of training samples: {train_count}", + f"Number of validation samples: {eval_count}", + f"Number of classes: {num_classes}", + f"Number of batches: {num_batches}", + f"Batch size: {args.batch_size}", + f"Auto augment: {args.auto_augment}", + f"MixUp: {args.mixup}", + f"CutMix: {args.cutmix}", + f"Model: {args.model}", + f"Model parameters: {num_params}", + f"Number of epochs: {args.epoch_size}", + f"Optimizer: {args.opt}", + f"Learning rate: {args.lr}", + f"LR Scheduler: {args.scheduler}", + f"Momentum: {args.momentum}", + f"Weight decay: {args.weight_decay}", + f"Auto mixed precision: {args.amp_level}", + f"Loss scale: {args.loss_scale}({args.loss_scale_type})", + ] + ) + logger.info(essential_cfg_msg) + save_args(args, os.path.join(args.ckpt_save_dir, f"{args.model}.yaml"), rank_id) + + if args.ckpt_path != "": + logger.info(f"Resume training from {args.ckpt_path}, last step: {begin_step}, last epoch: {begin_epoch}") + else: + logger.info("Start training") + + trainer.train(args.epoch_size, loader_train, callbacks=callbacks, dataset_sink_mode=args.dataset_sink_mode) + + +if __name__ == "__main__": + main() diff --git a/examples/ssl/pretrain.py b/examples/ssl/pretrain.py new file mode 100644 index 000000000..ba28d6445 --- /dev/null +++ b/examples/ssl/pretrain.py @@ -0,0 +1,270 @@ +""" Model pre-training pipeline """ +import logging +import os +import sys + +mindcv_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.append(mindcv_path) + +import mindspore as ms +from mindspore import Tensor +from mindspore.communication import get_group_size, get_rank, init + +from mindcv.data import create_dataset, create_loader_pretrain, create_transforms_pretrain +from mindcv.loss import create_loss +from mindcv.models import create_model +from mindcv.optim import create_pretrain_optimizer +from mindcv.scheduler import create_scheduler +from mindcv.utils import AllReduceSum, StateMonitor, create_trainer, require_customized_train_step, set_logger, set_seed + +from config import parse_args, save_args # isort: skip + +logger = logging.getLogger("mindcv.pre-train") + + +def main(): + args = parse_args() + ms.set_context(mode=args.mode) + if args.distribute: + init() + rank_id, device_num = get_rank(), get_group_size() + ms.set_auto_parallel_context( + device_num=device_num, + parallel_mode="data_parallel", + gradients_mean=True, + # we should but cannot set parameter_broadcast=True, which will cause error on gpu. + ) + all_reduce = AllReduceSum() + else: + rank_id, device_num = None, None + all_reduce = None + + set_seed(args.seed) + set_logger(name="mindcv", output_dir=args.ckpt_save_dir, rank=rank_id, color=False) + logger.info( + "We recommend installing `termcolor` via `pip install termcolor` " + "and setup logger by `set_logger(..., color=True)`" + ) + + # create dataset + dataset_train = create_dataset( + name=args.dataset, + root=args.data_dir, + split=args.train_split, + shuffle=args.shuffle, + num_samples=args.num_samples, + num_shards=device_num, + shard_id=rank_id, + num_parallel_workers=args.num_parallel_workers, + download=args.dataset_download, + num_aug_repeats=args.aug_repeats, + ) + if args.num_classes is None: + num_classes = dataset_train.num_classes() + else: + num_classes = args.num_classes + + # create transforms + patch_size = int(args.model.split("_")[2]) # need to be more robust + transform_list = create_transforms_pretrain( + dataset_name=args.dataset, + resize_list=args.pretrain_resize, + tokenizer=args.tokenizer, + scale=args.scale, + ratio=args.ratio, + hflip=args.hflip, + color_jitter=args.color_jitter, + interpolations=args.pretrain_interpolations.copy(), + mean=args.mean, + std=args.std, + mask_type=args.mask_type, + mask_ratio=args.mask_ratio, + patch_size=patch_size, + mask_patch_size=args.mask_patch_size, + ) + + # load dataset + loader_train = create_loader_pretrain( + dataset=dataset_train, + batch_size=args.batch_size, + drop_remainder=args.drop_remainder, + transform=transform_list, + num_parallel_workers=args.num_parallel_workers, + ) + + loader_eval = None + + num_batches = loader_train.get_dataset_size() + # Train dataset count + train_count = dataset_train.get_dataset_size() + if args.distribute: + train_count = all_reduce(Tensor(train_count, ms.int32)) + + # create model + network = create_model( + model_name=args.model, + drop_rate=args.drop_rate, + drop_path_rate=args.drop_path_rate, + mask_ratio=args.mask_ratio, + pretrained=args.pretrained, + checkpoint_path=args.ckpt_path, + ema=args.ema, + ) + + if args.tokenizer is not None: + tokenizer = create_model(model_name=args.tokenizer, checkpoint_path=args.tokenizer_ckpt_path) + else: + tokenizer = None + + num_params = sum([param.size for param in network.get_parameters()]) + # create loss + if args.loss != "None": + loss = create_loss( + name=args.loss, + reduction=args.reduction, + label_smoothing=args.label_smoothing, + aux_factor=args.aux_factor, + ) + else: + loss = None + + # create learning rate schedule + lr_scheduler = create_scheduler( + num_batches, + scheduler=args.scheduler, + lr=args.lr, + min_lr=args.min_lr, + warmup_epochs=args.warmup_epochs, + warmup_factor=args.warmup_factor, + decay_epochs=args.decay_epochs, + decay_rate=args.decay_rate, + milestones=args.multi_step_decay_milestones, + num_epochs=args.epoch_size, + lr_epoch_stair=args.lr_epoch_stair, + num_cycles=args.num_cycles, + cycle_decay=args.cycle_decay, + ) + + # resume training if ckpt_path is given + if args.ckpt_path != "" and args.resume_opt: + opt_ckpt_path = os.path.join(args.ckpt_save_dir, f"optim_{args.model}.ckpt") + else: + opt_ckpt_path = "" + + # create optimizer + # TODO: consistent naming opt, name, dataset_name + if ( + args.loss_scale_type == "fixed" + and args.drop_overflow_update is False + and not require_customized_train_step(args.ema, args.clip_grad, args.gradient_accumulation_steps) + ): + optimizer_loss_scale = args.loss_scale + else: + optimizer_loss_scale = 1.0 + optimizer = create_pretrain_optimizer( + network, + opt=args.opt, + lr=lr_scheduler, + weight_decay=args.weight_decay, + momentum=args.momentum, + nesterov=args.use_nesterov, + filter_bias_and_bn=args.filter_bias_and_bn, + loss_scale=optimizer_loss_scale, + checkpoint_path=opt_ckpt_path, + eps=args.eps, + ) + + # Define eval metrics. + metrics = None + + # create trainer + trainer = create_trainer( + network, + loss, + optimizer, + metrics, + amp_level=args.amp_level, + amp_cast_list=args.amp_cast_list, + loss_scale_type=args.loss_scale_type, + loss_scale=args.loss_scale, + drop_overflow_update=args.drop_overflow_update, + ema=args.ema, + ema_decay=args.ema_decay, + clip_grad=args.clip_grad, + clip_value=args.clip_value, + gradient_accumulation_steps=args.gradient_accumulation_steps, + tokenizer=tokenizer, + ) + + # callback + # save checkpoint, summary training loss + # record val acc and do model selection if val dataset is available + begin_step = 0 + begin_epoch = 0 + if args.ckpt_path != "": + begin_step = optimizer.global_step.asnumpy()[0] + begin_epoch = args.ckpt_path.split("/")[-1].split("-")[1].split("_")[0] + begin_epoch = int(begin_epoch) + + summary_dir = f"./{args.ckpt_save_dir}/summary" + assert ( + args.ckpt_save_policy != "top_k" or args.val_while_train is True + ), "ckpt_save_policy is top_k, val_while_train must be True." + state_cb = StateMonitor( + trainer, + model_name=args.model, + model_ema=args.ema, + last_epoch=begin_epoch, + dataset_sink_mode=args.dataset_sink_mode, + dataset_val=loader_eval, + metric_name=[], + val_interval=args.val_interval, + ckpt_save_dir=args.ckpt_save_dir, + ckpt_save_interval=args.ckpt_save_interval, + ckpt_save_policy=args.ckpt_save_policy, + ckpt_keep_max=args.keep_checkpoint_max, + summary_dir=summary_dir, + log_interval=args.log_interval, + rank_id=rank_id, + device_num=device_num, + ) + + callbacks = [state_cb] + essential_cfg_msg = "\n".join( + [ + "Essential Experiment Configurations:", + f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", + f"Distributed mode: {args.distribute}", + f"Number of devices: {device_num if device_num is not None else 1}", + f"Number of training samples: {train_count}", + f"Number of classes: {num_classes}", + f"Number of batches: {num_batches}", + f"Batch size: {args.batch_size}", + f"Auto augment: {args.auto_augment}", + f"MixUp: {args.mixup}", + f"CutMix: {args.cutmix}", + f"Model: {args.model}", + f"Model parameters: {num_params}", + f"Number of epochs: {args.epoch_size}", + f"Optimizer: {args.opt}", + f"Learning rate: {args.lr}", + f"LR Scheduler: {args.scheduler}", + f"Momentum: {args.momentum}", + f"Weight decay: {args.weight_decay}", + f"Auto mixed precision: {args.amp_level}", + f"Loss scale: {args.loss_scale}({args.loss_scale_type})", + ] + ) + logger.info(essential_cfg_msg) + save_args(args, os.path.join(args.ckpt_save_dir, f"{args.model}.yaml"), rank_id) + + if args.ckpt_path != "": + logger.info(f"Resume training from {args.ckpt_path}, last step: {begin_step}, last epoch: {begin_epoch}") + else: + logger.info("Start training") + + trainer.train(args.epoch_size, loader_train, callbacks=callbacks, dataset_sink_mode=args.dataset_sink_mode) + + +if __name__ == "__main__": + main() diff --git a/mindcv/data/__init__.py b/mindcv/data/__init__.py index 4ecf5e5d0..02a91024b 100644 --- a/mindcv/data/__init__.py +++ b/mindcv/data/__init__.py @@ -1,12 +1,21 @@ """ Data processing """ -from . import dataset_download, dataset_factory, loader, transforms_factory +from . import ( + dataset_download, + dataset_factory, + loader, + pretrain_loader, + pretrain_transforms_factory, + transforms_factory, +) from .auto_augment import * from .constants import * from .dataset_download import * from .dataset_factory import * from .loader import * +from .pretrain_loader import * +from .pretrain_transforms_factory import * from .transforms_factory import * __all__ = [] @@ -14,3 +23,5 @@ __all__.extend(dataset_factory.__all__) __all__.extend(loader.__all__) __all__.extend(transforms_factory.__all__) +__all__.extend(pretrain_loader.__all__) +__all__.extend(pretrain_transforms_factory.__all__) diff --git a/mindcv/data/mask_generator/__init__.py b/mindcv/data/mask_generator/__init__.py new file mode 100644 index 000000000..bf55764ba --- /dev/null +++ b/mindcv/data/mask_generator/__init__.py @@ -0,0 +1,5 @@ +from . import mask_factory +from .mask_factory import create_mask_generator + +__all__ = [] +__all__.extend(mask_factory.__all__) diff --git a/mindcv/data/mask_generator/block_wise_mask.py b/mindcv/data/mask_generator/block_wise_mask.py new file mode 100644 index 000000000..f6fc00cb3 --- /dev/null +++ b/mindcv/data/mask_generator/block_wise_mask.py @@ -0,0 +1,73 @@ +import math +import random +from typing import Optional, Tuple + +import numpy as np + + +class BlockWiseMaskGenerator: + def __init__( + self, + input_size: int = 224, + model_patch_size: int = 16, + mask_ratio: float = 0.4, + min_num_patches: int = 4, + max_num_patches: Optional[int] = None, + min_aspect: int = 0.3, + max_aspect: Optional[int] = None, + ): + assert input_size % model_patch_size == 0 + + grid_size = input_size // model_patch_size + self.height, self.width = (grid_size, grid_size) + + num_masking_patches = int(np.ceil(grid_size**2 * mask_ratio)) + self.num_masking_patches = num_masking_patches + + self.min_num_patches = min_num_patches + self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches + + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + def _get_shape(self) -> Tuple[int, int]: + return self.height, self.width + + def _mask(self, mask: np.ndarray, max_mask_patches: int): + delta = 0 + for _ in range(10): + target_area = random.uniform(self.min_num_patches, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < self.width and h < self.height: + top = random.randint(0, self.height - h) + left = random.randint(0, self.width - w) + + num_masked = mask[top : top + h, left : left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + + if delta > 0: + break + return delta + + def __call__(self) -> np.ndarray: + mask = np.zeros(shape=self._get_shape(), dtype=np.int32) + mask_count = 0 + while mask_count < self.num_masking_patches: + max_mask_patches = self.num_masking_patches - mask_count + max_mask_patches = min(max_mask_patches, self.max_num_patches) + + delta = self._mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + return mask diff --git a/mindcv/data/mask_generator/mask_factory.py b/mindcv/data/mask_generator/mask_factory.py new file mode 100644 index 000000000..de632e9d5 --- /dev/null +++ b/mindcv/data/mask_generator/mask_factory.py @@ -0,0 +1,20 @@ +from .block_wise_mask import BlockWiseMaskGenerator +from .patch_aligned_mask import PatchAlignedMaskGenerator +from .random_mask import RandomMaskGenerator + +__all__ = ["create_mask_generator"] + + +def create_mask_generator( + mask_name: str, input_size: int = 224, patch_size: int = 16, mask_ratio: float = 0.6, **kwargs +): + if mask_name == "random": + mask_generator = RandomMaskGenerator(input_size, patch_size, mask_ratio) + elif mask_name == "block_wise": + mask_generator = BlockWiseMaskGenerator(input_size, patch_size, mask_ratio) + elif mask_name == "patch_aligned": + mask_generator = PatchAlignedMaskGenerator(input_size, patch_size, mask_ratio, **kwargs) + else: + raise NotImplementedError(f"{mask_name} mask generator is not implemented.") + + return mask_generator diff --git a/mindcv/data/mask_generator/patch_aligned_mask.py b/mindcv/data/mask_generator/patch_aligned_mask.py new file mode 100644 index 000000000..651a43da5 --- /dev/null +++ b/mindcv/data/mask_generator/patch_aligned_mask.py @@ -0,0 +1,25 @@ +import numpy as np + + +class PatchAlignedMaskGenerator: + def __init__( + self, input_size: int = 192, model_patch_size: int = 4, mask_ratio: float = 0.6, mask_patch_size: int = 32 + ): + assert input_size % mask_patch_size == 0 + assert mask_patch_size % model_patch_size == 0 + + self.rand_size = input_size // mask_patch_size + self.scale = mask_patch_size // model_patch_size + + self.token_count = self.rand_size**2 + self.mask_count = int(np.ceil(self.token_count * mask_ratio)) + + def __call__(self): + mask_idx = np.random.permutation(self.token_count)[: self.mask_count] + mask = np.zeros(self.token_count, dtype=np.int32) + mask[mask_idx] = 1 + + mask = mask.reshape((self.rand_size, self.rand_size)) + mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) + + return mask diff --git a/mindcv/data/mask_generator/random_mask.py b/mindcv/data/mask_generator/random_mask.py new file mode 100644 index 000000000..1077e54dc --- /dev/null +++ b/mindcv/data/mask_generator/random_mask.py @@ -0,0 +1,18 @@ +import numpy as np + + +class RandomMaskGenerator: + def __init__(self, input_size: int = 224, model_patch_size: int = 16, mask_ratio: float = 0.75): + assert input_size % model_patch_size == 0 + + self.grid_size = input_size // model_patch_size + self.seq_len = self.grid_size**2 + self.mask_count = int(np.ceil(self.seq_len * mask_ratio)) + + def __call__(self): + mask_idx = np.random.permutation(self.seq_len)[: self.mask_count] + mask = np.zeros(self.seq_len, dtype=np.int32) + mask[mask_idx] = 1 + + mask = mask.reshape((self.grid_size, self.grid_size)) + return mask diff --git a/mindcv/data/pretrain_loader.py b/mindcv/data/pretrain_loader.py new file mode 100644 index 000000000..2f1e4dd1d --- /dev/null +++ b/mindcv/data/pretrain_loader.py @@ -0,0 +1,32 @@ +""" +Create dataloader for pre-training +""" +import inspect + +__all__ = ["create_loader_pretrain"] + + +def create_loader_pretrain( + dataset, batch_size, drop_remainder=False, transform=None, num_parallel_workers=None, python_multiprocessing=False +): + if transform is None: + raise ValueError("tranform should not be None for pre-training.") + + # notes: mindspore-2.0 delete parameter 'column_order' + sig = inspect.signature(dataset.map) + pass_column_order = False if "kwargs" in sig.parameters else True + + dataset = dataset.map( + operations=transform, + input_columns="image", + output_columns=transform.output_columns, + column_order=transform.output_columns if pass_column_order else None, + num_parallel_workers=num_parallel_workers, + python_multiprocessing=python_multiprocessing, + ) + if not pass_column_order: + dataset = dataset.project(transform.output_columns) + + dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder) + + return dataset diff --git a/mindcv/data/pretrain_transforms_factory.py b/mindcv/data/pretrain_transforms_factory.py new file mode 100644 index 000000000..7129ad8e4 --- /dev/null +++ b/mindcv/data/pretrain_transforms_factory.py @@ -0,0 +1,127 @@ +""" +Transform operation for pre-training +""" + +from typing import List, Tuple, Union + +from mindspore.dataset import vision +from mindspore.dataset.transforms import Compose +from mindspore.dataset.vision import Inter + +from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .mask_generator import create_mask_generator + +__all__ = ["create_transforms_pretrain"] + + +class RandomResizedCropWithTwoResolution: + def __init__(self, resize_list: List, interpolations: Union[List, Tuple], scale=(0.08, 1.0), ratio=(0.75, 1.333)): + self.first_transform = vision.RandomResizedCrop(resize_list[0], scale, ratio, interpolations[0]) + self.second_transform = vision.RandomResizedCrop(resize_list[1], scale, ratio, interpolations[1]) + + def __call__(self, img): + return self.first_transform(img), self.second_transform(img) + + +class TransformsForPretrain: + def __init__( + self, + resize_list: List = [224], + tokenizer: str = "dall-e", + mask_type: str = "block-wise", + scale=(0.08, 1.0), + ratio=(0.75, 1.333), + hflip=0.5, + color_jitter=None, + interpolations: Union[List, Tuple] = ["bicubic", "bilinear"], # lanczos is not implemented in MindSpore + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + patch_size: int = 16, + mask_ratio: float = 0.4, + **kwargs + ): + for i in range(len(interpolations)): + if hasattr(Inter, interpolations[i].upper()): + interpolations[i] = getattr(Inter, interpolations[i].upper()) + else: + interpolations[i] = Inter.BILINEAR + + if len(resize_list) == 2: + common_transform = [vision.Decode()] + if color_jitter is not None: + if isinstance(color_jitter, (list, tuple)): + # color jitter shoulf be a 3-tuple/list for brightness/contrast/saturation + # or 4 if also augmenting hue + assert len(color_jitter) in (3, 4) + else: + color_jitter = (float(color_jitter),) * 3 + common_transform += [vision.RandomColorAdjust(*color_jitter)] + + if hflip > 0.0: + common_transform += [vision.RandomHorizontalFlip(prob=hflip)] + + common_transform += [RandomResizedCropWithTwoResolution(resize_list, interpolations, scale, ratio)] + self.common_transform = Compose(common_transform) + + self.patch_transform = Compose([vision.Normalize(mean=mean, std=std), vision.HWC2CHW()]) + + if tokenizer == "dall_e": # beit + self.visual_token_transform = Compose([vision.ToTensor(), lambda x: (1 - 2 * 0.1) * x + 0.1]) + elif tokenizer == "vqkd": # beit v2 + self.visual_token_transform = Compose([vision.ToTensor()]) + elif tokenizer == "clip": # eva, eva-02 + self.visual_token_transform = Compose( + [ + vision.ToTensor(), + vision.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + is_hwc=False, + ), + ] + ) + + self.masked_position_generator = create_mask_generator( + mask_type, input_size=resize_list[0], patch_size=patch_size, mask_ratio=mask_ratio, **kwargs + ) + + self.output_columns = ["patch", "token", "mask"] + else: + self.common_transform = None + + patch_transform = [ + vision.RandomCropDecodeResize( + size=resize_list[0], scale=scale, ratio=ratio, interpolation=interpolations[0] + ) + ] + + if hflip > 0.0: + patch_transform += [vision.RandomHorizontalFlip(hflip)] + + patch_transform += [vision.Normalize(mean=mean, std=std), vision.HWC2CHW()] + self.patch_transform = Compose(patch_transform) + + self.masked_position_generator = create_mask_generator( + mask_type, input_size=resize_list[0], patch_size=patch_size, mask_ratio=mask_ratio, **kwargs + ) + + self.output_columns = ["patch", "mask"] + + def __call__(self, image): + if self.common_transform is not None: # for beit, beit v2, eva, eva-02 + patches, visual_tokens = self.common_transform(image) + patches = self.patch_transform(patches) + visual_tokens = self.visual_token_transform(visual_tokens) + masks = self.masked_position_generator() + return patches, visual_tokens, masks + else: + patches = self.patch_transform(image) # for MAE, SimMIM + masks = self.masked_position_generator() + return patches, masks + + +def create_transforms_pretrain(dataset_name="", **kwargs): + if dataset_name in ("imagenet", ""): + return TransformsForPretrain(**kwargs) + else: + raise NotImplementedError() diff --git a/mindcv/models/mae.py b/mindcv/models/mae.py index 4a5cf887e..706fac716 100644 --- a/mindcv/models/mae.py +++ b/mindcv/models/mae.py @@ -35,6 +35,9 @@ def _cfg(url="", **kwargs): default_cfgs = { + "mae_b_16_224_pretrain": _cfg(url=""), + "mae_l_16_224_pretrain": _cfg(url=""), + "mae_h_16_224_pretrain": _cfg(url=""), "mae_b_16_224_finetune": _cfg( url="https://download.mindspore.cn/toolkits/mindcv/mae/mae_b_16_224_finetune-cc05b899.ckpt" ), diff --git a/mindcv/models/vit.py b/mindcv/models/vit.py index 5a679df72..5f90c22fb 100644 --- a/mindcv/models/vit.py +++ b/mindcv/models/vit.py @@ -1,11 +1,12 @@ """ViT""" +import functools from typing import Callable, Optional import numpy as np import mindspore as ms from mindspore import Parameter, Tensor, nn, ops -from mindspore.common.initializer import HeUniform, TruncatedNormal, initializer +from mindspore.common.initializer import TruncatedNormal, XavierUniform, initializer from .helpers import load_pretrained from .layers.compatibility import Dropout @@ -114,11 +115,11 @@ def construct(self, x): q, k, v = self.unstack(qkv) q, k = self.q_norm(q), self.k_norm(k) + q = self.mul(q, self.scale**0.5) + k = self.mul(k, self.scale**0.5) attn = self.q_matmul_k(q, k) - attn = self.mul(attn, self.scale) - attn = attn.astype(ms.float32) - attn = ops.softmax(attn, axis=-1) + attn = ops.softmax(attn.astype(ms.float32), axis=-1).astype(attn.dtype) attn = self.attn_drop(attn) out = self.attn_matmul_v(attn, v) @@ -329,10 +330,13 @@ def no_weight_decay(self): return {'pos_embed', 'cls_token'} def _init_weights(self): + w = self.patch_embed.proj.weight + w_shape_flatted = (w.shape[0], functools.reduce(lambda x, y: x*y, w.shape[1:])) + w.set_data(initializer(XavierUniform(), w_shape_flatted, w.dtype).reshape(w.shape)) for _, cell in self.cells_and_names(): if isinstance(cell, nn.Dense): cell.weight.set_data( - initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) + initializer(XavierUniform(), cell.weight.shape, cell.weight.dtype) ) if cell.bias is not None: cell.bias.set_data( @@ -345,14 +349,6 @@ def _init_weights(self): cell.beta.set_data( initializer('zeros', cell.beta.shape, cell.beta.dtype) ) - elif isinstance(cell, nn.Conv2d): - cell.weight.set_data( - initializer(HeUniform(), cell.weight.shape, cell.weight.dtype) - ) - if cell.bias is not None: - cell.bias.set_data( - initializer("zeros", cell.bias.shape, cell.bias.dtype) - ) def _pos_embed(self, x): if self.dynamic_img_size or self.dynamic_img_pad: diff --git a/mindcv/optim/__init__.py b/mindcv/optim/__init__.py index 572a3c204..5a61bed91 100644 --- a/mindcv/optim/__init__.py +++ b/mindcv/optim/__init__.py @@ -1,6 +1,6 @@ """ optim init """ from . import optim_factory -from .optim_factory import create_optimizer +from .optim_factory import create_finetune_optimizer, create_optimizer, create_pretrain_optimizer __all__ = [] __all__.extend(optim_factory.__all__) diff --git a/mindcv/optim/optim_factory.py b/mindcv/optim/optim_factory.py index 7fe6bf282..754f79266 100644 --- a/mindcv/optim/optim_factory.py +++ b/mindcv/optim/optim_factory.py @@ -1,5 +1,6 @@ """ optim factory """ import os +from functools import partial from typing import Optional from mindspore import load_checkpoint, load_param_into_net, nn @@ -9,7 +10,7 @@ from .lion import Lion from .nadam import NAdam -__all__ = ["create_optimizer"] +__all__ = ["create_optimizer", "create_pretrain_optimizer", "create_finetune_optimizer"] def init_group_params(params, weight_decay): @@ -76,6 +77,219 @@ def create_optimizer( # if lr is not None: # opt_args.setdefault('lr', lr) + optimizer = get_optimizer( + params, opt_args, opt, lr, weight_decay, momentum, nesterov, loss_scale, schedule_decay, checkpoint_path, eps + ) + + return optimizer + + +def get_pretrain_param_groups(model, weight_decay, skip, skip_keywords): + """get pretrain param groups""" + has_decay, has_decay_name = [], [] + no_decay, no_decay_name = [], [] + + for param in model.trainable_params(): + if ( + len(param.shape) == 1 + or param.name.endswith(".bias") + or (param.name in skip) + or check_keywords_in_name(param.name, skip_keywords) + ): + no_decay.append(param) + no_decay_name.append(param.name) + else: + has_decay.append(param) + has_decay_name.append(param.name) + + return [ + {"params": has_decay, "weight_decay": weight_decay}, + {"params": no_decay, "weight_decay": 0.0}, + {"order_params": model.trainable_params()}, + ] + + +def create_pretrain_optimizer( + model, + opt: str = "adam", + lr: Optional[float] = 1e-3, + weight_decay: float = 0, + momentum: float = 0.9, + nesterov: bool = False, + filter_bias_and_bn: bool = True, + loss_scale: float = 1.0, + schedule_decay: float = 4e-3, + checkpoint_path: str = "", + eps: float = 1e-10, + **kwargs, +): + """build pretrain optimizer""" + + opt = opt.lower() + + skip = {} + skip_keywords = {} + if hasattr(model, "no_weight_decay"): + skip = model.no_weight_decay() + if hasattr(model, "no_weight_decay_keywords"): + skip_keywords = model.no_weight_decay_keywords() + + params = get_pretrain_param_groups(model, weight_decay, skip, skip_keywords) + + opt_args = dict(**kwargs) + # if lr is not None: + # opt_args.setdefault('lr', lr) + + optimizer = get_optimizer( + params, opt_args, opt, lr, weight_decay, momentum, nesterov, loss_scale, schedule_decay, checkpoint_path, eps + ) + + return optimizer + + +def get_vit_layer(name, num_layers): + if name in ("cls_token", "mask_token", "pos_embed"): + return 0 + elif name.startswith("patch_embed"): + return 0 + elif name.startswith("rel_pos_bias"): + return num_layers - 1 + elif name.startswith("blocks"): + layer_id = int(name.split(".")[1]) + return layer_id + 1 + else: + return num_layers - 1 + + +def get_swin_layer(name, num_layers, depths): + if name in ("mask_token",): + return 0 + elif name.startswith("patch_embed"): + return 0 + elif name.startswith("layers"): + layer_id = int(name.split(".")[1]) + block_id = name.split(".")[3] + if block_id == "reduction" or block_id == "norm": + return sum(depths[: layer_id + 1]) + layer_id = sum(depths[:layer_id]) + int(block_id) + return layer_id + 1 + else: + return num_layers - 1 + + +def get_finetune_param_groups( + model, + lr, + weight_decay, + get_layer_func, + scales, + skip, + skip_keywords, +): + parameter_group_names = {} + parameter_group_vars = {} + + for param in model.trainable_params(): + if ( + len(param.shape) == 1 + or param.name.endswith(".bias") + or (param.name in skip) + or check_keywords_in_name(param.name, skip_keywords) + ): + group_name = "no_decay" + this_weight_decay = 0.0 + else: + group_name = "decay" + this_weight_decay = weight_decay + if get_layer_func is not None: + layer_id = get_layer_func(param.name) + group_name = "layer_%d_%s" % (layer_id, group_name) + else: + layer_id = None + + if group_name not in parameter_group_names: + if scales is not None: + scale = scales[layer_id] + else: + scale = 1.0 + + parameter_group_names[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr": [learning_rate * scale for learning_rate in lr], + } + parameter_group_vars[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr": [learning_rate * scale for learning_rate in lr], + } + + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(param.name) + + return list(parameter_group_vars.values()) + + +def create_finetune_optimizer( + model, + opt: str = "adam", + lr: Optional[float] = 1e-3, + weight_decay: float = 0, + momentum: float = 0.9, + nesterov: bool = False, + filter_bias_and_bn: bool = True, + loss_scale: float = 1.0, + schedule_decay: float = 4e-3, + checkpoint_path: str = "", + eps: float = 1e-10, + scale: float = 0.75, + **kwargs, +): + if hasattr(model, "get_depths"): + depths = model.get_depths() + num_layers = model.get_num_layers() + get_layer_func = partial(get_swin_layer, num_layers=num_layers + 2, depths=depths) + elif hasattr(model, "get_num_layers"): + num_layers = model.get_num_layers() + get_layer_func = partial(get_vit_layer, num_layers=num_layers + 2) + else: + raise NotImplementedError() + + scales = list(scale**i for i in reversed(range(num_layers + 2))) + + skip = {} + skip_keywords = {} + if hasattr(model, "no_weight_decay"): + skip = model.no_weight_decay() + if hasattr(model, "no_weight_decay_keywords"): + skip_keywords = model.no_weight_decay_keywords() + + params = get_finetune_param_groups(model, lr, weight_decay, get_layer_func, scales, skip, skip_keywords) + + opt_args = dict(**kwargs) + # if lr is not None: + # opt_args.setdefault('lr', lr) + + optimizer = get_optimizer( + params, opt_args, opt, lr, weight_decay, momentum, nesterov, loss_scale, schedule_decay, checkpoint_path, eps + ) + + return optimizer + + +def get_optimizer( + params, + opt_args, + opt: str = "adam", + lr: Optional[float] = 1e-3, + weight_decay: float = 0, + momentum: float = 0.9, + nesterov: bool = False, + loss_scale: float = 1.0, + schedule_decay: float = 4e-3, + checkpoint_path: str = "", + eps: float = 1e-10, +): # non-adaptive: SGD, momentum, and nesterov if opt == "sgd": # note: nn.Momentum may perform better if momentum > 0. @@ -174,3 +388,11 @@ def create_optimizer( load_param_into_net(optimizer, param_dict) return optimizer + + +def check_keywords_in_name(name, keywords=()): + isin = False + for keyword in keywords: + if keyword in name: + isin = True + return isin diff --git a/mindcv/utils/callbacks.py b/mindcv/utils/callbacks.py index 5cce910e9..daf70e723 100644 --- a/mindcv/utils/callbacks.py +++ b/mindcv/utils/callbacks.py @@ -6,7 +6,7 @@ import numpy as np import mindspore as ms -from mindspore import ParameterTuple, Tensor, ops +from mindspore import ParameterTuple, Tensor, nn, ops from mindspore.train import Callback, SummaryRecord, load_param_into_net, save_checkpoint from .checkpoint_manager import CheckpointManager @@ -209,7 +209,7 @@ def on_train_epoch_end(self, run_context): self.ckpt_manager.save_ckpoint( cb_params.train_network, num_ckpt=self.ckpt_keep_max, - metric=res[0], + metric=res[0] if len(self.metric_name) > 0 else 0.0, save_path=ckpt_save_path, ) @@ -278,7 +278,7 @@ def _get_lr_from_cbp(self, cb_params): else: # if the optimizer is successfully called, the global_step will actually be the value of next step. optim_step = optimizer.global_step - 1 if optimizer.dynamic_lr: - if isinstance(optimizer.learning_rate, ms.nn.CellList): + if isinstance(optimizer.learning_rate, nn.CellList): # return the learning rates of the first parameter if dynamic_lr lr = optimizer.learning_rate[0](optim_step)[0] else: diff --git a/mindcv/utils/trainer_factory.py b/mindcv/utils/trainer_factory.py index db47a48e6..b6fc8f2e8 100644 --- a/mindcv/utils/trainer_factory.py +++ b/mindcv/utils/trainer_factory.py @@ -4,7 +4,7 @@ import mindspore as ms from mindspore import Tensor, context from mindspore import dtype as mstype -from mindspore import nn +from mindspore import nn, ops from mindspore.ops import functional as F from mindspore.train import DynamicLossScaleManager, FixedLossScaleManager, Model @@ -88,6 +88,7 @@ def create_trainer( clip_grad: bool = False, clip_value: float = 15.0, gradient_accumulation_steps: int = 1, + tokenizer: Optional[nn.Cell] = None, ): """Create Trainer. @@ -123,11 +124,15 @@ def create_trainer( if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list): mindspore_kwargs = dict( network=network, - loss_fn=loss, + loss_fn=loss, # for MAE and SimMIM, loss is None and metric is None. optimizer=optimizer, - metrics=metrics, + metrics=metrics, # for beit, beit v2, eva and eva-02, metric is None amp_level=amp_level, ) + if tokenizer is not None: + mindspore_kwargs["network"] = WithLossCellForPretrain(network, tokenizer, loss) + mindspore_kwargs.pop("loss_fn") + if loss_scale_type.lower() == "fixed": mindspore_kwargs["loss_scale_manager"] = FixedLossScaleManager( loss_scale=loss_scale, drop_overflow_update=drop_overflow_update @@ -149,7 +154,14 @@ def create_trainer( else: # require customized train step eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"]) auto_mixed_precision(network, amp_level, amp_cast_list) - net_with_loss = add_loss_network(network, loss, amp_level) + if tokenizer is not None: + net_with_loss = WithLossCellForPretrain( + network, tokenizer, loss, amp_level + ) # for beit, beit v2, eva, eva-02 + elif loss is None: + net_with_loss = network # for MAE, SimMIM + else: + net_with_loss = add_loss_network(network, loss, amp_level) train_step_kwargs = dict( network=net_with_loss, optimizer=optimizer, @@ -185,3 +197,21 @@ def create_trainer( model = Model(train_step_cell, eval_network=eval_network, metrics=metrics, eval_indexes=[0, 1, 2]) # todo: do we need to set model._loss_scale_manager return model + + +class WithLossCellForPretrain(nn.WithLossCell): + def __init__(self, network: nn.Cell, tokenizer: nn.Cell, loss: nn.Cell): + super(WithLossCellForPretrain, self).__init__(network, loss) + self.tokenizer = tokenizer + + def construct(self, x1, x2, mask): + bsz = x1.shape[0] + mask = ops.reshape(mask, (bsz, -1)) + output = self._backbone(x1, mask) + output = ops.transpose(output, (0, 2, 1)) + + label = self.tokenizer(x2) + bool_mask = (1 - mask).astype(ms.bool_) + label = ops.masked_fill(label, bool_mask, value=-100) + label = F.mixed_precision_cast(mstype.float32, label) + return self._loss_fn(F.mixed_precision_cast(mstype.float32, output), label)