Skip to content

Commit ed54776

Browse files
authored
Merge pull request #252 from roboflow/use_best_checkpoint_for_test
load best checkpoint for test metrics
2 parents 4094f00 + 6288885 commit ed54776

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

rfdetr/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,4 @@ class TrainConfig(BaseModel):
8888
project: Optional[str] = None
8989
run: Optional[str] = None
9090
class_names: List[str] = None
91+
run_test: bool = True

rfdetr/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,9 +471,14 @@ def lr_lambda(current_step: int):
471471

472472

473473
if args.run_test:
474+
best_state_dict = torch.load(output_dir / 'checkpoint_best_total.pth', map_location='cpu', weights_only=False)['model']
475+
model.load_state_dict(best_state_dict)
476+
model.eval()
477+
474478
test_stats, _ = evaluate(
475479
model, criterion, postprocessors, data_loader_test, base_ds_test, device, args=args
476480
)
481+
print(f"Test results: {test_stats}")
477482
with open(output_dir / "results.json", "r") as f:
478483
results = json.load(f)
479484
test_metrics = test_stats["results_json"]["class_map"]

0 commit comments

Comments
 (0)