diff --git a/eend/pytorch_backend/infer.py b/eend/pytorch_backend/infer.py index d7d694a..d737398 100755 --- a/eend/pytorch_backend/infer.py +++ b/eend/pytorch_backend/infer.py @@ -280,13 +280,14 @@ def infer(args): n_samples = n_chunks * args.num_speakers - len(sil_lst) min_n_samples = 2 - if cls_num is not None: + if cls_num is not None and cls_num > min_n_samples: min_n_samples = cls_num if n_samples >= min_n_samples: # clustering (if cls_num is None, update cls_num) - clslab, cls_num =\ - clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst) + clslab, cls_num = clustering( + args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst + ) # merge acti, clslab = merge_acti_clslab(args, acti, clslab, cls_num) # stitching