-
Notifications
You must be signed in to change notification settings - Fork 114
Add configs to run int4 inference #37
base: main
Are you sure you want to change the base?
Changes from 1 commit
572e644
132d99d
99cd7c9
32779e8
b472e48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,7 +44,7 @@ | |
| parser = ArgumentParser() | ||
|
|
||
| parser.add_argument("--name", required=True, type=str, help="model_name") | ||
| parser.add_argument("--dtype", type=str, help="float16 or int8", choices=["int8", "float16"], default="float16") | ||
| parser.add_argument("--dtype", type=str, help="float16 or int8 or int4", choices=["int8", "float16", "int4"], default="float16") | ||
| parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers") | ||
| parser.add_argument("--batch_size", default=1, type=int, help="batch size") | ||
| parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark") | ||
|
|
@@ -100,7 +100,7 @@ def get_checkpoint_files(model_name_or_path): | |
|
|
||
|
|
||
| model_name = args.name | ||
| infer_dtype = args.dtype | ||
| infer_dtype = args.dtype if args.dtype != 'int4' else 'int8' | ||
|
|
||
| tp_presharded_mode = True if model_name in tp_presharded_models else False | ||
|
|
||
|
|
@@ -191,6 +191,7 @@ def write_checkponts_json(): | |
| mp_size=world_size, | ||
| base_dir=repo_root, | ||
| dtype=getattr(torch, infer_dtype), | ||
| quantization_bits=8 if args.dtype == 'int8' else 4, | ||
|
||
| checkpoint=checkpoints_json, | ||
| **kwargs, | ||
| ) | ||
|
|
@@ -227,7 +228,7 @@ def write_checkponts_json(): | |
| # dynamically extend to support larger bs by repetition | ||
| input_sentences *= math.ceil(args.batch_size / len(input_sentences)) | ||
|
|
||
| generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False) | ||
| generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=True) | ||
|
||
|
|
||
|
|
||
| print_rank0(f"Generate args {generate_kwargs}") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it make for a more user-friendly API to
dtypeintactquantization_bitsdeepspeed.init_inferencederive the number of bits fromdtype?not only the currently suggested override is confusing, I fail to see what purpose serves carrying the same information in
dtypeand andquantization_bitstwiceThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, wait,
torch.init4still doesn't exist, does it?let's find the feature request.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still not implemented pytorch/pytorch#74627
so that's why you had to do the odd workarounds, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can drop it once its implemented @stas00 ?
For now, this might be the best way to do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see #37 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's pointless to wait, since they won't have
int3andint12There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stas00 and @RezaYazdaniAminabadi - just clarifying that we have introduced a new DeepSpeedInferenceConfig that can be passed to init_inference. We are keeping it backwards compatible but if we are okay to make changes to this file, I would advocate for writing a config dictionary for DeepSpeed and pass that to init_inference instead of the various kwargs. Please see here for an example: https://gist.github.com/awan-10/6e3d5c756be3a876522e860c6bbf702d#file-bloom-ds-inference-py-L173
Also, see the docs for the new config: https://deepspeed.readthedocs.io/en/latest/inference-init.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That definitely works.
@awan-10, may I suggest you make the inference
configacceptdict_or_pathjust likezerodoes? it might be for some users easier to write out a separate file.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stas00 - thanks for the suggestion. Created an issue so we can track it: deepspeedai/DeepSpeed#2532. Mike and I will work on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much, @awan-10