@@ -85,9 +85,34 @@ def neg_pairs_from_tuple(indices_tuple):
8585
8686
8787def get_all_triplets_indices (labels , ref_labels = None ):
88- matches , diffs = get_matches_and_diffs (labels , ref_labels )
89- triplets = matches .unsqueeze (2 ) * diffs .unsqueeze (1 )
90- return torch .where (triplets )
88+ all_matches , all_diffs = get_matches_and_diffs (labels , ref_labels )
89+ all_matches , all_diffs = all_matches .bool (), all_diffs .bool ()
90+
91+ # Find anchors with at least a positive and a negative
92+ indices = torch .arange (0 , len (labels ), device = labels .device )
93+ indices = indices [all_matches .any (dim = 1 ) & all_diffs .any (dim = 1 )]
94+
95+ # No triplets found
96+ if len (indices ) == 0 :
97+ return (torch .tensor ([], device = labels .device , dtype = labels .dtype ),
98+ torch .tensor ([], device = labels .device , dtype = labels .dtype ),
99+ torch .tensor ([], device = labels .device , dtype = labels .dtype ))
100+
101+ # Compute all triplets
102+ anchors = []
103+ positives = []
104+ negatives = []
105+ for i in indices :
106+ matches = all_matches [i ].nonzero (as_tuple = False ).squeeze (1 )
107+ diffs = all_diffs [i ].nonzero (as_tuple = False ).squeeze (1 )
108+ nd = len (diffs )
109+ nm = len (matches )
110+ matches = matches .repeat_interleave (nd )
111+ diffs = diffs .repeat (nm )
112+ anchors .append (torch .full ((len (matches ),), i , dtype = labels .dtype , device = labels .device ))
113+ positives .append (matches )
114+ negatives .append (diffs )
115+ return torch .cat (anchors ), torch .cat (positives ), torch .cat (negatives )
91116
92117
93118# sample triplets, with a weighted distribution if weights is specified.
0 commit comments