Skip to content

Commit 5642099

Browse files
dunkeronilstein
andauthored
Feat: SDXL Color Compensation (#8637)
* feat(nodes/UI): add SDXL color compensation option * adjust value * Better warnings on wrong VAE base model * Restrict XL compensation to XL models Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com> * fix: BaseModelType missing import * (chore): appease the ruff --------- Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
1 parent 382d85e commit 5642099

File tree

9 files changed

+84
-5
lines changed

9 files changed

+84
-5
lines changed

invokeai/app/invocations/image_to_latents.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from contextlib import nullcontext
22
from functools import singledispatchmethod
3+
from typing import Literal
34

45
import einops
56
import torch
@@ -20,7 +21,7 @@
2021
Input,
2122
InputField,
2223
)
23-
from invokeai.app.invocations.model import VAEField
24+
from invokeai.app.invocations.model import BaseModelType, VAEField
2425
from invokeai.app.invocations.primitives import LatentsOutput
2526
from invokeai.app.services.shared.invocation_context import InvocationContext
2627
from invokeai.backend.model_manager.load.load_base import LoadedModel
@@ -29,13 +30,21 @@
2930
from invokeai.backend.util.devices import TorchDevice
3031
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl
3132

33+
"""
34+
SDXL VAE color compensation values determined experimentally to reduce color drift.
35+
If more reliable values are found in the future (e.g. individual color channels), they can be updated.
36+
SD1.5, TAESD, TAESDXL VAEs distort in less predictable ways, so no compensation is offered at this time.
37+
"""
38+
COMPENSATION_OPTIONS = Literal["None", "SDXL"]
39+
COLOR_COMPENSATION_MAP = {"None": [1, 0], "SDXL": [1.015, -0.002]}
40+
3241

3342
@invocation(
3443
"i2l",
3544
title="Image to Latents - SD1.5, SDXL",
3645
tags=["latents", "image", "vae", "i2l"],
3746
category="latents",
38-
version="1.1.1",
47+
version="1.2.0",
3948
)
4049
class ImageToLatentsInvocation(BaseInvocation):
4150
"""Encodes an image into latents."""
@@ -52,6 +61,10 @@ class ImageToLatentsInvocation(BaseInvocation):
5261
# offer a way to directly set None values.
5362
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
5463
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
64+
color_compensation: COMPENSATION_OPTIONS = InputField(
65+
default="None",
66+
description="Apply VAE scaling compensation when encoding images (reduces color drift).",
67+
)
5568

5669
@classmethod
5770
def vae_encode(
@@ -62,7 +75,7 @@ def vae_encode(
6275
image_tensor: torch.Tensor,
6376
tile_size: int = 0,
6477
) -> torch.Tensor:
65-
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
78+
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)), "VAE must be of type SD-1.5 or SDXL"
6679
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
6780
operation="encode",
6881
image_tensor=image_tensor,
@@ -71,7 +84,7 @@ def vae_encode(
7184
fp32=upcast,
7285
)
7386
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
74-
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
87+
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)), "VAE must be of type SD-1.5 or SDXL"
7588
orig_dtype = vae.dtype
7689
if upcast:
7790
vae.to(dtype=torch.float32)
@@ -127,9 +140,14 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
127140
image = context.images.get_pil(self.image.image_name)
128141

129142
vae_info = context.models.load(self.vae.vae)
130-
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
143+
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)), "VAE must be of type SD-1.5 or SDXL"
131144

132145
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
146+
147+
if self.color_compensation != "None" and vae_info.config.base == BaseModelType.StableDiffusionXL:
148+
scale, bias = COLOR_COMPENSATION_MAP[self.color_compensation]
149+
image_tensor = image_tensor * scale + bias
150+
133151
if image_tensor.dim() == 3:
134152
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
135153

