|
1 | 1 | import base64 |
2 | 2 | import random |
3 | 3 | from io import BytesIO |
4 | | -from typing import Optional, List, OrderedDict, Set, Union, Tuple |
| 4 | +from typing import Optional, List, Set, Union, Tuple |
5 | 5 |
|
6 | 6 | import cv2 |
7 | 7 | import numpy as np |
8 | 8 | from PIL import Image |
9 | 9 | from loguru import logger |
10 | | -from pydantic import BaseModel, Field, model_validator, field_validator |
| 10 | +from pydantic import BaseModel, Field, model_validator, field_validator, model_serializer |
11 | 11 |
|
12 | 12 | from novelai_python.sdk.ai._enum import Sampler, UCPresetTypeAlias, NoiseSchedule, ImageBytesTypeAlias, ControlNetModel, \ |
13 | 13 | Model |
@@ -115,39 +115,26 @@ class Params(BaseModel): |
115 | 115 | """ControlNet Model""" |
116 | 116 | uncond_scale: Optional[float] = Field(None, ge=0, le=1.5, multiple_of=0.05) |
117 | 117 | """Undesired Content Strength""" |
118 | | - |
119 | | - ########## Next code wants custom BaseModel |
| 118 | + |
| 119 | + # region Custom serializer |
120 | 120 | __strong_values__: Set[str] = { |
121 | 121 | "skip_cfg_above_sigma", # See models before Anime V3 |
122 | 122 | } |
123 | 123 | """Settings with none value to strongly include""" |
124 | 124 |
|
125 | | - def force_include(self, value: str): |
| 125 | + @model_serializer(mode="wrap") |
| 126 | + def _serialize(self, handler): |
126 | 127 | """ |
127 | | - Force include some params to json |
128 | | - """ |
129 | | - if value not in self.__dict__: |
130 | | - raise ValueError("Wrong value in params:", value) |
131 | | - |
132 | | - self.__strong_values__.add(value) |
133 | | - |
134 | | - # Also integrated into GenerateImageInfer |
135 | | - def model_dump(self, *args, **kwargs): |
| 128 | + Custom serializer to force include specific fields even when they are None |
136 | 129 | """ |
137 | | - Overrides model_dump for own features |
138 | | - """ |
139 | | - data = super().model_dump(*args, **kwargs) |
140 | | - ordered_fields = list(self.__fields__.keys()) |
141 | | - new_data = OrderedDict() |
142 | | - |
143 | | - for field in ordered_fields: |
144 | | - if field in data: |
145 | | - new_data[field] = data[field] |
146 | | - elif field in self.__strong_values__: |
147 | | - new_data[field] = getattr(self, field, None) |
148 | | - return new_data |
149 | | - ######### |
150 | | - |
| 130 | + data = handler(self) |
| 131 | + # Just add None values for strong fields |
| 132 | + for field in self.__strong_values__: |
| 133 | + if field not in data: |
| 134 | + data[field] = getattr(self, field, None) |
| 135 | + return data |
| 136 | + # endregion |
| 137 | + |
151 | 138 | @model_validator(mode="after") |
152 | 139 | def v_character(self): |
153 | 140 | if len(self.characterPrompts) > 6: |
|
0 commit comments