Skip to content

Commit aa2da6c

Browse files
authored
Merge pull request #30 from DeepLabCut/niels/fix_tests_update_readme
update README to add new models
2 parents 5a22726 + 90a945a commit aa2da6c

File tree

3 files changed

+46
-3
lines changed

3 files changed

+46
-3
lines changed

README.md

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,15 @@ download_huggingface_model("superanimal_quadruped", model_dir)
3636
```
3737

3838
PyTorch models available for a given dataset (compatible with DeepLabCut>=3.0) can be
39-
listed using the `dlclibrary.get_available_detectors` and
40-
`dlclibrary.get_available_models` methods. Example use:
39+
listed using the `dlclibrary.get_available_detectors` and
40+
`dlclibrary.get_available_models` methods. The datasets for which models are available
41+
can be listed using `dlclibrary.get_available_datasets`. Example use:
4142

4243
```python
4344
>>> import dlclibrary
45+
>>> dlclibrary.get_available_datasets()
46+
['superanimal_bird', 'superanimal_topviewmouse', 'superanimal_quadruped']
47+
4448
>>> dlclibrary.get_available_detectors("superanimal_bird")
4549
['fasterrcnn_mobilenet_v3_large_fpn', 'ssdlite']
4650

@@ -51,6 +55,8 @@ listed using the `dlclibrary.get_available_detectors` and
5155

5256
## How to add a new model?
5357

58+
### TensorFlow models
59+
5460
Pick a good model_name. Follow the (novel) naming convention (modeltype_species), e.g. ```superanimal_topviewmouse```.
5561

5662
1. Add the model_name with path and commit ID to: https://github.com/DeepLabCut/DLClibrary/blob/main/dlclibrary/dlcmodelzoo/modelzoo_urls.yaml
@@ -59,3 +65,31 @@ Pick a good model_name. Follow the (novel) naming convention (modeltype_species)
5965
https://github.com/DeepLabCut/DLClibrary/blob/main/dlclibrary/dlcmodelzoo/modelzoo_download.py#L15
6066

6167
3. For superanimal models also fill in the configs!
68+
69+
### PyTorch models (for `deeplabcut >= 3.0.0`)
70+
71+
PyTorch models are listed in [`dlclibrary/dlcmodelzoo/modelzoo_urls_pytorch.yaml`](
72+
https://github.com/DeepLabCut/DLClibrary/blob/main/dlclibrary/dlcmodelzoo/modelzoo_urls_pytorch.yaml
73+
). The file is organized as:
74+
75+
```yaml
76+
my_cool_dataset: # name of the dataset used to train the model
77+
detectors:
78+
detector_name: path/to/huggingface-detector.pt # add detectors under `detector`
79+
pose_models:
80+
pose_model_name: path/to/huggingface-pose-model.pt # add pose models under `pose_models`
81+
other_pose_model_name: path/to/huggingface-other-pose-model.pt
82+
```
83+
84+
This will allow users to download the models using the format `datatsetName_modelName`,
85+
i.e. for this example 3 models would be available: `my_cool_dataset_detector_name`,
86+
`my_cool_dataset_pose_model_name` and `my_cool_dataset_other_pose_model_name`.
87+
88+
To add a new model for `deeplabcut >= 3.0.0`, simply:
89+
90+
- add a new line under detectors or pose models if the dataset is already defined
91+
- add the structure if the model was trained on a new dataset
92+
93+
The models will then be listed when calling `dlclibrary.get_available_detectors` or
94+
`dlclibrary.get_available_models`! You can list the datasets for which models are
95+
available using `dlclibrary.get_available_datasets`.

dlclibrary/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from dlclibrary.dlcmodelzoo.modelzoo_download import (
1313
download_huggingface_model,
14+
get_available_datasets,
1415
get_available_detectors,
1516
get_available_models,
1617
parse_available_supermodels,

dlclibrary/dlcmodelzoo/modelzoo_download.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"mouse_pupil_vclose",
2929
"horse_sideview",
3030
"full_macaque",
31-
"superanimal_bird",
3231
"superanimal_quadruped",
3332
"superanimal_topviewmouse",
3433
]
@@ -85,6 +84,15 @@ def parse_available_supermodels():
8584
return super_animal_models
8685

8786

87+
def get_available_datasets() -> list[str]:
88+
"""Only for PyTorch models.
89+
90+
Returns:
91+
The name of datasets for which models are available
92+
"""
93+
return list(_load_pytorch_models().keys())
94+
95+
8896
def get_available_detectors(dataset: str) -> list[str]:
8997
""" Only for PyTorch models.
9098

0 commit comments

Comments
 (0)