Skip to content

Commit 8efb8ca

Browse files
committed
🔧 refactor(params): streamline custom serializer logic in params.py
Simplify custom serialization by using `model_serializer` to include specific fields even when they are `None`.
1 parent e02d2c7 commit 8efb8ca

File tree

1 file changed

+15
-28
lines changed
  • src/novelai_python/sdk/ai/generate_image

1 file changed

+15
-28
lines changed

‎src/novelai_python/sdk/ai/generate_image/params.py‎

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import base64
22
import random
33
from io import BytesIO
4-
from typing import Optional, List, OrderedDict, Set, Union, Tuple
4+
from typing import Optional, List, Set, Union, Tuple
55

66
import cv2
77
import numpy as np
88
from PIL import Image
99
from loguru import logger
10-
from pydantic import BaseModel, Field, model_validator, field_validator
10+
from pydantic import BaseModel, Field, model_validator, field_validator, model_serializer
1111

1212
from novelai_python.sdk.ai._enum import Sampler, UCPresetTypeAlias, NoiseSchedule, ImageBytesTypeAlias, ControlNetModel, \
1313
Model
@@ -115,39 +115,26 @@ class Params(BaseModel):
115115
"""ControlNet Model"""
116116
uncond_scale: Optional[float] = Field(None, ge=0, le=1.5, multiple_of=0.05)
117117
"""Undesired Content Strength"""
118-
119-
########## Next code wants custom BaseModel
118+
119+
# region Custom serializer
120120
__strong_values__: Set[str] = {
121121
"skip_cfg_above_sigma", # See models before Anime V3
122122
}
123123
"""Settings with none value to strongly include"""
124124

125-
def force_include(self, value: str):
125+
@model_serializer(mode="wrap")
126+
def _serialize(self, handler):
126127
"""
127-
Force include some params to json
128-
"""
129-
if value not in self.__dict__:
130-
raise ValueError("Wrong value in params:", value)
131-
132-
self.__strong_values__.add(value)
133-
134-
# Also integrated into GenerateImageInfer
135-
def model_dump(self, *args, **kwargs):
128+
Custom serializer to force include specific fields even when they are None
136129
"""
137-
Overrides model_dump for own features
138-
"""
139-
data = super().model_dump(*args, **kwargs)
140-
ordered_fields = list(self.__fields__.keys())
141-
new_data = OrderedDict()
142-
143-
for field in ordered_fields:
144-
if field in data:
145-
new_data[field] = data[field]
146-
elif field in self.__strong_values__:
147-
new_data[field] = getattr(self, field, None)
148-
return new_data
149-
#########
150-
130+
data = handler(self)
131+
# Just add None values for strong fields
132+
for field in self.__strong_values__:
133+
if field not in data:
134+
data[field] = getattr(self, field, None)
135+
return data
136+
# endregion
137+
151138
@model_validator(mode="after")
152139
def v_character(self):
153140
if len(self.characterPrompts) > 6:

0 commit comments

Comments
 (0)