Skip to content

Commit f283457

Browse files
author
Laurent Erignoux
committed
Making the resolutions parameters optional to avoid distortion when using images which size is not multiple of 64
1 parent f1a73cd commit f283457

File tree

15 files changed

+45
-30
lines changed

15 files changed

+45
-30
lines changed

src/controlnet_aux/canny/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ def __call__(self, input_image=None, low_threshold=100, high_threshold=200, dete
2020
output_type = output_type or "np"
2121

2222
input_image = HWC3(input_image)
23-
input_image = resize_image(input_image, detect_resolution)
23+
if detect_resolution is not None:
24+
input_image = resize_image(input_image, detect_resolution)
2425

2526
detected_map = cv2.Canny(input_image, low_threshold, high_threshold)
2627
detected_map = HWC3(detected_map)
2728

28-
img = resize_image(input_image, image_resolution)
29+
img = resize_image(input_image, image_resolution) if image_resolution is not None else input_image
2930
H, W, C = img.shape
3031

3132
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

src/controlnet_aux/dwpose/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, out
4545
input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
4646

4747
input_image = HWC3(input_image)
48-
input_image = resize_image(input_image, detect_resolution)
48+
if detect_resolution is not None:
49+
input_image = resize_image(input_image, detect_resolution)
4950
H, W, C = input_image.shape
5051

5152
with torch.no_grad():
@@ -80,7 +81,7 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, out
8081
detected_map = draw_pose(pose, H, W)
8182
detected_map = HWC3(detected_map)
8283

83-
img = resize_image(input_image, image_resolution)
84+
img = resize_image(input_image, image_resolution) if image_resolution is not None else input_image
8485
H, W, C = img.shape
8586

8687
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

src/controlnet_aux/hed/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, saf
9393
input_image = np.array(input_image, dtype=np.uint8)
9494

9595
input_image = HWC3(input_image)
96-
input_image = resize_image(input_image, detect_resolution)
96+
if detect_resolution is not None:
97+
input_image = resize_image(input_image, detect_resolution)
9798

9899
assert input_image.ndim == 3
99100
H, W, C = input_image.shape
@@ -112,7 +113,7 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, saf
112113
detected_map = edge
113114
detected_map = HWC3(detected_map)
114115

115-
img = resize_image(input_image, image_resolution)
116+
img = resize_image(input_image, image_resolution) if image_resolution is not None else input_image
116117
H, W, C = img.shape
117118

118119
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

src/controlnet_aux/leres/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def __call__(self, input_image, thr_a=0, thr_b=0, boost=False, detect_resolution
6262
input_image = np.array(input_image, dtype=np.uint8)
6363

6464
input_image = HWC3(input_image)
65-
input_image = resize_image(input_image, detect_resolution)
65+
if detect_resolution is not None:
66+
input_image = resize_image(input_image, detect_resolution)
6667

6768
assert input_image.ndim == 3
6869
height, width, dim = input_image.shape
@@ -107,7 +108,7 @@ def __call__(self, input_image, thr_a=0, thr_b=0, boost=False, detect_resolution
107108
detected_map = depth_image
108109
detected_map = HWC3(detected_map)
109110

110-
img = resize_image(input_image, image_resolution)
111+
img = resize_image(input_image, image_resolution) if image_resolution is not None else input_image
111112
H, W, C = img.shape
112113

113114
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

src/controlnet_aux/lineart/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ def __call__(self, input_image, coarse=False, detect_resolution=512, image_resol
137137
input_image = np.array(input_image, dtype=np.uint8)
138138

139139
input_image = HWC3(input_image)
140-
input_image = resize_image(input_image, detect_resolution)
140+
if detect_resolution is not None:
141+
input_image = resize_image(input_image, detect_resolution)
141142

142143
model = self.model_coarse if coarse else self.model
143144
assert input_image.ndim == 3
@@ -155,7 +156,7 @@ def __call__(self, input_image, coarse=False, detect_resolution=512, image_resol
155156

156157
detected_map = HWC3(detected_map)
157158

158-
img = resize_image(input_image, image_resolution)
159+
img = resize_image(input_image, image_resolution) if image_resolution is not None else input_image
159160
H, W, C = img.shape
160161

161162
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

src/controlnet_aux/lineart_anime/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, out
156156
input_image = np.array(input_image, dtype=np.uint8)
157157

158158
input_image = HWC3(input_image)
159-
input_image = resize_image(input_image, detect_resolution)
159+
if detect_resolution is not None:
160+
input_image = resize_image(input_image, detect_resolution)
160161

161162
H, W, C = input_image.shape
162163
Hn = 256 * int(np.ceil(float(H) / 256.0))
@@ -177,7 +178,7 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, out
177178

178179
detected_map = HWC3(detected_map)
179180

180-
img = resize_image(input_image, image_resolution)
181+
img = resize_image(input_image, image_resolution) if image_resolution is not None else input_image
181182
H, W, C = img.shape
182183

183184
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

src/controlnet_aux/mediapipe_face/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,13 @@ def __call__(self,
3737
input_image = np.array(input_image, dtype=np.uint8)
3838

3939
input_image = HWC3(input_image)
40-
input_image = resize_image(input_image, detect_resolution)
40+
if detect_resolution is not None:
41+
input_image = resize_image(input_image, detect_resolution)
4142

4243
detected_map = generate_annotation(input_image, max_faces, min_confidence)
4344
detected_map = HWC3(detected_map)
4445

45-
img = resize_image(input_image, image_resolution)
46+
img = resize_image(input_image, image_resolution) if image_resolution is not None else input_image
4647
H, W, C = img.shape
4748

4849
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

src/controlnet_aux/midas/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1, depth_and_normal=False
4545
output_type = output_type or "np"
4646

4747
input_image = HWC3(input_image)
48-
input_image = resize_image(input_image, detect_resolution)
48+
if detect_resolution is not None:
49+
input_image = resize_image(input_image, detect_resolution)
4950

5051
assert input_image.ndim == 3
5152
image_depth = input_image
@@ -77,7 +78,7 @@ def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1, depth_and_normal=False
7778
if depth_and_normal:
7879
normal_image = HWC3(normal_image)
7980

80-
img = resize_image(input_image, image_resolution)
81+
img = resize_image(input_image, image_resolution) if image_resolution is not None else input_image
8182
H, W, C = img.shape
8283

8384
depth_image = cv2.resize(depth_image, (W, H), interpolation=cv2.INTER_LINEAR)

src/controlnet_aux/mlsd/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def __call__(self, input_image, thr_v=0.1, thr_d=0.1, detect_resolution=512, ima
5151
input_image = np.array(input_image, dtype=np.uint8)
5252

5353
input_image = HWC3(input_image)
54-
input_image = resize_image(input_image, detect_resolution)
54+
if detect_resolution is not None:
55+
input_image = resize_image(input_image, detect_resolution)
5556

5657
assert input_image.ndim == 3
5758
img = input_image
@@ -68,7 +69,7 @@ def __call__(self, input_image, thr_v=0.1, thr_d=0.1, detect_resolution=512, ima
6869
detected_map = img_output[:, :, 0]
6970
detected_map = HWC3(detected_map)
7071

71-
img = resize_image(input_image, image_resolution)
72+
img = resize_image(input_image, image_resolution) if image_resolution is not None else input_image
7273
H, W, C = img.shape
7374

7475
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

src/controlnet_aux/normalbae/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, out
7474
input_image = np.array(input_image, dtype=np.uint8)
7575

7676
input_image = HWC3(input_image)
77-
input_image = resize_image(input_image, detect_resolution)
77+
if detect_resolution is not None:
78+
input_image = resize_image(input_image, detect_resolution)
7879

7980
assert input_image.ndim == 3
8081
image_normal = input_image
@@ -97,7 +98,7 @@ def __call__(self, input_image, detect_resolution=512, image_resolution=512, out
9798
detected_map = normal_image
9899
detected_map = HWC3(detected_map)
99100

100-
img = resize_image(input_image, image_resolution)
101+
img = resize_image(input_image, image_resolution) if image_resolution is not None else input_image
101102
H, W, C = img.shape
102103

103104
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

0 commit comments

Comments
 (0)