From 4c6b32ea61749ccc79cc8d30c72a47020c71ba10 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 10 Dec 2025 11:30:55 +0000 Subject: [PATCH 1/2] make model instead of loading --- src/spikeinterface/curation/tests/common.py | 27 ++++++++ .../tests/test_model_based_curation.py | 5 +- .../tests/trained_pipeline/best_model.skops | Bin 34009 -> 0 bytes .../tests/trained_pipeline/labels.csv | 21 ------ .../trained_pipeline/model_accuracies.csv | 2 - .../tests/trained_pipeline/model_info.json | 60 ------------------ .../tests/trained_pipeline/training_data.csv | 21 ------ 7 files changed, 31 insertions(+), 105 deletions(-) delete mode 100644 src/spikeinterface/curation/tests/trained_pipeline/best_model.skops delete mode 100644 src/spikeinterface/curation/tests/trained_pipeline/labels.csv delete mode 100644 src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv delete mode 100644 src/spikeinterface/curation/tests/trained_pipeline/model_info.json delete mode 100644 src/spikeinterface/curation/tests/trained_pipeline/training_data.csv diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index 20ad84efa2..06d16debcc 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -4,6 +4,8 @@ from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer from spikeinterface.core.generate import inject_some_split_units +from spikeinterface.curation import train_model +from pathlib import Path job_kwargs = dict(n_jobs=-1) @@ -82,6 +84,31 @@ def sorting_analyzer_with_splits(): return make_sorting_analyzer_with_splits(sorting_analyzer) +def make_trained_pipeline(): + """ + Makes a model saved at "./trained_pipeline" which will be used by other tests in the module. + If the model already exists, this function does nothing. + """ + trained_model_folder = Path(__file__).parent / Path("trained_pipeline") + if not trained_model_folder.is_dir(): + analyzer = make_sorting_analyzer(sparse=True) + analyzer.compute( + { + "quality_metrics": {"metric_names": ["snr", "num_spikes"]}, + "template_metrics": {"metric_names": ["half_width"]}, + } + ) + train_model( + analyzers=[analyzer] * 5, + labels=[[1, 0, 1, 0, 1]] * 5, + folder=trained_model_folder, + classifiers=["RandomForestClassifier"], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + ) + return + + if __name__ == "__main__": sorting_analyzer = make_sorting_analyzer(sparse=False) print(sorting_analyzer) diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index e2452c1d54..8ec646ef05 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -1,6 +1,6 @@ import pytest from pathlib import Path -from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation +from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, make_trained_pipeline from spikeinterface.curation.model_based_curation import ModelBasedClassification from spikeinterface.curation import auto_label_units, load_model from spikeinterface.curation.train_manual_curation import _get_computed_metrics @@ -19,6 +19,7 @@ def model(): It has been trained locally and, when applied to `sorting_analyzer_for_curation` will label its 5 units with the following labels: [1,0,1,0,1].""" + make_trained_pipeline() model = load_model(Path(__file__).parent / "trained_pipeline/", trusted=["numpy.dtype"]) return model @@ -45,6 +46,7 @@ def test_metric_ordering_independence(sorting_analyzer_for_curation, model): sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) + make_trained_pipeline() model_folder = Path(__file__).parent / Path("trained_pipeline") prediction_prob_dataframe_1 = auto_label_units( @@ -147,6 +149,7 @@ def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curat ) sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + make_trained_pipeline() model_folder = Path(__file__).parent / Path("trained_pipeline") model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) diff --git a/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops b/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops deleted file mode 100644 index 362405f917e2f4c3dc41eca4160f2d8be69b1c32..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 34009 zcmeHQU8oz!6_)L!1v@PX_#sa7AoSvd8nJ(pR;fv^X?^mutv_X^JTgDd`WT51~(fErmR!c?lsjCVdDbV28Byt@tr10SDSMvpc&pqnXvN zuWY27m64^ZoilUhn=|K}IrF!A^1uU!Yv}c2uyyvvP4~+k_>SJ+gm+spOhs3_rqT3A z+pnJf$#bh`_th@eUOL)$!ru7kQ}v_A2gcE3^`nDe9F1+?3dVg0{r&^n3my19+_Faw z{QZP-?8#&G%k^En4z4o1w%)yZ{g=QadLv$2W=muGbGf3|yfX%d^+?j=T=PqBn)L*3ojX zF2?g5um*X3_S8Gd3NTlD`$vC8pZ_@is`6d74sU(*{%3+Zn1;Csb=dmnPk(pil?}vu z?d>Ojy;2+e=ilG@c>mLE9X>sD`<>}JsHV9Pby!)y@r&ozo>=+Dz?7EA4K3_x5NGIKT19Yn(p( z%WF3u{?~S`hTeyu6~f+@GqjuM!@z&(V7=b3te$6wp=C9msy9~Mk>k0((>NAC8V3D| z=fuy#3!Y<-{pN^E;{`VO0&$w>H_kh~=ycFea_9wq$|-OmxLeCWo=Ac zFLM1bdrFt`4A7-z04-k%{1iscC`^l;F4}WLQgre#r}!iV&B7_5r&z!#GB(cP%Az0Z zNyX#uJOI&T7dzYc3wakar~g96f`tX=uzAt=o@*lq7X*o1by-xeL|WU zK#|lOhtGn%2^hqAX23Pp8t!m3iJb9TLsYsOT1arAVWaWHK|;#DpisAuMbQ9fRx@-} zQ@BKzd8%1JoB@tJ=Z7&{BBoeKNH>Y=(X=+~kZ)pKX|@5mj>FX|18ACAUZ$u{#pS~C z$#t=*w6u6}9ga(z&mHhLbi>f~H?51dH*wfr$z{bTBe4X=6vgE*O;1-SM;ep#S*04M z?bCv0Ptwhj$h3GOm%6I9wU%P=av3L_lFo$Q?)NR%hqkq&V9e;XI05+|aG})sq98ax zqr`b)s3?lrF}iK8Z|7rp!Er{GGaN7VD4DD(=owp`JoA{OF6iu{)bR^SpldbL(Lhv`mktha=!>1-qyUt@n2N$#XAZHB z)ReAw#6}Y(mN*~i(XxXIqG*L9_kt7VsD4z@-C%1-oH~^69>#Xd_6F7^w;yfYLu^}Z zv692sN_4@xRP{vLBL}oz0}m#psqmymj&kJniD)>F_#-dS4JGvnr%hzTd{O}Vq?CzV zN9K2EzFxERBu-zlW|-Bx(3MOD8@klG#E_Ke&A1ZjAPo&L%(fyPq{Vq(mPlQ_E|ojF z44B8#+olTTNb-?t3=oQh&z7mS6;*|dgn3AEX!z;)r26N9p%%?b8QGf7ExPNNrW z?hHUf&uwTKfY!i`tbU?IWTp{Hp@!mdYN@JWHA5-K>6sTmnH6=LWXxj8q}ZM1>EZNr zm(OPzC#U_QBHkVl5j{OKcNk4!h0e}o=V*sa1~L+1C(1G1SsoB*Y`4J=H^rSH1q!OLUVn!g8XPGQo zN#NSL=!`=b$Jk}c0?MLP#iB^$T0?E>O)bXYLbm}usVq>eYFYm% z&*#K_s?y1oPL2iV9(#_0A%HXh2r`^_q&ViFsp6ovP#0vhEfcx2owhL@YY5?NgZl$6gXij6~x0{s+H(PL~`44jG(o}(BP>rj?A zu{J5j9m=2MFejKvF^OjSdbT&em2M_Kx?$R;2o7f=8%0O@XE0v^1;Rl>mb#6`&S)I; z9K6b_Y4z>M78BF$xTjaVU2>x>Yv|Z$r6SuaP&>r0A~Kk!EYMF#zGdVw$v+$RHx^=r z0){V)`W1R;om<)U9o5R1dOWXF8o6p>j5sKFIWgv2p^a8chgRq~B^DOBe$+MuHbbcY z@+X5hV_amSseDMD-pZH9b~;Z>2sqCquHTfQPO6Oz$ey3rNH!N1LJ{x$1cFj1UbM%1 z9>ipKV{s6HpcIJMRo&|_k`0!HVMGL_V5}zDV%0{+J*!DJf=T)*8!}+ns6n-O;7YRL zMr+k18v!QyRozd$dt%>R6t32QPtE0k^**W0>k4D+<8q;w6v+sx=i~Z4%jxrJTdKb zD%iw}J%8+M!q#29MsZr0p=oJzLZ&Ftp%xGafJctkg!!a6C&54u=)W?v6eGvCy=eJ% zQLZ(-xi0gBV5ipB(cCDbZjc0`*OZ{fKR5al+grLBipAJK9tF#XoYI87BC0@TB;rtP zQzmV+kfl`Gk_|vIsr%m2iJZ{{hXTyI2J;qawWJr`}?fhh-P z+;c|R;Y4ve(*Lu5LSD3mGdnY5Eu5LjSu&7F*P>0)rtn!PGz_P^ES*@+u`4IDWo7r_ z#XjgI$KBjoJ}F#?4M1=dX+MCS-LMY`ZU36~Ad}w1_>s#t*+`$8C^tXppda@CBnK^) z(urZ_ zlpD!ivd7qnjF-<0FUEMzx6TI}^39E9<}L51C%q6pjHg;~nwlFdUB6uj^r|I$>15B{ zon?UsA+TVtu=p+E272lcZ@I0Z9q5FhlmqBv_ir!bNFOL(2uDIt3ddgH4E8dLE$W!_ zLMRe~QYdEenLXUGLB7(H*9$>O2ueX&E!rb7QSi;u70oc?TAhqRfI*+K;8Pjr7pJA? zFJ}glF$FziSuKa58wI{+LTgm#fY8`T;s})_6Ocf1sAVE~;gYJ~(${qm{OC`>veHsg zAEnY#)j1%zw^TYA!%+EBvF@xa5N$)~?|@9@C+FLf%HHPVQ!SFGI~4g|Ss;u=Wr1RW zs4Nia82o2tfwBu4d2G<%fX_X&K*9*_`t*K)gp)+q!{(Mwx1}4xBLsw+4{AX|o?Te~ ztEnBqT2lcfMF*S2VUaKV<9IH7?iNI?yKR$Mou1KBb*9Ti8)B#pH7SR10oE`Wvupl= zP8#QcXEXdPNI|!=U@0rTwS=2Amah2C>JqQ%=w$apWpxp)c)6dI;TK;NuT6G$T~?di zd4~I0EzwEZ1*$f+kv|sFq=LdTu-RHMr(VOOT)~6qwAf9ZY)uK>xwPZaR-G-IkM7MX zzNm(SeB^#++V*VBwc>0!F-yd3omio)?O|hNw+jm`lTC1>OQf<9KUX{WaxVOoriI%= z+i(ZDg7R&^NqbSy!~176`~#?sU=&iv_BWKK3Y%3L2QROl+_(Smfx~d?bFKDbuyyvv zP4~+k_>SJ_j&SnY!B>iB7h#|qxOv*vyH~H%_jPBq*Wqh+>41p#k)3Og(RaaTw0$VK zWs=j8M6_PhJMuaZpf}6&qx;ixlS71fp5OZD{m;a-w=&w~>>Qr<(?hr45!3$Z04bLo z7{k-v`TWE!Nqrx{w8;W(QGNM4=rgPyeGt=9i&^+``8(h<+7B@;S&tzqm%ooaqkZ{n zq+D`4J6|q;CwfME{OgF8OuLK9 Date: Wed, 10 Dec 2025 12:48:55 +0000 Subject: [PATCH 2/2] use a fixture --- src/spikeinterface/curation/tests/common.py | 36 +++++++++---------- .../tests/test_model_based_curation.py | 25 +++++-------- 2 files changed, 27 insertions(+), 34 deletions(-) diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index 06d16debcc..8863f874d5 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -84,29 +84,29 @@ def sorting_analyzer_with_splits(): return make_sorting_analyzer_with_splits(sorting_analyzer) -def make_trained_pipeline(): +@pytest.fixture(scope="module") +def trained_pipeline_path(): """ Makes a model saved at "./trained_pipeline" which will be used by other tests in the module. If the model already exists, this function does nothing. """ trained_model_folder = Path(__file__).parent / Path("trained_pipeline") - if not trained_model_folder.is_dir(): - analyzer = make_sorting_analyzer(sparse=True) - analyzer.compute( - { - "quality_metrics": {"metric_names": ["snr", "num_spikes"]}, - "template_metrics": {"metric_names": ["half_width"]}, - } - ) - train_model( - analyzers=[analyzer] * 5, - labels=[[1, 0, 1, 0, 1]] * 5, - folder=trained_model_folder, - classifiers=["RandomForestClassifier"], - imputation_strategies=["median"], - scaling_techniques=["standard_scaler"], - ) - return + analyzer = make_sorting_analyzer(sparse=True) + analyzer.compute( + { + "quality_metrics": {"metric_names": ["snr", "num_spikes"]}, + "template_metrics": {"metric_names": ["half_width"]}, + } + ) + train_model( + analyzers=[analyzer] * 5, + labels=[[1, 0, 1, 0, 1]] * 5, + folder=trained_model_folder, + classifiers=["RandomForestClassifier"], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + ) + yield trained_model_folder if __name__ == "__main__": diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index 8ec646ef05..b3eeabc17c 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -1,6 +1,6 @@ import pytest from pathlib import Path -from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, make_trained_pipeline +from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path from spikeinterface.curation.model_based_curation import ModelBasedClassification from spikeinterface.curation import auto_label_units, load_model from spikeinterface.curation.train_manual_curation import _get_computed_metrics @@ -14,13 +14,12 @@ @pytest.fixture -def model(): +def model(trained_pipeline_path): """A toy model, created using the `sorting_analyzer_for_curation` from `spikeinterface.curation.tests.common`. It has been trained locally and, when applied to `sorting_analyzer_for_curation` will label its 5 units with the following labels: [1,0,1,0,1].""" - make_trained_pipeline() - model = load_model(Path(__file__).parent / "trained_pipeline/", trusted=["numpy.dtype"]) + model = load_model(trained_pipeline_path, trusted=["numpy.dtype"]) return model @@ -39,19 +38,16 @@ def test_model_based_classification_init(sorting_analyzer_for_curation, model): assert np.all(model_based_classification.required_metrics == model_based_classification.pipeline.feature_names_in_) -def test_metric_ordering_independence(sorting_analyzer_for_curation, model): +def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pipeline_path): """The function `auto_label_units` needs the correct metrics to have been computed. However, it should be independent of the order of computation. We test this here.""" sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) - make_trained_pipeline() - model_folder = Path(__file__).parent / Path("trained_pipeline") - prediction_prob_dataframe_1 = auto_label_units( sorting_analyzer=sorting_analyzer_for_curation, - model_folder=model_folder, + model_folder=trained_pipeline_path, trusted=["numpy.dtype"], ) @@ -59,7 +55,7 @@ def test_metric_ordering_independence(sorting_analyzer_for_curation, model): prediction_prob_dataframe_2 = auto_label_units( sorting_analyzer=sorting_analyzer_for_curation, - model_folder=model_folder, + model_folder=trained_pipeline_path, trusted=["numpy.dtype"], ) @@ -138,7 +134,7 @@ def test_model_based_classification_predict_labels(sorting_analyzer_for_curation assert np.all(predictions_labelled == ["good", "noise", "good", "noise", "good"]) -def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curation): +def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curation, trained_pipeline_path): """We track whether the metric parameters used to compute the metrics used to train a model are the same as the parameters used to compute the metrics in the sorting analyzer which is being curated. If they are different, an error or warning will @@ -149,10 +145,7 @@ def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curat ) sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) - make_trained_pipeline() - model_folder = Path(__file__).parent / Path("trained_pipeline") - - model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) + model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"]) model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) # an error should be raised if `enforce_metric_params` is True @@ -167,6 +160,6 @@ def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curat sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"], metric_params={}) sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) - model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) + model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"]) model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info)