@@ -85,9 +85,41 @@ def neg_pairs_from_tuple(indices_tuple):
8585
8686
8787def get_all_triplets_indices (labels , ref_labels = None ):
88- matches , diffs = get_matches_and_diffs (labels , ref_labels )
89- triplets = matches .unsqueeze (2 ) * diffs .unsqueeze (1 )
90- return torch .where (triplets )
88+ all_matches , all_diffs = get_matches_and_diffs (labels , ref_labels )
89+
90+ if (all_matches .shape [0 ] * all_matches .shape [1 ] * all_matches .shape [1 ]
91+ < torch .iinfo (torch .int32 ).max ):
92+ # torch.nonzero is not supported for tensors with more than INT_MAX elements
93+ triplets = all_matches .unsqueeze (2 ) * all_diffs .unsqueeze (1 )
94+ return torch .where (triplets )
95+
96+ all_matches , all_diffs = all_matches .bool (), all_diffs .bool ()
97+
98+ # Find anchors with at least a positive and a negative
99+ indices = torch .arange (0 , len (labels ), device = labels .device )
100+ indices = indices [all_matches .any (dim = 1 ) & all_diffs .any (dim = 1 )]
101+
102+ # No triplets found
103+ if len (indices ) == 0 :
104+ return (torch .tensor ([], device = labels .device , dtype = labels .dtype ),
105+ torch .tensor ([], device = labels .device , dtype = labels .dtype ),
106+ torch .tensor ([], device = labels .device , dtype = labels .dtype ))
107+
108+ # Compute all triplets
109+ anchors = []
110+ positives = []
111+ negatives = []
112+ for i in indices :
113+ matches = all_matches [i ].nonzero (as_tuple = False ).squeeze (1 )
114+ diffs = all_diffs [i ].nonzero (as_tuple = False ).squeeze (1 )
115+ nd = len (diffs )
116+ nm = len (matches )
117+ matches = matches .repeat_interleave (nd )
118+ diffs = diffs .repeat (nm )
119+ anchors .append (torch .full ((len (matches ),), i , dtype = labels .dtype , device = labels .device ))
120+ positives .append (matches )
121+ negatives .append (diffs )
122+ return torch .cat (anchors ), torch .cat (positives ), torch .cat (negatives )
91123
92124
93125# sample triplets, with a weighted distribution if weights is specified.
0 commit comments