1+ import  argparse 
2+ import  json 
13import  os 
4+ import  random 
5+ import  re 
6+ 
27import  numpy  as  np 
38import  torch 
4- import  os 
5- import  re 
6- import  json 
7- import  argparse 
8- import  random 
9- from  transformers  import  T5Tokenizer , DataCollatorForSeq2Seq , Seq2SeqTrainingArguments , Seq2SeqTrainer , T5ForConditionalGeneration 
10- from  model  import  T5ForConditionalGeneration , T5ForMultimodalGeneration 
11- from  utils_data  import  img_shape , load_data_std , load_data_img , ScienceQADatasetStd , ScienceQADatasetImg 
12- from  utils_prompt  import  * 
13- from  utils_evaluate  import  get_scores 
14- from  rich .table  import  Column , Table 
159from  rich  import  box 
1610from  rich .console  import  Console 
11+ from  rich .table  import  Column , Table 
12+ from  transformers  import  (DataCollatorForSeq2Seq , Seq2SeqTrainer , Seq2SeqTrainingArguments , T5ForConditionalGeneration , T5Tokenizer )
13+ 
14+ from  model  import  T5ForConditionalGeneration , T5ForMultimodalGeneration 
15+ from  utils_data  import  (ScienceQADatasetImg , ScienceQADatasetStd , img_shape , load_data_img , load_data_std )
16+ from  utils_evaluate  import  get_scores 
17+ from  utils_prompt  import  * 
18+ 
1719console  =  Console (record = True )
18- from  torch  import  cuda 
19- import  nltk 
2020import  evaluate 
21+ import  nltk 
22+ from  torch  import  cuda 
2123
2224
2325def  parse_args ():
@@ -36,7 +38,7 @@ def parse_args():
3638    parser .add_argument ('--train_split' , type = str , default = 'train' , choices = ['train' , 'trainval' , 'minitrain' ])
3739    parser .add_argument ('--val_split' , type = str , default = 'val' , choices = ['test' , 'val' , 'minival' ])
3840    parser .add_argument ('--test_split' , type = str , default = 'test' , choices = ['test' , 'minitest' ])
39-      
41+ 
4042    parser .add_argument ('--use_generate' , action = 'store_true' , help = 'only for baseline to improve inference speed' )
4143    parser .add_argument ('--final_eval' , action = 'store_true' , help = 'only evaluate the model at the final epoch' )
4244    parser .add_argument ('--user_msg' , type = str , default = "baseline" , help = 'experiment type in the save_dir' )
@@ -50,16 +52,15 @@ def parse_args():
5052                        choices = ['QCM-A' , 'QCM-LE' , 'QCMG-A' , 'QCM-LEA' , 'QCM-ALE' ])
5153    parser .add_argument ('--seed' , type = int , default = 42 , help = 'random seed' )
5254
53-     args  =  parser .parse_args ()
54-     return  args 
55+     return  parser .parse_args ()
5556
5657def  T5Trainer (
5758    dataframe , args ,
5859):
5960    torch .manual_seed (args .seed )  # pytorch random seed 
6061    np .random .seed (args .seed )  # numpy random seed 
6162    torch .backends .cudnn .deterministic  =  True 
62-      
63+ 
6364    if  args .evaluate_dir  is  not   None :
6465        args .model  =  args .evaluate_dir 
6566
@@ -72,7 +73,7 @@ def T5Trainer(
7273    train_qids  =  qids ['train' ]
7374    test_qids  =  qids ['test' ]
7475    val_qids  =  qids ['val' ]
75-      
76+ 
7677    if  args .evaluate_dir  is  not   None :
7778        save_dir  =  args .evaluate_dir 
7879    else :
@@ -139,7 +140,7 @@ def T5Trainer(
139140            args ,
140141            args .eval_le ,
141142        )
142-          
143+ 
143144        test_set  =  ScienceQADatasetStd (
144145            problems ,
145146            test_qids ,
@@ -155,11 +156,8 @@ def T5Trainer(
155156    def  extract_ans (ans ):
156157        pattern  =  re .compile (r'The answer is \(([A-Z])\)' )
157158        res  =  pattern .findall (ans )
158-         
159-         if  len (res ) ==  1 :
160-             answer  =  res [0 ]  # 'A', 'B', ... 
161-         else :
162-             answer  =  "FAILED"  
159+ 
160+         answer  =  res [0 ] if  len (res ) ==  1  else  "FAILED" 
163161        return  answer   
164162
165163    # accuracy for answer inference 
@@ -184,7 +182,7 @@ def compute_metrics_acc(eval_preds):
184182            if  reference  ==  best_option :
185183                correct  += 1  
186184        return  {'accuracy' : 1.0 * correct / len (targets )}
187-      
185+ 
188186    # rougel for rationale generation 
189187    metric  =  evaluate .load ("rouge" )
190188    def  postprocess_text (preds , labels ):
@@ -218,13 +216,13 @@ def compute_metrics_rougel(eval_preds):
218216    if  args .final_eval :
219217        training_args  =  Seq2SeqTrainingArguments (
220218            save_dir ,
221-             do_train = True   if   args .evaluate_dir  is  None   else   False ,
219+             do_train = args .evaluate_dir  is  None ,
222220            do_eval = False ,
223221            evaluation_strategy = "no" ,
224222            logging_strategy = "steps" ,
225223            save_strategy = "epoch" ,
226-             save_total_limit   =   2 ,
227-             learning_rate =   args .lr ,
224+             save_total_limit = 2 ,
225+             learning_rate = args .lr ,
228226            eval_accumulation_steps = args .eval_acc ,
229227            per_device_train_batch_size = args .bs ,
230228            per_device_eval_batch_size = args .eval_bs ,
@@ -233,23 +231,24 @@ def compute_metrics_rougel(eval_preds):
233231            predict_with_generate = args .use_generate ,
234232            report_to = "none" ,
235233        )
236-     # evaluate at each epoch 
237234    else :
238235        training_args  =  Seq2SeqTrainingArguments (
239236            save_dir ,
240-             do_train = True   if   args .evaluate_dir  is  None   else   False ,
237+             do_train = args .evaluate_dir  is  None ,
241238            do_eval = True ,
242239            evaluation_strategy = "epoch" ,
243240            logging_strategy = "steps" ,
244241            save_strategy = "epoch" ,
245-             save_total_limit   =   2 ,
246-             learning_rate =   args .lr ,
242+             save_total_limit = 2 ,
243+             learning_rate = args .lr ,
247244            eval_accumulation_steps = args .eval_acc ,
248245            per_device_train_batch_size = args .bs ,
249246            per_device_eval_batch_size = args .eval_bs ,
250247            weight_decay = 0.01 ,
251248            num_train_epochs = args .epoch ,
252-             metric_for_best_model = "accuracy"  if  args .prompt_format  !=  "QCM-LE"  else  "rougeL" ,
249+             metric_for_best_model = "accuracy" 
250+             if  args .prompt_format  !=  "QCM-LE" 
251+             else  "rougeL" ,
253252            predict_with_generate = args .use_generate ,
254253            load_best_model_at_end = True ,
255254            report_to = "none" ,
@@ -268,12 +267,12 @@ def compute_metrics_rougel(eval_preds):
268267    if  args .evaluate_dir  is  None :
269268        trainer .train ()
270269        trainer .save_model (save_dir )
271-          
270+ 
272271    metrics  =  trainer .evaluate (eval_dataset  =  test_set )
273272    trainer .log_metrics ("test" , metrics )
274273    trainer .save_metrics ("test" , metrics )
275274
276-     predict_results  =  trainer .predict (test_dataset = test_set , max_length = args .output_len )  
275+     predict_results  =  trainer .predict (test_dataset = test_set , max_length = args .output_len )
277276    if  trainer .is_world_process_zero ():
278277        if  args .use_generate :
279278            preds , targets  =  predict_results .predictions , predict_results .label_ids 
@@ -292,7 +291,7 @@ def compute_metrics_rougel(eval_preds):
292291        results_ans  =  {}
293292        results_rationale  =  {}
294293        results_reference  =  {}
295-          
294+ 
296295        num_fail  =  0 
297296        for  idx , qid  in  enumerate (test_qids ):
298297            pred  =  preds [int (idx )]
@@ -302,7 +301,7 @@ def compute_metrics_rougel(eval_preds):
302301                if  extract_pred  in  args .options :
303302                    extract_pred  =  args .options .index (extract_pred )
304303                else :
305-                     extract_pred  =  random .choice (range (0 , len (args .options )))
304+                     extract_pred  =  random .choice (range (len (args .options )))
306305            else :
307306                num_fail  +=  1 
308307                extract_pred  =  random .choice (range (len (args .options ))) # random choose one option 
@@ -320,7 +319,7 @@ def compute_metrics_rougel(eval_preds):
320319        output_prediction_file  =  os .path .join (save_dir ,"predictions_ans_test.json" )
321320        with  open (output_prediction_file , "w" ) as  writer :
322321            writer .write (json .dumps (output_data , indent = 4 ))
323-      
322+ 
324323    # generate the rationale for the eval set 
325324    if  args .prompt_format  ==  "QCM-LE" :
326325        torch .cuda .empty_cache ()
0 commit comments