Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions train/dataset/crop_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import cv2
import time
from tqdm import tqdm

parse = argparse.ArgumentParser(description='Generate training data (cropped) for DCFNet_pytorch')
parse.add_argument('-v', '--visual', dest='visual', action='store_true', help='whether visualise crop')
Expand All @@ -13,7 +14,7 @@

args = parse.parse_args()

print args
print (args)


def crop_hwc(image, bbox, out_sz, padding=(0, 0, 0)):
Expand Down Expand Up @@ -47,12 +48,14 @@ def cxy_wh_2_bbox(cxy, wh):

count = 0
begin_time = time.time()
for snap in snaps:
for snap in tqdm(snaps):
frames = snap['frame']
n_frames = len(frames)
for f, frame in enumerate(frames):
for f, frame in enumerate(tqdm(frames,leave=False)):
img_path = join(snap['base_path'], frame['img_path'])
im = cv2.imread(img_path)
#print (img_path)
#print(im.shape)
avg_chans = np.mean(im, axis=(0, 1))
bbox = frame['obj']['bbox']

Expand All @@ -67,17 +70,17 @@ def cxy_wh_2_bbox(cxy, wh):
lmdb['down_index'][count] = f
lmdb['up_index'][count] = n_frames - f
count += 1
if count % 100 == 0:
elapsed = time.time() - begin_time
print("Processed {} images in {:.2f} seconds. "
"{:.2f} images/second.".format(count, elapsed, count / elapsed))
#if count % 100 == 0:
# elapsed = time.time() - begin_time
# print("Processed {} images in {:.2f} seconds. "
# "{:.2f} images/second.".format(count, elapsed, count / elapsed))

template_id = np.where(lmdb['up_index'] > 1)[0] # NEVER use the last frame as template! I do not like bidirectional.
rand_split = np.random.choice(len(template_id), len(template_id))
lmdb['train_set'] = template_id[rand_split[:(len(template_id)-num_val)]]
lmdb['val_set'] = template_id[rand_split[(len(template_id)-num_val):]]
print len(lmdb['train_set'])
print len(lmdb['val_set'])
print ("train set: %i"%len(lmdb['train_set']))
print ("val set: %i"%len(lmdb['val_set']))

# to list for json
lmdb['train_set'] = lmdb['train_set'].tolist()
Expand Down