Skip to content

Commit fe687c2

Browse files
Refactor - Move functions into utils, move sub ood methods
1 parent f97a151 commit fe687c2

File tree

6 files changed

+615
-568
lines changed

6 files changed

+615
-568
lines changed

geti_sdk/detect_ood/ood_data.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions
13+
# and limitations under the License.
14+
15+
from enum import Enum
16+
from typing import Union
17+
18+
import numpy as np
19+
20+
from geti_sdk.data_models import Prediction
21+
22+
23+
class DistributionDataItemPurpose(Enum):
24+
"""
25+
Enum to represent the purpose of the DistributionDataItem.
26+
This is used during splitting of the data into TRAIN, VAL, TEST
27+
"""
28+
29+
TRAIN = "train"
30+
VAL = "val"
31+
TEST = "test"
32+
33+
34+
class DistributionDataItem:
35+
"""
36+
A class to store the data for the COOD model.
37+
An DistributionDataItem for an image contains the following:
38+
- media_name: Name of the media (optional)
39+
- image_path: Path to the image (optional)
40+
- annotated_label: Annotated label for the image (optional)
41+
- raw_prediction: Prediction object for the image (required)
42+
- feature_vector: Feature vector extracted from the image (extracted from raw_prediction)
43+
44+
All OOD models take a list of DistributionDataItems as input for training and inference.
45+
"""
46+
47+
def __init__(
48+
self,
49+
raw_prediction: Prediction,
50+
media_name: Union[str, None],
51+
media_path: Union[str, None],
52+
annotated_label: Union[str, None],
53+
normalise_feature_vector: bool = True,
54+
purpose: Union[DistributionDataItemPurpose, None] = None,
55+
):
56+
self.media_name = media_name
57+
self.image_path = media_path
58+
self.annotated_label = annotated_label
59+
self.raw_prediction = raw_prediction
60+
self.purpose = purpose
61+
62+
feature_vector = raw_prediction.feature_vector
63+
64+
if len(feature_vector.shape) != 1:
65+
feature_vector = feature_vector.flatten()
66+
67+
if normalise_feature_vector:
68+
feature_vector = self.normalise_features(feature_vector)[0]
69+
70+
self._normalise_feature_vector = normalise_feature_vector
71+
self.feature_vector = feature_vector
72+
self.max_prediction_probability = (
73+
raw_prediction.annotations[0].labels[0].probability,
74+
)
75+
self.predicted_label = raw_prediction.annotations[0].labels[0].name
76+
77+
@property
78+
def is_feature_vector_normalised(self) -> bool:
79+
"""
80+
Return True if the feature vector is normalised.
81+
"""
82+
return self._normalise_feature_vector
83+
84+
@staticmethod
85+
def normalise_features(feature_vectors: np.ndarray) -> np.ndarray:
86+
"""
87+
Feature embeddings are normalised by dividing each feature embedding vector by its respective 2nd-order vector
88+
norm (vector Euclidean norm). It has been shown that normalising feature embeddings lead to a significant
89+
improvement in OOD detection.
90+
:param feature_vectors: Feature vectors to normalise
91+
:return: Normalised feature vectors.
92+
"""
93+
if len(feature_vectors.shape) == 1:
94+
feature_vectors = feature_vectors.reshape(1, -1)
95+
96+
return feature_vectors / (
97+
np.linalg.norm(feature_vectors, axis=1, keepdims=True) + 1e-10
98+
)
99+
100+
def __repr__(self):
101+
"""
102+
Return a string representation of the DistributionDataItem.
103+
"""
104+
return (
105+
f"DataItem(media_name={self.media_name}, "
106+
f"shape(feature_vector)={self.feature_vector.shape}), "
107+
f"feature_vector normalised={self.is_feature_vector_normalised})"
108+
)

0 commit comments

Comments
 (0)