1212# See the License for the specific language governing permissions
1313# and limitations under the License.
1414
15- import itertools
1615import logging
1716import os
1817import tempfile
3837 get_deployment_with_xai_head ,
3938 image_to_distribution_data_item ,
4039 load_annotations ,
40+ ood_metrics_to_string ,
4141 split_data ,
4242)
4343
@@ -88,6 +88,7 @@ def __init__(
8888 }
8989 self ._thresholds_prefix = "threshold_"
9090 self .train_test_split_ratio = 0.7 # The ratio of train-test split
91+ self .eval_metrics = {}
9192
9293 if isinstance (project , str ):
9394 project_name = project
@@ -189,10 +190,10 @@ def __init__(
189190 eval_metrics_test = self ._test_model (
190191 test_id_data = test_id_data , test_ood_data = test_ood_data
191192 )
192- logging . info ( "COOD Model Metrics on Test Data: " )
193- for metric in eval_metrics_test :
194- if not metric . startswith ( self . _thresholds_prefix ):
195- logging . info ( f" { metric } : { eval_metrics_test [ metric ] } " )
193+ self . eval_metrics [ "test" ] = eval_metrics_test
194+ logging . info (
195+ f"COOD Model Metrics on Test Data: { ood_metrics_to_string ( eval_metrics_test ) } "
196+ )
196197
197198 def _prepare_id_ood_data (self ) -> dict :
198199 """
@@ -387,9 +388,18 @@ def _train_cood_model(
387388 y_pred_prob = pred_probabilities_val ,
388389 )
389390
390- logging .info (f"COOD Model Metrics on Train Data: { eval_metrics_train } " )
391- logging .info (f"COOD Model Metrics on Validation Data: { eval_metrics_val } " )
391+ logging .info (
392+ f"COOD Model Metrics on Training Data: \n { ood_metrics_to_string (eval_metrics_train )} "
393+ )
392394
395+ logging .info (
396+ f"COOD Model Metrics on Validation Data: \n { ood_metrics_to_string (eval_metrics_val )} "
397+ )
398+
399+ self .eval_metrics ["train" ] = eval_metrics_train
400+ self .eval_metrics ["val" ] = eval_metrics_val
401+
402+ # Update the prediction thresholds based on validation data
393403 self ._update_thresholds (eval_metrics = eval_metrics_val )
394404
395405 def _update_thresholds (self , eval_metrics : dict ) -> None :
@@ -407,91 +417,6 @@ def _update_thresholds(self, eval_metrics: dict) -> None:
407417 self ._thresholds_prefix + threshold_name
408418 ]
409419
410- def _train_cood_hpo (
411- self ,
412- id_features_train ,
413- id_features_val ,
414- ood_features_train ,
415- ood_features_val ,
416- ) -> None :
417- """
418- Train the COOD model using the RandomForestClassifier with hyperparameter optimization.
419- :param id_features_train: Numpy array of COOD features for in-distribution training data
420- :param id_features_val: Numpy array of COOD features for in-distribution validation data
421- :param ood_features_train: Numpy array of COOD features for out-of-distribution training data
422- :param ood_features_val: Numpy array of COOD features for out-of-distribution validation data
423- """
424- logging .info ("Training COOD Model with Hyperparameter Optimization" )
425- logging .info (
426- f"Training data: ID - { len (id_features_train )} , OOD - { len (ood_features_train )} "
427- )
428-
429- all_features_train = np .concatenate ((id_features_train , ood_features_train ))
430- all_labels_train = np .concatenate (
431- (np .zeros (len (id_features_train )), np .ones (len (ood_features_train )))
432- )
433- all_features_val = np .concatenate ((id_features_val , ood_features_val ))
434- all_labels_val = np .concatenate (
435- (np .zeros (len (id_features_val )), np .ones (len (ood_features_val )))
436- )
437-
438- # Hyperparameter optimization
439- n_estimators = [10 , 25 , 50 , 100 , 250 ]
440- max_depth = [2 , 4 , 8 , 16 ]
441-
442- best_accuracy = 0
443- best_params = {}
444- best_model = None
445-
446- all_combinations = itertools .product (n_estimators , max_depth )
447- # Iterate over each combination of hyperparameters
448- for params in all_combinations :
449- n_est , depth = params
450-
451- # Initialize a RandomForestClassifier with current parameters
452- model = RandomForestClassifier (
453- n_estimators = n_est ,
454- max_depth = depth ,
455- random_state = 42 ,
456- )
457-
458- # Train the model on the training set
459- model .fit (all_features_train , all_labels_train )
460-
461- # Validate the model on the validation set
462- y_val_pred_prob = model .predict_proba (all_features_val )[:, 1 ]
463-
464- metrics = calculate_ood_metrics (
465- y_true = all_labels_val ,
466- y_pred_prob = y_val_pred_prob ,
467- )
468-
469- accuracy = metrics ["accuracy" ]
470- f1_score_val = metrics ["f1_score" ]
471- auroc_val = metrics ["auroc" ]
472-
473- logging .info (
474- f"n_estimators: { n_est } , max_depth: { depth } , Accuracy: { accuracy :.4f} , F1 Score: { f1_score_val :.4f} , "
475- f"AUROC: { auroc_val :.4f} "
476- )
477-
478- # Update best model if the current one is better
479- if accuracy > best_accuracy :
480- best_accuracy = accuracy
481- best_params = {
482- "n_estimators" : n_est ,
483- "max_depth" : depth ,
484- }
485- best_model = model
486- if best_accuracy == 1 :
487- break
488-
489- # Log the best parameters and accuracy
490- print (f"Best Parameters: { best_params } " )
491- print (f"Best Validation Accuracy: { best_accuracy :.4f} " )
492-
493- self .ood_classifier = best_model
494-
495420 def _cood_features_from_distribution_data (
496421 self , distribution_data : List [DistributionDataItem ]
497422 ) -> np .ndarray :
0 commit comments