diff --git a/evaluation/generation/generate.py b/evaluation/generation/generate.py index fde023bf..96ba9266 100644 --- a/evaluation/generation/generate.py +++ b/evaluation/generation/generate.py @@ -13,6 +13,8 @@ def get_args(): parser.add_argument("--greedy", action="store_true") parser.add_argument("--top-k", type=int, default=0) parser.add_argument("--offload_folder", type=str, help="offload folder for accelerate", default="./offload") + parser.add_argument("--max_memory", type=str, help="max memory per GPU", default="30GB") + parser.add_argument("--max_cpu_memory", type=str, help="max memory on CPU", default="300GB") return parser.parse_args() @@ -40,9 +42,12 @@ def main(): args.checkpoint, device_map="auto" if args.parallelize else None, torch_dtype=torch.bfloat16, - revision="gs{}".format(args.global_step) if args.global_step else None - offload_folder=args.offload_folder is args.parallelize else None, + revision="gs{}".format(args.global_step) if args.global_step else None, + max_memory=args.max_memory if args.parallelize else None, + max_cpu_memory=args.max_cpu_memory if args.parallelize else None, + offload_folder=args.offload_folder if args.parallelize else None, ) + print(f"Loaded model in {datetime.datetime.now() - start}") text = ''