Skip to content

Commit 00daf5e

Browse files
authored
Merge pull request #105 from roboflow/add-class-names
Add class names to args in checkpoint
2 parents 57e0f64 + 8bb8e61 commit 00daf5e

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

rfdetr/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,4 @@ class TrainConfig(BaseModel):
7979
wandb: bool = False
8080
project: Optional[str] = None
8181
run: Optional[str] = None
82+
class_names: List[str] = None

rfdetr/detr.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def train_from_config(self, config: TrainConfig, **kwargs):
4444
) as f:
4545
anns = json.load(f)
4646
num_classes = len(anns["categories"])
47+
class_names = [c["name"] for c in anns["categories"] if c["supercategory"] != "none"]
4748

4849
if self.model_config.num_classes != num_classes:
4950
logger.warning(
@@ -52,9 +53,16 @@ def train_from_config(self, config: TrainConfig, **kwargs):
5253
)
5354
self.model.reinitialize_detection_head(num_classes)
5455

56+
5557
train_config = config.dict()
5658
model_config = self.model_config.dict()
5759
model_config.pop("num_classes")
60+
if "class_names" in model_config:
61+
model_config.pop("class_names")
62+
63+
if "class_names" in train_config and train_config["class_names"] is None:
64+
train_config["class_names"] = class_names
65+
5866
for k, v in train_config.items():
5967
if k in model_config:
6068
model_config.pop(k)

0 commit comments

Comments
 (0)