invokeai/frontend/web/public/locales/en.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,6 +1317,7 @@
13171317
"scheduler": "Scheduler",
13181318
"seamlessXAxis": "Seamless X Axis",
13191319
"seamlessYAxis": "Seamless Y Axis",
1320+
"colorCompensation": "Color Compensation",
13201321
"seed": "Seed",
13211322
"imageActions": "Image Actions",
13221323
"sendToCanvas": "Send To Canvas",
@@ -1860,6 +1861,10 @@
18601861
"heading": "Seamless Tiling Y Axis",
18611862
"paragraphs": ["Seamlessly tile an image along the vertical axis."]
18621863
},
1864+
"colorCompensation": {
1865+
"heading": "Color Compensation",
1866+
"paragraphs": ["Adjust the input image to reduce color shifts during inpainting or img2img (SDXL Only)."]
1867+
},
18631868
"upscaleModel": {
18641869
"heading": "Upscale Model",
18651870
"paragraphs": [

invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ export type Feature =
6262
| 'scaleBeforeProcessing'
6363
| 'seamlessTilingXAxis'
6464
| 'seamlessTilingYAxis'
65+
| 'colorCompensation'
6566
| 'upscaleModel'
6667
| 'scale'
6768
| 'creativity'

invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ const slice = createSlice({
170170
shouldUseCpuNoiseChanged: (state, action: PayloadAction<boolean>) => {
171171
state.shouldUseCpuNoise = action.payload;
172172
},
173+
setColorCompensation: (state, action: PayloadAction<boolean>) => {
174+
state.colorCompensation = action.payload;
175+
},
173176
positivePromptChanged: (state, action: PayloadAction<ParameterPositivePrompt>) => {
174177
state.positivePrompt = action.payload;
175178
},
@@ -436,6 +439,7 @@ export const {
436439
clipGEmbedModelSelected,
437440
setClipSkip,
438441
shouldUseCpuNoiseChanged,
442+
setColorCompensation,
439443
positivePromptChanged,
440444
positivePromptAddedToHistory,
441445
promptRemovedFromHistory,
@@ -557,6 +561,7 @@ export const selectShouldRandomizeSeed = createParamsSelector((params) => params
557561
export const selectVAEPrecision = createParamsSelector((params) => params.vaePrecision);
558562
export const selectIterations = createParamsSelector((params) => params.iterations);
559563
export const selectShouldUseCPUNoise = createParamsSelector((params) => params.shouldUseCpuNoise);
564+
export const selectColorCompensation = createParamsSelector((params) => params.colorCompensation);
560565

561566
export const selectUpscaleScheduler = createParamsSelector((params) => params.upscaleScheduler);
562567
export const selectUpscaleCfgScale = createParamsSelector((params) => params.upscaleCfgScale);

invokeai/frontend/web/src/features/controlLayers/store/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@ export const zParamsState = z.object({
596596
seamlessYAxis: z.boolean(),
597597
clipSkip: z.number(),
598598
shouldUseCpuNoise: z.boolean(),
599+
colorCompensation: z.boolean(),
599600
positivePrompt: zParameterPositivePrompt,
600601
positivePromptHistory: zPositivePromptHistory,
601602
negativePrompt: zParameterNegativePrompt,
@@ -645,6 +646,7 @@ export const getInitialParamsState = (): ParamsState => ({
645646
seamlessYAxis: false,
646647
clipSkip: 0,
647648
shouldUseCpuNoise: true,
649+
colorCompensation: false,
648650
positivePrompt: '',
649651
positivePromptHistory: [],
650652
negativePrompt: null,

invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,14 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
4545
scheduler,
4646
steps,
4747
shouldUseCpuNoise,
48+
colorCompensation,
4849
vaePrecision,
4950
vae,
5051
refinerModel,
5152
} = params;
5253

5354
const fp32 = vaePrecision === 'fp32';
55+
const compensation = colorCompensation ? 'SDXL' : 'None';
5456
const prompts = selectPresetModifiedPrompts(state);
5557

5658
const g = new Graph(getPrefixedId('sdxl_graph'));
@@ -178,6 +180,7 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
178180
type: 'i2l',
179181
id: getPrefixedId('i2l'),
180182
fp32,
183+
color_compensation: compensation,
181184
});
182185
canvasOutput = await addImageToImage({
183186
g,
@@ -196,6 +199,7 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
196199
type: 'i2l',
197200
id: getPrefixedId('i2l'),
198201
fp32,
202+
color_compensation: compensation,
199203
});
200204
canvasOutput = await addInpaint({
201205
g,
@@ -216,6 +220,7 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
216220
type: 'i2l',
217221
id: getPrefixedId('i2l'),
218222
fp32,
223+
color_compensation: compensation,
219224
});
220225
canvasOutput = await addOutpaint({
221226
g,
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
2+
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
3+
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
4+
import { selectColorCompensation, setColorCompensation } from 'features/controlLayers/store/paramsSlice';
5+
import type { ChangeEvent } from 'react';
6+
import { memo, useCallback } from 'react';
7+
import { useTranslation } from 'react-i18next';
8+
9+
const ParamColorCompensation = () => {
10+
const { t } = useTranslation();
11+
const colorCompensation = useAppSelector(selectColorCompensation);
12+
13+
const dispatch = useAppDispatch();
14+
15+
const handleChange = useCallback(
16+
(e: ChangeEvent<HTMLInputElement>) => {
17+
dispatch(setColorCompensation(e.target.checked));
18+
},
19+
[dispatch]
20+
);
21+
22+
return (
23+
<FormControl>
24+
<InformationalPopover feature="colorCompensation">
25+
<FormLabel>{t('parameters.colorCompensation')}</FormLabel>
26+
</InformationalPopover>
27+
<Switch isChecked={colorCompensation} onChange={handleChange} />
28+
</FormControl>
29+
);
30+
};
31+
32+
export default memo(ParamColorCompensation);

invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import ParamClipSkip from 'features/parameters/components/Advanced/ParamClipSkip
1212
import ParamT5EncoderModelSelect from 'features/parameters/components/Advanced/ParamT5EncoderModelSelect';
1313
import ParamSeamlessXAxis from 'features/parameters/components/Seamless/ParamSeamlessXAxis';
1414
import ParamSeamlessYAxis from 'features/parameters/components/Seamless/ParamSeamlessYAxis';
15+
import ParamColorCompensation from 'features/parameters/components/VAEModel/ParamColorCompensation';
1516
import ParamFLUXVAEModelSelect from 'features/parameters/components/VAEModel/ParamFLUXVAEModelSelect';
1617
import ParamVAEModelSelect from 'features/parameters/components/VAEModel/ParamVAEModelSelect';
1718
import ParamVAEPrecision from 'features/parameters/components/VAEModel/ParamVAEPrecision';
@@ -97,6 +98,9 @@ export const AdvancedSettingsAccordion = memo(() => {
9798
<ParamSeamlessYAxis />
9899
</FormControlGroup>
99100
</Flex>
101+
<FormControlGroup formLabelProps={formLabelProps}>
102+
<ParamColorCompensation />
103+
</FormControlGroup>
100104
</>
101105
)}
102106
{isFLUX && (

invokeai/frontend/web/src/services/api/schema.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11685,6 +11685,13 @@ export type components = {
1168511685
* @default false
1168611686
*/
1168711687
fp32?: boolean;
11688+
/**
11689+
* Color Compensation
11690+
* @description Apply VAE scaling compensation when encoding images (reduces color drift).
11691+
* @default None
11692+
* @enum {string}
11693+
*/
11694+
color_compensation?: "None" | "SDXL";
1168811695
/**
1168911696
* type
1169011697
* @default i2l

0 commit comments

Comments
 (0)