6464from datetime import datetime
6565import gc
6666import json
67+ import os
6768import random
6869import time
6970from typing import Any , AsyncGenerator , Optional
70- import os
71-
7271
72+ from benchmarks .eval_accuracy import eval_accuracy
73+ from benchmarks .metrics import CounterMetric , EventMetric
7374import grpc
74- from benchmarks .metrics import EventMetric , CounterMetric
7575from jetstream .core .proto import jetstream_pb2
7676from jetstream .core .proto import jetstream_pb2_grpc
7777from jetstream .engine .token_utils import load_vocab
7878from jetstream .external_tokenizers .llama3 import llama3_tokenizer
7979import numpy as np
80- from tqdm .asyncio import tqdm # pytype: disable=pyi-error
8180import pandas
82-
83- from eval_accuracy import eval_accuracy
81+ from tqdm .asyncio import tqdm # pytype: disable=pyi-error
8482from transformers import AutoTokenizer
8583
8684
@@ -706,136 +704,7 @@ def sample_warmup_requests(requests):
706704 break
707705
708706
709- def main (args : argparse .Namespace ):
710- print (args )
711- random .seed (args .seed )
712- np .random .seed (args .seed )
713-
714- model_id = args .model
715- tokenizer_id = args .tokenizer
716- use_hf_tokenizer = args .use_hf_tokenizer
717-
718- prefill_quota = AsyncCounter (init_value = 3 )
719- active_req_quota = AsyncCounter (init_value = 450 )
720-
721- api_url = f"{ args .server } :{ args .port } "
722-
723- tokenizer = get_tokenizer (model_id , tokenizer_id , use_hf_tokenizer )
724- if tokenizer == "test" or args .dataset == "test" :
725- input_requests = mock_requests (
726- args .total_mock_requests
727- ) # e.g. [("AB", 2, "AB", 3)]
728- else :
729- dataset = []
730- if args .dataset == "openorca" :
731- dataset = load_openorca_dataset_pkl (args .dataset_path )
732- elif args .dataset == "sharegpt" :
733- dataset = load_sharegpt_dataset (
734- args .dataset_path ,
735- args .conversation_starter ,
736- )
737-
738- # A given args.max_output_length value is the max generation step,
739- # when the args.max_output_length is default to None, the sample's golden
740- # output length will be used to decide the generation step.
741- input_requests = sample_requests (
742- dataset = dataset ,
743- tokenizer = tokenizer ,
744- num_requests = args .num_prompts ,
745- max_output_length = args .max_output_length ,
746- )
747-
748- warmup_requests = None
749- if args .warmup_mode == "full" :
750- warmup_requests = input_requests
751- elif args .warmup_mode == "sampled" :
752- warmup_requests = list (sample_warmup_requests (input_requests )) * 2
753-
754- if warmup_requests :
755- print (f"Warmup (mode: { args .warmup_mode } ) is starting." )
756- _ , _ = asyncio .run (
757- benchmark (
758- api_url = api_url ,
759- tokenizer = tokenizer ,
760- input_requests = warmup_requests ,
761- request_rate = args .request_rate ,
762- disable_tqdm = args .disable_tqdm ,
763- prefill_quota = prefill_quota ,
764- active_req_quota = active_req_quota ,
765- is_warmup = True ,
766- )
767- )
768- print (f"Warmup (mode: { args .warmup_mode } ) has completed." )
769-
770- # TODO: Replace this with warmup complete signal once supported.
771- # Wait for server completely warmup before running the benchmark.
772- time .sleep (5 )
773-
774- benchmark_result , request_outputs = asyncio .run (
775- benchmark (
776- api_url = api_url ,
777- tokenizer = tokenizer ,
778- input_requests = input_requests ,
779- request_rate = args .request_rate ,
780- disable_tqdm = args .disable_tqdm ,
781- prefill_quota = prefill_quota ,
782- active_req_quota = active_req_quota ,
783- )
784- )
785-
786- # Process output
787- output = [output .to_dict () for output in request_outputs ]
788- if args .run_eval :
789- eval_json = eval_accuracy (output )
790-
791- # Save config and results to json
792- if args .save_result :
793- # dimensions values are strings
794- dimensions_json = {}
795- # metrics values are numerical
796- metrics_json = {}
797-
798- # Setup
799- current_dt = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
800- dimensions_json ["date" ] = current_dt
801- dimensions_json ["model_id" ] = model_id
802- dimensions_json ["tokenizer_id" ] = tokenizer_id
803- if args .additional_metadata_metrics_to_save is not None :
804- dimensions_json = {
805- ** dimensions_json ,
806- ** json .loads (args .additional_metadata_metrics_to_save ),
807- }
808- metrics_json ["num_prompts" ] = args .num_prompts
809-
810- # Traffic
811- metrics_json ["request_rate" ] = args .request_rate
812- metrics_json = {** metrics_json , ** benchmark_result }
813- if args .run_eval :
814- metrics_json = {** metrics_json , ** eval_json }
815-
816- final_json = {}
817- final_json ["metrics" ] = metrics_json
818- final_json ["dimensions" ] = dimensions_json
819-
820- # Save to file
821- base_model_id = model_id .split ("/" )[- 1 ]
822- file_name = (
823- f"JetStream-{ args .request_rate } qps-{ base_model_id } -{ current_dt } .json"
824- )
825- with open (file_name , "w" , encoding = "utf-8" ) as outfile :
826- json .dump (final_json , outfile )
827-
828- if args .save_request_outputs :
829- file_path = args .request_outputs_file_path
830- with open (file_path , "w" , encoding = "utf-8" ) as output_file :
831- json .dump (
832- output ,
833- output_file ,
834- indent = 4 ,
835- )
836-
837-
838- if __name__ == "__main__" :
707+ def parse_args () -> argparse .Namespace :
839708 parser = argparse .ArgumentParser (
840709 description = "Benchmark the online serving throughput."
841710 )
@@ -909,7 +778,6 @@ def main(args: argparse.Namespace):
909778 default = 150 ,
910779 help = "The maximum number of mock requests to send for benchmark testing." ,
911780 )
912-
913781 parser .add_argument (
914782 "--max-output-length" ,
915783 type = int ,
@@ -926,7 +794,6 @@ def main(args: argparse.Namespace):
926794 "the output length of the golden dataset would be passed."
927795 ),
928796 )
929-
930797 parser .add_argument ("--seed" , type = int , default = 0 )
931798 parser .add_argument (
932799 "--disable-tqdm" ,
@@ -977,7 +844,138 @@ def main(args: argparse.Namespace):
977844 choices = ["human" , "gpt" , "both" ],
978845 help = "What entity should be the one starting the conversations." ,
979846 )
847+ return parser .parse_args ()
848+
849+
850+ def main (args : argparse .Namespace ):
851+ print (args )
852+ random .seed (args .seed )
853+ np .random .seed (args .seed )
854+
855+ model_id = args .model
856+ tokenizer_id = args .tokenizer
857+ use_hf_tokenizer = args .use_hf_tokenizer
858+
859+ prefill_quota = AsyncCounter (init_value = 3 )
860+ active_req_quota = AsyncCounter (init_value = 450 )
861+
862+ api_url = f"{ args .server } :{ args .port } "
863+
864+ tokenizer = get_tokenizer (model_id , tokenizer_id , use_hf_tokenizer )
865+ if tokenizer == "test" or args .dataset == "test" :
866+ input_requests = mock_requests (
867+ args .total_mock_requests
868+ ) # e.g. [("AB", 2, "AB", 3)]
869+ else :
870+ dataset = []
871+ if args .dataset == "openorca" :
872+ dataset = load_openorca_dataset_pkl (args .dataset_path )
873+ elif args .dataset == "sharegpt" :
874+ dataset = load_sharegpt_dataset (
875+ args .dataset_path ,
876+ args .conversation_starter ,
877+ )
878+
879+ # A given args.max_output_length value is the max generation step,
880+ # when the args.max_output_length is default to None, the sample's golden
881+ # output length will be used to decide the generation step.
882+ input_requests = sample_requests (
883+ dataset = dataset ,
884+ tokenizer = tokenizer ,
885+ num_requests = args .num_prompts ,
886+ max_output_length = args .max_output_length ,
887+ )
888+
889+ warmup_requests = None
890+ if args .warmup_mode == "full" :
891+ warmup_requests = input_requests
892+ elif args .warmup_mode == "sampled" :
893+ warmup_requests = list (sample_warmup_requests (input_requests )) * 2
894+
895+ if warmup_requests :
896+ print (f"Warmup (mode: { args .warmup_mode } ) is starting." )
897+ _ , _ = asyncio .run (
898+ benchmark (
899+ api_url = api_url ,
900+ tokenizer = tokenizer ,
901+ input_requests = warmup_requests ,
902+ request_rate = args .request_rate ,
903+ disable_tqdm = args .disable_tqdm ,
904+ prefill_quota = prefill_quota ,
905+ active_req_quota = active_req_quota ,
906+ is_warmup = True ,
907+ )
908+ )
909+ print (f"Warmup (mode: { args .warmup_mode } ) has completed." )
910+
911+ # TODO: Replace this with warmup complete signal once supported.
912+ # Wait for server completely warmup before running the benchmark.
913+ time .sleep (5 )
914+
915+ benchmark_result , request_outputs = asyncio .run (
916+ benchmark (
917+ api_url = api_url ,
918+ tokenizer = tokenizer ,
919+ input_requests = input_requests ,
920+ request_rate = args .request_rate ,
921+ disable_tqdm = args .disable_tqdm ,
922+ prefill_quota = prefill_quota ,
923+ active_req_quota = active_req_quota ,
924+ )
925+ )
926+
927+ # Process output
928+ output = [output .to_dict () for output in request_outputs ]
929+ if args .run_eval :
930+ eval_json = eval_accuracy (output )
931+
932+ # Save config and results to json
933+ if args .save_result :
934+ # dimensions values are strings
935+ dimensions_json = {}
936+ # metrics values are numerical
937+ metrics_json = {}
980938
981- parsed_args = parser .parse_args ()
939+ # Setup
940+ current_dt = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
941+ dimensions_json ["date" ] = current_dt
942+ dimensions_json ["model_id" ] = model_id
943+ dimensions_json ["tokenizer_id" ] = tokenizer_id
944+ if args .additional_metadata_metrics_to_save is not None :
945+ dimensions_json = {
946+ ** dimensions_json ,
947+ ** json .loads (args .additional_metadata_metrics_to_save ),
948+ }
949+ metrics_json ["num_prompts" ] = args .num_prompts
950+
951+ # Traffic
952+ metrics_json ["request_rate" ] = args .request_rate
953+ metrics_json = {** metrics_json , ** benchmark_result }
954+ if args .run_eval :
955+ metrics_json = {** metrics_json , ** eval_json }
956+
957+ final_json = {}
958+ final_json ["metrics" ] = metrics_json
959+ final_json ["dimensions" ] = dimensions_json
960+
961+ # Save to file
962+ base_model_id = model_id .split ("/" )[- 1 ]
963+ file_name = (
964+ f"JetStream-{ args .request_rate } qps-{ base_model_id } -{ current_dt } .json"
965+ )
966+ with open (file_name , "w" , encoding = "utf-8" ) as outfile :
967+ json .dump (final_json , outfile )
968+
969+ if args .save_request_outputs :
970+ file_path = args .request_outputs_file_path
971+ with open (file_path , "w" , encoding = "utf-8" ) as output_file :
972+ json .dump (
973+ output ,
974+ output_file ,
975+ indent = 4 ,
976+ )
977+
978+
979+ if __name__ == "__main__" :
982980 gc .disable ()
983- main (parsed_args )
981+ main (parse_args () )
0 commit comments