Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/spikeinterface/curation/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer
from spikeinterface.core.generate import inject_some_split_units
from spikeinterface.curation import train_model
from pathlib import Path

job_kwargs = dict(n_jobs=-1)

Expand Down Expand Up @@ -82,6 +84,31 @@ def sorting_analyzer_with_splits():
return make_sorting_analyzer_with_splits(sorting_analyzer)


@pytest.fixture(scope="module")
def trained_pipeline_path():
"""
Makes a model saved at "./trained_pipeline" which will be used by other tests in the module.
If the model already exists, this function does nothing.
"""
trained_model_folder = Path(__file__).parent / Path("trained_pipeline")
analyzer = make_sorting_analyzer(sparse=True)
analyzer.compute(
{
"quality_metrics": {"metric_names": ["snr", "num_spikes"]},
"template_metrics": {"metric_names": ["half_width"]},
}
)
train_model(
analyzers=[analyzer] * 5,
labels=[[1, 0, 1, 0, 1]] * 5,
folder=trained_model_folder,
classifiers=["RandomForestClassifier"],
imputation_strategies=["median"],
scaling_techniques=["standard_scaler"],
)
yield trained_model_folder


if __name__ == "__main__":
sorting_analyzer = make_sorting_analyzer(sparse=False)
print(sorting_analyzer)
22 changes: 9 additions & 13 deletions src/spikeinterface/curation/tests/test_model_based_curation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from pathlib import Path
from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation
from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path
from spikeinterface.curation.model_based_curation import ModelBasedClassification
from spikeinterface.curation import auto_label_units, load_model
from spikeinterface.curation.train_manual_curation import _get_computed_metrics
Expand All @@ -14,12 +14,12 @@


@pytest.fixture
def model():
def model(trained_pipeline_path):
"""A toy model, created using the `sorting_analyzer_for_curation` from `spikeinterface.curation.tests.common`.
It has been trained locally and, when applied to `sorting_analyzer_for_curation` will label its 5 units with
the following labels: [1,0,1,0,1]."""

model = load_model(Path(__file__).parent / "trained_pipeline/", trusted=["numpy.dtype"])
model = load_model(trained_pipeline_path, trusted=["numpy.dtype"])
return model


Expand All @@ -38,26 +38,24 @@ def test_model_based_classification_init(sorting_analyzer_for_curation, model):
assert np.all(model_based_classification.required_metrics == model_based_classification.pipeline.feature_names_in_)


def test_metric_ordering_independence(sorting_analyzer_for_curation, model):
def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pipeline_path):
"""The function `auto_label_units` needs the correct metrics to have been computed. However,
it should be independent of the order of computation. We test this here."""

sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"])

model_folder = Path(__file__).parent / Path("trained_pipeline")

prediction_prob_dataframe_1 = auto_label_units(
sorting_analyzer=sorting_analyzer_for_curation,
model_folder=model_folder,
model_folder=trained_pipeline_path,
trusted=["numpy.dtype"],
)

sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"])

prediction_prob_dataframe_2 = auto_label_units(
sorting_analyzer=sorting_analyzer_for_curation,
model_folder=model_folder,
model_folder=trained_pipeline_path,
trusted=["numpy.dtype"],
)

Expand Down Expand Up @@ -137,7 +135,7 @@ def test_model_based_classification_predict_labels(sorting_analyzer_for_curation


@pytest.mark.skip(reason="We need to retrain the model to reflect any changes in metric computation")
def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_curation):
def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_curation, trained_pipeline_path):
"""We track whether the metric parameters used to compute the metrics used to train
a model are the same as the parameters used to compute the metrics in the sorting
analyzer which is being curated. If they are different, an error or warning will
Expand All @@ -148,9 +146,7 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura
)
sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])

model_folder = Path(__file__).parent / Path("trained_pipeline")

model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"])
model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"])
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model)

# an error should be raised if `enforce_metric_params` is True
Expand All @@ -169,6 +165,6 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura
)
sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])

model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"])
model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"])
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model)
model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info)
Binary file not shown.
21 changes: 0 additions & 21 deletions src/spikeinterface/curation/tests/trained_pipeline/labels.csv

This file was deleted.

This file was deleted.

60 changes: 0 additions & 60 deletions src/spikeinterface/curation/tests/trained_pipeline/model_info.json

This file was deleted.

This file was deleted.