From 5e9b02b69b303d4eb4a93ab28c4b1dbfc16f2115 Mon Sep 17 00:00:00 2001 From: AngelBottomless Date: Tue, 6 Aug 2024 11:47:53 +0900 Subject: [PATCH] add new taggers for BikiniPlusMetric --- sdeval/controllability/bikini_plus.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sdeval/controllability/bikini_plus.py b/sdeval/controllability/bikini_plus.py index 492d835..e7ba6af 100644 --- a/sdeval/controllability/bikini_plus.py +++ b/sdeval/controllability/bikini_plus.py @@ -254,6 +254,11 @@ def _mldanbooru_tagging(image: Image.Image, use_real_name: bool = False, general "wd14_convnextv2": "wd-v1-4-convnextv2-tagger-v2", "wd14_vit": "wd-v1-4-vit-tagger-v2", "wd14_moat": "wd-v1-4-moat-tagger-v2", + "wd14_swinv3": 'wd-v-1-4-swinv2-tagger-v3', + "wd14_convnextv3": 'wd-v-1-4-convnext-tagger-v3', + "wd14_vitv3": 'wd-v-1-4-vit-tagger-v3', + "wd14_eva02_large": 'wd-v-1-4-eva02-large-tagger-v1', + "wd14_vit_largev3": 'wd-v-1-4-vit-large-tagger-v1', } _TAGGING_METHODS = { 'deepdanbooru': _deepdanbooru_tagging, @@ -263,11 +268,17 @@ def _mldanbooru_tagging(image: Image.Image, use_real_name: bool = False, general 'wd14_swinv2': partial(_wd14_tagging, model_name='SwinV2'), 'wd14_moat': partial(_wd14_tagging, model_name='MOAT'), 'mldanbooru': _mldanbooru_tagging, + 'wd14_swinv3': partial(_wd14_tagging, model_name='SwinV2_v3'), + 'wd14_convnextv3': partial(_wd14_tagging, model_name='ConvNext_v3'), + 'wd14_vitv3': partial(_wd14_tagging, model_name='ViT_v3'), + 'wd14_eva02_large': partial(_wd14_tagging, model_name='EVA02_Large'), + 'wd14_vit_largev3': partial(_wd14_tagging, model_name='ViT_Large'), } TaggingMethodTyping = Literal[ 'deepdanbooru', 'mldanbooru', 'wd14_vit', 'wd14_convnext', 'wd14_convnextv2', 'wd14_swinv2', 'wd14_moat', + 'wd14_vitv3', 'wd14_convnextv3', 'wd14_swinv3', 'wd14_eva02_large', 'wd14_vit_largev3' ] PromptedImageTyping = Union[ @@ -339,7 +350,7 @@ class BikiniPlusMetrics: :type silent: bool """ - def __init__(self, tagger: TaggingMethodTyping = 'wd14_convnextv2', + def __init__(self, tagger: TaggingMethodTyping = 'wd14_convnextv3', tagger_cfgs: Optional[dict] = None, base_num: float = 1.5, tag_blacklist: Optional[List[str]] = None, silent: bool = False): self.tagger = tagger