Skip to content

Commit 25aa6ec

Browse files
ly015Ben-Louis
authored andcommitted
ae inference align with master
1 parent fd39228 commit 25aa6ec

File tree

5 files changed

+146
-41
lines changed

5 files changed

+146
-41
lines changed

configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
decoder=dict(codec, heatmap_size=codec['input_size'])),
9797
test_cfg=dict(
9898
multiscale_test=False,
99-
flip_test=True,
99+
flip_test=False,
100100
shift_heatmap=True,
101101
restore_heatmap_size=True,
102102
align_corners=False))
@@ -113,9 +113,14 @@
113113
dict(
114114
type='BottomupResize',
115115
input_size=codec['input_size'],
116-
size_factor=32,
116+
size_factor=64,
117117
resize_mode='expand'),
118-
dict(type='PackPoseInputs')
118+
dict(
119+
type='PackPoseInputs',
120+
meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape',
121+
'img_shape', 'input_size', 'input_center', 'input_scale',
122+
'flip', 'flip_direction', 'flip_indices', 'raw_ann_info',
123+
'skeleton_links'))
119124
]
120125

121126
# data loaders
@@ -142,7 +147,7 @@
142147
type=dataset_type,
143148
data_root=data_root,
144149
data_mode=data_mode,
145-
ann_file='annotations/person_keypoints_val2017.json',
150+
ann_file='annotations/person_keypoints_val2017_tiny_clean.json',
146151
data_prefix=dict(img='val2017/'),
147152
test_mode=True,
148153
pipeline=val_pipeline,
@@ -152,7 +157,8 @@
152157
# evaluators
153158
val_evaluator = dict(
154159
type='CocoMetric',
155-
ann_file=data_root + 'annotations/person_keypoints_val2017.json',
160+
ann_file=data_root +
161+
'annotations/person_keypoints_val2017_tiny_clean.json',
156162
nms_mode='none',
157163
score_mode='keypoint',
158164
)

demo/bottomup_demo.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
23
import mimetypes
34
import os
45
import time
6+
import os.path as osp
7+
import tempfile
58
from argparse import ArgumentParser
69

710
import cv2
@@ -208,19 +211,61 @@ def main():
208211
cap.release()
209212

210213
else:
211-
args.save_predictions = False
212-
raise ValueError(
213-
f'file {os.path.basename(args.input)} has invalid format.')
214+
inputs = [osp.join(args.input, fn) for fn in os.listdir(args.input)]
214215

215-
if args.save_predictions:
216-
with open(args.pred_save_path, 'w') as f:
217-
json.dump(
218-
dict(
219-
meta_info=model.dataset_meta,
220-
instance_info=pred_instances_list),
221-
f,
222-
indent='\t')
223-
print(f'predictions have been saved at {args.pred_save_path}')
216+
for fn in inputs:
217+
218+
input_type = mimetypes.guess_type(fn)[0].split('/')[0]
219+
if input_type == 'image':
220+
pred_instances = process_one_image(
221+
args, fn, model, visualizer, show_interval=0)
222+
pred_instances_list = split_instances(pred_instances)
223+
224+
elif input_type == 'video':
225+
tmp_folder = tempfile.TemporaryDirectory()
226+
video = mmcv.VideoReader(fn)
227+
progressbar = mmengine.ProgressBar(len(video))
228+
video.cvt2frames(tmp_folder.name, show_progress=False)
229+
output_root = args.output_root
230+
args.output_root = tmp_folder.name
231+
pred_instances_list = []
232+
233+
for frame_id, img_fname in enumerate(os.listdir(tmp_folder.name)):
234+
pred_instances = process_one_image(
235+
args,
236+
f'{tmp_folder.name}/{img_fname}',
237+
model,
238+
visualizer,
239+
show_interval=1)
240+
progressbar.update()
241+
pred_instances_list.append(
242+
dict(
243+
frame_id=frame_id,
244+
instances=split_instances(pred_instances)))
245+
246+
if output_root:
247+
mmcv.frames2video(
248+
tmp_folder.name,
249+
f'{output_root}/{os.path.basename(fn)}',
250+
fps=video.fps,
251+
fourcc='mp4v',
252+
show_progress=False)
253+
tmp_folder.cleanup()
254+
255+
else:
256+
args.save_predictions = False
257+
raise ValueError(
258+
f'file {os.path.basename(fn)} has invalid format.')
259+
260+
if args.save_predictions:
261+
with open(args.pred_save_path, 'w') as f:
262+
json.dump(
263+
dict(
264+
meta_info=model.dataset_meta,
265+
instance_info=pred_instances_list),
266+
f,
267+
indent='\t')
268+
print(f'predictions have been saved at {args.pred_save_path}')
224269

225270

226271
if __name__ == '__main__':

mmpose/codecs/associative_embedding.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from collections import namedtuple
3+
from copy import deepcopy
34
from itertools import product
45
from typing import Any, List, Optional, Tuple
56

