11import argparse
2+ import argparse
23import json
34import logging
45import os
1011import torch
1112from sklearn .metrics import classification_report , confusion_matrix
1213from tqdm import tqdm
14+ from huggingface_hub import HfFolder
1315
1416from adaptive_classifier import AdaptiveClassifier
1517
@@ -52,6 +54,21 @@ def setup_args() -> argparse.Namespace:
5254 default = 'benchmark_results' ,
5355 help = 'Directory to save results'
5456 )
57+ parser .add_argument (
58+ '--push-to-hub' ,
59+ action = 'store_true' ,
60+ help = 'Push the trained model to HuggingFace Hub'
61+ )
62+ parser .add_argument (
63+ '--hub-repo' ,
64+ type = str ,
65+ help = 'HuggingFace Hub repository ID (e.g. "username/model-name") for pushing the model'
66+ )
67+ parser .add_argument (
68+ '--hub-token' ,
69+ type = str ,
70+ help = 'HuggingFace Hub token. If not provided, will look for the token in the environment'
71+ )
5572 return parser .parse_args ()
5673
5774def load_dataset (max_samples : int = None ) -> Tuple [datasets .Dataset , datasets .Dataset ]:
@@ -223,8 +240,47 @@ def evaluate_classifier(
223240
224241 return results
225242
226- def save_results (classifier , results : Dict [str , Any ], args : argparse .Namespace ):
227- """Save evaluation results."""
243+ def push_to_hub (
244+ classifier : AdaptiveClassifier ,
245+ repo_id : str ,
246+ token : str = None ,
247+ metrics : Dict [str , Any ] = None
248+ ) -> str :
249+ """Push the classifier to HuggingFace Hub.
250+
251+ Args:
252+ classifier: Trained classifier to push
253+ repo_id: HuggingFace Hub repository ID
254+ token: HuggingFace Hub token
255+ metrics: Optional evaluation metrics to add to model card
256+
257+ Returns:
258+ URL of the model on the Hub
259+ """
260+ logger .info (f"Pushing model to HuggingFace Hub: { repo_id } " )
261+
262+ # Set token if provided
263+ if token :
264+ HfFolder .save_token (token )
265+
266+ try :
267+ # Push to hub with evaluation results in model card
268+ url = classifier .push_to_hub (
269+ repo_id ,
270+ commit_message = "Upload from benchmark script" ,
271+ )
272+ logger .info (f"Successfully pushed model to Hub: { url } " )
273+ return url
274+ except Exception as e :
275+ logger .error (f"Error pushing to Hub: { str (e )} " )
276+ raise
277+
278+ def save_results (
279+ classifier : AdaptiveClassifier ,
280+ results : Dict [str , Any ],
281+ args : argparse .Namespace
282+ ):
283+ """Save evaluation results and optionally push to Hub."""
228284 # Create output directory
229285 os .makedirs (args .output_dir , exist_ok = True )
230286
@@ -241,15 +297,32 @@ def save_results(classifier, results: Dict[str, Any], args: argparse.Namespace):
241297 'timestamp' : timestamp
242298 }
243299
244- # Save classifier
300+ # Save classifier locally
245301 classifier .save (args .output_dir )
246302
247- # Save results
303+ # Save results locally
248304 with open (filepath , 'w' ) as f :
249305 json .dump (results , f , indent = 2 )
250306
251307 logger .info (f"Results saved to { filepath } " )
252308
309+ # Push to HuggingFace Hub if requested
310+ if args .push_to_hub :
311+ if not args .hub_repo :
312+ raise ValueError ("--hub-repo must be specified when using --push-to-hub" )
313+
314+ hub_url = push_to_hub (
315+ classifier ,
316+ args .hub_repo ,
317+ args .hub_token ,
318+ metrics = results ['metrics' ]
319+ )
320+ results ['hub_url' ] = hub_url
321+
322+ # Update saved results with hub URL
323+ with open (filepath , 'w' ) as f :
324+ json .dump (results , f , indent = 2 )
325+
253326 # Print summary to console
254327 print ("\n Evaluation Results:" )
255328 print ("-" * 50 )
@@ -267,6 +340,9 @@ def save_results(classifier, results: Dict[str, Any], args: argparse.Namespace):
267340 print (" HIGH LOW" )
268341 print (f"Actual HIGH { results ['confusion_matrix' ][0 ][0 ]:4d} { results ['confusion_matrix' ][0 ][1 ]:4d} " )
269342 print (f" LOW { results ['confusion_matrix' ][1 ][0 ]:4d} { results ['confusion_matrix' ][1 ][1 ]:4d} " )
343+
344+ if args .push_to_hub :
345+ print (f"\n Model pushed to HuggingFace Hub: { results ['hub_url' ]} " )
270346
271347def main ():
272348 """Main execution function."""
0 commit comments