Skip to content

Commit a7a981e

Browse files
committed
new POI class
Signed-off-by: Orit Davidovich <orit.davidovich@protonmail.com>
1 parent 4448d22 commit a7a981e

File tree

1 file changed

+44
-13
lines changed

1 file changed

+44
-13
lines changed

doframework/core/poi.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
1+
#
2+
# Copyright IBM Corporation 2022
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
from typing import List
18+
119
import numpy as np
220
import pandas as pd
3-
from scipy.stats import norm
421
from scipy.stats import multivariate_normal
522
from GPy.models import GPRegression
623

7-
from dataclasses import dataclass
8-
from typing import List
9-
1024
from doframework.core.utils import order_stats
1125

1226
def plot_joint_distribution(samples: np.array, **kwargs):
@@ -33,23 +47,37 @@ def plot_joint_distribution(samples: np.array, **kwargs):
3347
sns.jointplot(data=df, x=cols[0], y=cols[1], kind="hex", xlim=lims, ylim=lims)
3448
sns.lineplot(data=dl, x=cols[0], y=cols[1])
3549

36-
@dataclass
37-
class POI:
50+
class POI(object):
3851
'''
3952
Class for probability of improvement outcomes.
4053
'''
54+
55+
def __init__(self, point: np.array, probability: float, **kwargs):
4156

42-
solution: np.array
43-
reference: np.array
44-
probability: float
45-
is_minimum: bool
57+
self.point = point
58+
assert all([probability>=0.0,probability<=1.0]), f'Probability value should be in [0,1]. Received {probability:.2f}.'
59+
self.probability = probability
60+
61+
self.upper_bound = kwargs['upper_bound'] if 'upper_bound' in kwargs else True
62+
self.reference = kwargs['reference'] if 'reference' in kwargs else np.array([])
63+
self.threshold = kwargs['threshold'] if 'threshold' in kwargs else None
64+
65+
def __repr__(self):
66+
return 'POI('+''.join([f'point={self.point},',
67+
f' probability={self.probability},',
68+
f' upper_bound={self.upper_bound}',
69+
','*any([self.reference.size > 0]),
70+
f' reference={self.reference}'*(self.reference.size > 0),
71+
','*any([self.threshold is not None]),
72+
f' threshold={self.threshold}'*(self.threshold is not None)])+')'
4673

4774
def probability_of_improvement(solutions: np.array, references: np.array, model: GPRegression,
48-
sample_num: int=100000, is_constraint: bool=False, is_minimum: bool=True, plot_joint_gaussians: bool=False,
75+
sample_num: int=100000, is_constraint: bool=False, upper_bound: bool=True, plot_joint_gaussians: bool=False,
4976
**kwargs) -> List[POI]:
5077

5178
sols = np.atleast_2d(solutions)
5279
d = sols.shape[-1]
80+
is_minimum = not upper_bound
5381

5482
if is_constraint:
5583

@@ -81,8 +109,11 @@ def probability_of_improvement(solutions: np.array, references: np.array, model:
81109
else:
82110
mu, cov = model.predict(np.vstack([sols_rep[i],refs_rep[i]]),full_cov=True)
83111
samples = multivariate_normal(mean=mu.flatten(),cov=cov).rvs(size=sample_num)
84-
85-
pois.append(POI(sols_rep[i],refs_rep[i],order_stats(samples,is_minimum),is_minimum))
112+
113+
if is_constraint:
114+
pois.append(POI(sols_rep[i],order_stats(samples,is_minimum),upper_bound=upper_bound,threshold=refs_rep[i]))
115+
else:
116+
pois.append(POI(sols_rep[i],order_stats(samples,is_minimum),upper_bound=upper_bound,reference=refs_rep[i]))
86117

87118
if plot_joint_gaussians and not is_constraint:
88119

0 commit comments

Comments
 (0)