Skip to content

Commit 8dca8cc

Browse files
🚀 feat(model): Enable Patchcore Training Half Precision (#3031)
patchcore infer dtype from model dtype Co-authored-by: Rajesh Gangireddy <rajesh.gangireddy@intel.com>
1 parent b066588 commit 8dca8cc

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

src/anomalib/models/components/dimensionality_reduction/random_projection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,4 +193,4 @@ def transform(self, embedding: torch.Tensor) -> torch.Tensor:
193193
msg = "`fit()` has not been called on SparseRandomProjection yet."
194194
raise NotFittedError(msg)
195195

196-
return embedding @ self.sparse_random_matrix.T.float()
196+
return embedding @ self.sparse_random_matrix.T.type(embedding.dtype)

src/anomalib/models/image/patchcore/torch_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch:
153153
... else:
154154
... assert isinstance(output, InferenceBatch)
155155
"""
156+
input_tensor = input_tensor.type(self.memory_bank.dtype)
156157
output_size = input_tensor.shape[-2:]
157158
if self.tiler:
158159
input_tensor = self.tiler.tile(input_tensor)

0 commit comments

Comments
 (0)