In my experiments (a different setup but using discriminative loss), I am finding the loss computation to be very time-consuming. I suspect the lag is coming from the tf.unsorted_segment_sum. Have you noticed such latency issues in your experiments? Is there perhaps a way to speed up/optimize the loss computation?