67
import numpy as np
78
import torch
9+
from mmengine import dump
810
from munkres import Munkres
911
from torch import Tensor
1012

@@ -75,7 +77,9 @@ def _init_group():
7577
tag_list=[])
7678
return _group
7779

78-
for i in keypoint_order:
80+
group_history = []
81+
82+
for idx, i in enumerate(keypoint_order):
7983
# Get all valid candidate of the i-th keypoints
8084
valid = vals[i] > val_thr
8185
if not valid.any():
@@ -87,12 +91,22 @@ def _init_group():
8791

8892
if len(groups) == 0: # Initialize the group pool
8993
for tag, val, loc in zip(tags_i, vals_i, locs_i):
94+
95+
# Check if the keypoint belongs to existing groups
96+
if len(groups):
97+
prev_tags = np.stack([g.tag_list[0] for g in groups])
98+
dists = np.linalg.norm(prev_tags - tag, ord=2, axis=1)
99+
if dists.min() < 1:
100+
continue
101+
90102
group = _init_group()
91103
group.kpts[i] = loc
92104
group.scores[i] = val
93105
group.tag_list.append(tag)
94106

95107
groups.append(group)
108+
costs_copy = None
109+
matches = None
96110

97111
else: # Match keypoints to existing groups
98112
groups = groups[:max_groups]
@@ -101,17 +115,18 @@ def _init_group():
101115
# Calculate distance matrix between group tags and tag candidates
102116
# of the i-th keypoint
103117
# Shape: (M', 1, L) , (1, G, L) -> (M', G, L)
104-
diff = tags_i[:, None] - np.array(group_tags)[None]
118+
diff = (tags_i[:, None] -
119+
np.array(group_tags)[None]).astype(np.float64)
105120
dists = np.linalg.norm(diff, ord=2, axis=2)
106121
num_kpts, num_groups = dists.shape[:2]
107122

108-
# Experimental cost function for keypoint-group matching
123+
# Experimental cost function for keypoint-group matching2
109124
costs = np.round(dists) * 100 - vals_i[..., None]
125+
110126
if num_kpts > num_groups:
111-
padding = np.full((num_kpts, num_kpts - num_groups),
112-
1e10,
113-
dtype=np.float32)
127+
padding = np.full((num_kpts, num_kpts - num_groups), 1e10)
114128
costs = np.concatenate((costs, padding), axis=1)
129+
costs_copy = costs.copy()
115130

116131
# Match keypoints and groups by Munkres algorithm
117132
matches = munkres.compute(costs)
@@ -121,13 +136,30 @@ def _init_group():
121136
# Add the keypoint to the matched group
122137
group = groups[group_idx]
123138
else:
124-
# Initialize a new group with unmatched keypoint
125-
group = _init_group()
126-
groups.append(group)
127-
128-
group.kpts[i] = locs_i[kpt_idx]
129-
group.scores[i] = vals_i[kpt_idx]
130-
group.tag_list.append(tags_i[kpt_idx])
139+
# if dists[kpt_idx].min() < 0.2:
140+
if False:
141+
group = None
142+
else:
143+
# Initialize a new group with unmatched keypoint
144+
group = _init_group()
145+
groups.append(group)
146+
if group is not None:
147+
group.kpts[i] = locs_i[kpt_idx]
148+
group.scores[i] = vals_i[kpt_idx]
149+
group.tag_list.append(tags_i[kpt_idx])
150+
151+
out = {
152+
'idx': idx,
153+
'i': i,
154+
'costs': costs_copy,
155+
'matches': matches,
156+
'kpts': np.array([g.kpts for g in groups]),
157+
'scores': np.array([g.scores for g in groups]),
158+
'tag_list': [np.array(g.tag_list) for g in groups],
159+
}
160+
group_history.append(deepcopy(out))
161+
162+
dump(group_history, 'group_history.pkl')
131163

132164
groups = groups[:max_groups]
133165
if groups:
@@ -210,7 +242,7 @@ def __init__(
210242
decode_gaussian_kernel: int = 3,
211243
decode_keypoint_thr: float = 0.1,
212244
decode_tag_thr: float = 1.0,
213-
decode_topk: int = 20,
245+
decode_topk: int = 30,
214246
decode_max_instances: Optional[int] = None,
215247
) -> None:
216248
super().__init__()
@@ -336,6 +368,12 @@ def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor,
336368
B, K, H, W = batch_heatmaps.shape
337369
L = batch_tags.shape[1] // K
338370

