Skip to content

Commit bb1d3a2

Browse files
authored
Merge pull request #14 from codelion/fix-model-card
Fix model card
2 parents d54865a + d2d2d9b commit bb1d3a2

File tree

2 files changed

+81
-5
lines changed

2 files changed

+81
-5
lines changed

scripts/eval_llmrouter_classifier.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import argparse
23
import json
34
import logging
45
import os
@@ -10,6 +11,7 @@
1011
import torch
1112
from sklearn.metrics import classification_report, confusion_matrix
1213
from tqdm import tqdm
14+
from huggingface_hub import HfFolder
1315

1416
from adaptive_classifier import AdaptiveClassifier
1517

@@ -52,6 +54,21 @@ def setup_args() -> argparse.Namespace:
5254
default='benchmark_results',
5355
help='Directory to save results'
5456
)
57+
parser.add_argument(
58+
'--push-to-hub',
59+
action='store_true',
60+
help='Push the trained model to HuggingFace Hub'
61+
)
62+
parser.add_argument(
63+
'--hub-repo',
64+
type=str,
65+
help='HuggingFace Hub repository ID (e.g. "username/model-name") for pushing the model'
66+
)
67+
parser.add_argument(
68+
'--hub-token',
69+
type=str,
70+
help='HuggingFace Hub token. If not provided, will look for the token in the environment'
71+
)
5572
return parser.parse_args()
5673

5774
def load_dataset(max_samples: int = None) -> Tuple[datasets.Dataset, datasets.Dataset]:
@@ -223,8 +240,47 @@ def evaluate_classifier(
223240

224241
return results
225242

226-
def save_results(classifier, results: Dict[str, Any], args: argparse.Namespace):
227-
"""Save evaluation results."""
243+
def push_to_hub(
244+
classifier: AdaptiveClassifier,
245+
repo_id: str,
246+
token: str = None,
247+
metrics: Dict[str, Any] = None
248+
) -> str:
249+
"""Push the classifier to HuggingFace Hub.
250+
251+
Args:
252+
classifier: Trained classifier to push
253+
repo_id: HuggingFace Hub repository ID
254+
token: HuggingFace Hub token
255+
metrics: Optional evaluation metrics to add to model card
256+
257+
Returns:
258+
URL of the model on the Hub
259+
"""
260+
logger.info(f"Pushing model to HuggingFace Hub: {repo_id}")
261+
262+
# Set token if provided
263+
if token:
264+
HfFolder.save_token(token)
265+
266+
try:
267+
# Push to hub with evaluation results in model card
268+
url = classifier.push_to_hub(
269+
repo_id,
270+
commit_message="Upload from benchmark script",
271+
)
272+
logger.info(f"Successfully pushed model to Hub: {url}")
273+
return url
274+
except Exception as e:
275+
logger.error(f"Error pushing to Hub: {str(e)}")
276+
raise
277+
278+
def save_results(
279+
classifier: AdaptiveClassifier,
280+
results: Dict[str, Any],
281+
args: argparse.Namespace
282+
):
283+
"""Save evaluation results and optionally push to Hub."""
228284
# Create output directory
229285
os.makedirs(args.output_dir, exist_ok=True)
230286

@@ -241,15 +297,32 @@ def save_results(classifier, results: Dict[str, Any], args: argparse.Namespace):
241297
'timestamp': timestamp
242298
}
243299

244-
# Save classifier
300+
# Save classifier locally
245301
classifier.save(args.output_dir)
246302

247-
# Save results
303+
# Save results locally
248304
with open(filepath, 'w') as f:
249305
json.dump(results, f, indent=2)
250306

251307
logger.info(f"Results saved to {filepath}")
252308

309+
# Push to HuggingFace Hub if requested
310+
if args.push_to_hub:
311+
if not args.hub_repo:
312+
raise ValueError("--hub-repo must be specified when using --push-to-hub")
313+
314+
hub_url = push_to_hub(
315+
classifier,
316+
args.hub_repo,
317+
args.hub_token,
318+
metrics=results['metrics']
319+
)
320+
results['hub_url'] = hub_url
321+
322+
# Update saved results with hub URL
323+
with open(filepath, 'w') as f:
324+
json.dump(results, f, indent=2)
325+
253326
# Print summary to console
254327
print("\nEvaluation Results:")
255328
print("-" * 50)
@@ -267,6 +340,9 @@ def save_results(classifier, results: Dict[str, Any], args: argparse.Namespace):
267340
print(" HIGH LOW")
268341
print(f"Actual HIGH {results['confusion_matrix'][0][0]:4d} {results['confusion_matrix'][0][1]:4d}")
269342
print(f" LOW {results['confusion_matrix'][1][0]:4d} {results['confusion_matrix'][1][1]:4d}")
343+
344+
if args.push_to_hub:
345+
print(f"\nModel pushed to HuggingFace Hub: {results['hub_url']}")
270346

271347
def main():
272348
"""Main execution function."""

src/adaptive_classifier/classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ def _generate_model_card(self) -> str:
482482
from adaptive_classifier import AdaptiveClassifier
483483
484484
# Load the model
485-
classifier = AdaptiveClassifier.from_pretrained("{self.model.config._name_or_path}")
485+
classifier = AdaptiveClassifier.from_pretrained("adaptive-classifier/model-name")
486486
487487
# Make predictions
488488
text = "Your text here"

0 commit comments

Comments
 (0)