1+ #
2+ # Copyright IBM Corporation 2022
3+ #
4+ # Licensed under the Apache License, Version 2.0 (the "License");
5+ # you may not use this file except in compliance with the License.
6+ # You may obtain a copy of the License at
7+ #
8+ # http://www.apache.org/licenses/LICENSE-2.0
9+ #
10+ # Unless required by applicable law or agreed to in writing, software
11+ # distributed under the License is distributed on an "AS IS" BASIS,
12+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ # See the License for the specific language governing permissions and
14+ # limitations under the License.
15+ #
16+
17+ from typing import List
18+
119import numpy as np
220import pandas as pd
3- from scipy .stats import norm
421from scipy .stats import multivariate_normal
522from GPy .models import GPRegression
623
7- from dataclasses import dataclass
8- from typing import List
9-
1024from doframework .core .utils import order_stats
1125
1226def plot_joint_distribution (samples : np .array , ** kwargs ):
@@ -33,23 +47,37 @@ def plot_joint_distribution(samples: np.array, **kwargs):
3347 sns .jointplot (data = df , x = cols [0 ], y = cols [1 ], kind = "hex" , xlim = lims , ylim = lims )
3448 sns .lineplot (data = dl , x = cols [0 ], y = cols [1 ])
3549
36- @dataclass
37- class POI :
50+ class POI (object ):
3851 '''
3952 Class for probability of improvement outcomes.
4053 '''
54+
55+ def __init__ (self , point : np .array , probability : float , ** kwargs ):
4156
42- solution : np .array
43- reference : np .array
44- probability : float
45- is_minimum : bool
57+ self .point = point
58+ assert all ([probability >= 0.0 ,probability <= 1.0 ]), f'Probability value should be in [0,1]. Received { probability :.2f} .'
59+ self .probability = probability
60+
61+ self .upper_bound = kwargs ['upper_bound' ] if 'upper_bound' in kwargs else True
62+ self .reference = kwargs ['reference' ] if 'reference' in kwargs else np .array ([])
63+ self .threshold = kwargs ['threshold' ] if 'threshold' in kwargs else None
64+
65+ def __repr__ (self ):
66+ return 'POI(' + '' .join ([f'point={ self .point } ,' ,
67+ f' probability={ self .probability } ,' ,
68+ f' upper_bound={ self .upper_bound } ' ,
69+ ',' * any ([self .reference .size > 0 ]),
70+ f' reference={ self .reference } ' * (self .reference .size > 0 ),
71+ ',' * any ([self .threshold is not None ]),
72+ f' threshold={ self .threshold } ' * (self .threshold is not None )])+ ')'
4673
4774def probability_of_improvement (solutions : np .array , references : np .array , model : GPRegression ,
48- sample_num : int = 100000 , is_constraint : bool = False , is_minimum : bool = True , plot_joint_gaussians : bool = False ,
75+ sample_num : int = 100000 , is_constraint : bool = False , upper_bound : bool = True , plot_joint_gaussians : bool = False ,
4976 ** kwargs ) -> List [POI ]:
5077
5178 sols = np .atleast_2d (solutions )
5279 d = sols .shape [- 1 ]
80+ is_minimum = not upper_bound
5381
5482 if is_constraint :
5583
@@ -81,8 +109,11 @@ def probability_of_improvement(solutions: np.array, references: np.array, model:
81109 else :
82110 mu , cov = model .predict (np .vstack ([sols_rep [i ],refs_rep [i ]]),full_cov = True )
83111 samples = multivariate_normal (mean = mu .flatten (),cov = cov ).rvs (size = sample_num )
84-
85- pois .append (POI (sols_rep [i ],refs_rep [i ],order_stats (samples ,is_minimum ),is_minimum ))
112+
113+ if is_constraint :
114+ pois .append (POI (sols_rep [i ],order_stats (samples ,is_minimum ),upper_bound = upper_bound ,threshold = refs_rep [i ]))
115+ else :
116+ pois .append (POI (sols_rep [i ],order_stats (samples ,is_minimum ),upper_bound = upper_bound ,reference = refs_rep [i ]))
86117
87118 if plot_joint_gaussians and not is_constraint :
88119
0 commit comments