371+
# Heatmap NMS
372+
dump(batch_heatmaps.cpu().numpy(), 'heatmaps.pkl')
373+
batch_heatmaps = batch_heatmap_nms(batch_heatmaps,
374+
self.decode_nms_kernel)
375+
dump(batch_heatmaps.cpu().numpy(), 'heatmaps_nms.pkl')
376+
339377
# shape of topk_val, top_indices: (B, K, TopK)
340378
topk_vals, topk_indices = batch_heatmaps.flatten(-2, -1).topk(
341379
k, dim=-1)
@@ -433,9 +471,8 @@ def _fill_missing_keypoints(self, keypoints: np.ndarray,
433471
cost_map = np.round(dist_map) * 100 - heatmaps[k] # H, W
434472
y, x = np.unravel_index(np.argmin(cost_map), shape=(H, W))
435473
keypoints[n, k] = [x, y]
436-
keypoint_scores[n, k] = heatmaps[k, y, x]
437474

438-
return keypoints, keypoint_scores
475+
return keypoints
439476

440477
def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
441478
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
@@ -457,15 +494,12 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
457494
batch, each is in shape (N, K). It usually represents the
458495
confidience of the keypoint prediction
459496
"""
497+
460498
B, _, H, W = batch_heatmaps.shape
461499
assert batch_tags.shape[0] == B and batch_tags.shape[2:4] == (H, W), (
462500
f'Mismatched shapes of heatmap ({batch_heatmaps.shape}) and '
463501
f'tagging map ({batch_tags.shape})')
464502

465-
# Heatmap NMS
466-
batch_heatmaps = batch_heatmap_nms(batch_heatmaps,
467-
self.decode_nms_kernel)
468-
469503
# Get top-k in each heatmap and and convert to numpy
470504
batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy(
471505
self._get_batch_topk(
@@ -489,7 +523,7 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
489523

490524
if keypoints.size > 0:
491525
# identify missing keypoints
492-
keypoints, scores = self._fill_missing_keypoints(
526+
keypoints = self._fill_missing_keypoints(
493527
keypoints, scores, heatmaps, tags)
494528

495529
# refine keypoint coordinates according to heatmap distribution
@@ -500,6 +534,8 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
500534
blur_kernel_size=self.decode_gaussian_kernel)
501535
else:
502536
keypoints = refine_keypoints(keypoints, heatmaps)
537+
# keypoints += 0.75
538+
keypoints += 0.5
503539

504540
batch_keypoints.append(keypoints)
505541
batch_keypoint_scores.append(scores)

mmpose/datasets/transforms/bottomup_transforms.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ def transform(self, results: Dict) -> Optional[dict]:
484484
output_size=actual_input_size)
485485
else:
486486
center = np.array([img_w / 2, img_h / 2], dtype=np.float32)
487+
center = np.round(center)
487488
scale = np.array([
488489
img_w * padded_input_size[0] / actual_input_size[0],
489490
img_h * padded_input_size[1] / actual_input_size[1]
@@ -495,11 +496,18 @@ def transform(self, results: Dict) -> Optional[dict]:
495496
rot=0,
496497
output_size=padded_input_size)
497498

498-
_img = cv2.warpAffine(
499-
img, warp_mat, padded_input_size, flags=cv2.INTER_LINEAR)
499+
_img = cv2.warpAffine(img, warp_mat, padded_input_size)
500500

501501
imgs.append(_img)
502502

503+
# print('#' * 20)
504+
# print('w,h: ', img_w, img_h, 'center: ', center, 'scale: ',
505+
# scale,
506+
# 'actual_input_size: ', actual_input_size,
507+
# 'padded_input_size: ', padded_input_size)
508+
# print(warp_mat)
509+
# print('#' * 20)
510+
503511
# Store the transform information w.r.t. the main input size
504512
if i == 0:
505513
results['img_shape'] = padded_input_size[::-1]

mmpose/models/heads/heatmap_heads/ae_head.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List, Optional, Sequence, Tuple, Union
33

44
import torch
5+
import torch.nn.functional as F
56
from mmengine.structures import PixelData
67
from mmengine.utils import is_list_of
78
from torch import Tensor
@@ -110,7 +111,7 @@ def predict(self,
110111
# TTA: multi-scale test
111112
assert is_list_of(feats, list if flip_test else tuple)
112113
else:
113-
assert is_list_of(feats, tuple if flip_test else Tensor)
114+
assert isinstance(feats, list if flip_test else tuple)
114115
feats = [feats]
115116

116117
# resize heatmaps to align with with input size
@@ -129,6 +130,15 @@ def predict(self,
129130
for scale_idx, _feats in enumerate(feats):
130131
if not flip_test:
131132
_heatmaps, _tags = self.forward(_feats)
133+
if heatmap_size:
134+
_heatmaps = F.interpolate(
135+
_heatmaps, (img_h, img_w),
136+
mode='bilinear',
137+
align_corners=align_corners)
138+
_tags = F.interpolate(
139+
_tags, (img_h, img_w),
140+
mode='bilinear',
141+
align_corners=align_corners)
132142

133143
else:
134144
# TTA: flip test

0 commit comments

Comments
 (0)