diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index 20ad84efa2..8863f874d5 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -4,6 +4,8 @@ from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer from spikeinterface.core.generate import inject_some_split_units +from spikeinterface.curation import train_model +from pathlib import Path job_kwargs = dict(n_jobs=-1) @@ -82,6 +84,31 @@ def sorting_analyzer_with_splits(): return make_sorting_analyzer_with_splits(sorting_analyzer) +@pytest.fixture(scope="module") +def trained_pipeline_path(): + """ + Makes a model saved at "./trained_pipeline" which will be used by other tests in the module. + If the model already exists, this function does nothing. + """ + trained_model_folder = Path(__file__).parent / Path("trained_pipeline") + analyzer = make_sorting_analyzer(sparse=True) + analyzer.compute( + { + "quality_metrics": {"metric_names": ["snr", "num_spikes"]}, + "template_metrics": {"metric_names": ["half_width"]}, + } + ) + train_model( + analyzers=[analyzer] * 5, + labels=[[1, 0, 1, 0, 1]] * 5, + folder=trained_model_folder, + classifiers=["RandomForestClassifier"], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + ) + yield trained_model_folder + + if __name__ == "__main__": sorting_analyzer = make_sorting_analyzer(sparse=False) print(sorting_analyzer) diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index 6156062a7c..9f845bb1c3 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -1,6 +1,6 @@ import pytest from pathlib import Path -from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation +from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path from spikeinterface.curation.model_based_curation import ModelBasedClassification from spikeinterface.curation import auto_label_units, load_model from spikeinterface.curation.train_manual_curation import _get_computed_metrics @@ -14,12 +14,12 @@ @pytest.fixture -def model(): +def model(trained_pipeline_path): """A toy model, created using the `sorting_analyzer_for_curation` from `spikeinterface.curation.tests.common`. It has been trained locally and, when applied to `sorting_analyzer_for_curation` will label its 5 units with the following labels: [1,0,1,0,1].""" - model = load_model(Path(__file__).parent / "trained_pipeline/", trusted=["numpy.dtype"]) + model = load_model(trained_pipeline_path, trusted=["numpy.dtype"]) return model @@ -38,18 +38,16 @@ def test_model_based_classification_init(sorting_analyzer_for_curation, model): assert np.all(model_based_classification.required_metrics == model_based_classification.pipeline.feature_names_in_) -def test_metric_ordering_independence(sorting_analyzer_for_curation, model): +def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pipeline_path): """The function `auto_label_units` needs the correct metrics to have been computed. However, it should be independent of the order of computation. We test this here.""" sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) - model_folder = Path(__file__).parent / Path("trained_pipeline") - prediction_prob_dataframe_1 = auto_label_units( sorting_analyzer=sorting_analyzer_for_curation, - model_folder=model_folder, + model_folder=trained_pipeline_path, trusted=["numpy.dtype"], ) @@ -57,7 +55,7 @@ def test_metric_ordering_independence(sorting_analyzer_for_curation, model): prediction_prob_dataframe_2 = auto_label_units( sorting_analyzer=sorting_analyzer_for_curation, - model_folder=model_folder, + model_folder=trained_pipeline_path, trusted=["numpy.dtype"], ) @@ -137,7 +135,7 @@ def test_model_based_classification_predict_labels(sorting_analyzer_for_curation @pytest.mark.skip(reason="We need to retrain the model to reflect any changes in metric computation") -def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_curation): +def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_curation, trained_pipeline_path): """We track whether the metric parameters used to compute the metrics used to train a model are the same as the parameters used to compute the metrics in the sorting analyzer which is being curated. If they are different, an error or warning will @@ -148,9 +146,7 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura ) sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) - model_folder = Path(__file__).parent / Path("trained_pipeline") - - model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) + model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"]) model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) # an error should be raised if `enforce_metric_params` is True @@ -169,6 +165,6 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura ) sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) - model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) + model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"]) model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) diff --git a/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops b/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops deleted file mode 100644 index 362405f917..0000000000 Binary files a/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops and /dev/null differ diff --git a/src/spikeinterface/curation/tests/trained_pipeline/labels.csv b/src/spikeinterface/curation/tests/trained_pipeline/labels.csv deleted file mode 100644 index 46680a9e89..0000000000 --- a/src/spikeinterface/curation/tests/trained_pipeline/labels.csv +++ /dev/null @@ -1,21 +0,0 @@ -unit_index,0 -0,1 -1,0 -2,1 -3,0 -4,1 -0,1 -1,0 -2,1 -3,0 -4,1 -0,1 -1,0 -2,1 -3,0 -4,1 -0,1 -1,0 -2,1 -3,0 -4,1 diff --git a/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv b/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv deleted file mode 100644 index 7f015c380b..0000000000 --- a/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv +++ /dev/null @@ -1,2 +0,0 @@ -,classifier name,imputation_strategy,scaling_strategy,accuracy,precision,recall,model_id,best_params -0,LogisticRegression,median,StandardScaler(),1.0000,1.0000,1.0000,0,"OrderedDict([('C', 4.811707275233983), ('max_iter', 384), ('solver', 'saga')])" diff --git a/src/spikeinterface/curation/tests/trained_pipeline/model_info.json b/src/spikeinterface/curation/tests/trained_pipeline/model_info.json deleted file mode 100644 index 75ced28486..0000000000 --- a/src/spikeinterface/curation/tests/trained_pipeline/model_info.json +++ /dev/null @@ -1,60 +0,0 @@ -{ - "metric_params": { - "quality_metric_params": { - "metric_names": [ - "snr", - "num_spikes" - ], - "peak_sign": null, - "seed": null, - "metric_params": { - "num_spikes": {}, - "snr": { - "peak_sign": "neg", - "peak_mode": "extremum" - } - }, - "skip_pc_metrics": false, - "delete_existing_metrics": false, - "metrics_to_compute": [ - "snr", - "num_spikes" - ] - }, - "template_metric_params": { - "metric_names": [ - "half_width" - ], - "sparsity": null, - "peak_sign": "neg", - "upsampling_factor": 10, - "metric_params": { - "half_width": { - "recovery_window_ms": 0.7, - "peak_relative_threshold": 0.2, - "peak_width_ms": 0.1, - "depth_direction": "y", - "min_channels_for_velocity": 5, - "min_r2_velocity": 0.5, - "exp_peak_function": "ptp", - "min_r2_exp_decay": 0.5, - "spread_threshold": 0.2, - "spread_smooth_um": 20, - "column_range": null - } - }, - "delete_existing_metrics": false, - "metrics_to_compute": [ - "half_width" - ] - } - }, - "requirements": { - "spikeinterface": "0.101.1", - "scikit-learn": "1.3.2" - }, - "label_conversion": { - "1": 1, - "0": 0 - } -} diff --git a/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv b/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv deleted file mode 100644 index c9efca17ad..0000000000 --- a/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv +++ /dev/null @@ -1,21 +0,0 @@ -unit_id,snr,num_spikes,half_width -0,21.026926,5968.0,0.00027333334 -1,34.64474,5928.0,0.00023666666 -2,6.986315,5954.0,0.00026666667 -3,8.223127,6032.0,0.00020333333 -4,2.7464194,6002.0,0.00026666667 -0,21.026926,5968.0,0.00027333334 -1,34.64474,5928.0,0.00023666666 -2,6.986315,5954.0,0.00026666667 -3,8.223127,6032.0,0.00020333333 -4,2.7464194,6002.0,0.00026666667 -0,21.026926,5968.0,0.00027333334 -1,34.64474,5928.0,0.00023666666 -2,6.986315,5954.0,0.00026666667 -3,8.223127,6032.0,0.00020333333 -4,2.7464194,6002.0,0.00026666667 -0,21.026926,5968.0,0.00027333334 -1,34.64474,5928.0,0.00023666666 -2,6.986315,5954.0,0.00026666667 -3,8.223127,6032.0,0.00020333333 -4,2.7464194,6002.0,0.00026666667