Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Commit 1f307a1

Browse files
authored
Support IsolationForest (#693)
close iterative/mlem.ai#353 See also #423 (comment)
1 parent 5936765 commit 1f307a1

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ repos:
4444
- types-requests
4545
- types-six
4646
- types-PyYAML
47-
- pydantic
47+
- pydantic>=1.9.0,<2
4848
- types-filelock
4949
- types-emoji
5050
- repo: local

mlem/contrib/sklearn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, ClassVar, Dict, List, Optional, Union
77

88
import sklearn
9-
from sklearn.base import ClassifierMixin, RegressorMixin
9+
from sklearn.base import ClassifierMixin, OutlierMixin, RegressorMixin
1010
from sklearn.feature_extraction.text import TransformerMixin, _VectorizerMixin
1111
from sklearn.pipeline import Pipeline
1212
from sklearn.preprocessing._encoders import _BaseEncoder
@@ -28,7 +28,7 @@ class SklearnModel(ModelType, ModelHook, IsInstanceHookMixin):
2828
"""ModelType implementation for `scikit-learn` models"""
2929

3030
type: ClassVar[str] = "sklearn"
31-
valid_types: ClassVar = (RegressorMixin, ClassifierMixin)
31+
valid_types: ClassVar = (RegressorMixin, ClassifierMixin, OutlierMixin)
3232

3333
io: ModelIO = SimplePickleIO()
3434
"""IO"""

tests/contrib/test_sklearn.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import lightgbm as lgb
44
import numpy as np
55
import pytest
6+
from sklearn.ensemble import IsolationForest
67
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
78
from sklearn.linear_model import LinearRegression, LogisticRegression
89
from sklearn.pipeline import Pipeline
@@ -54,6 +55,13 @@ def regressor(inp_data, out_data):
5455
return lr
5556

5657

58+
@pytest.fixture
59+
def outlier(inp_data):
60+
model = IsolationForest()
61+
model.fit(inp_data)
62+
return model
63+
64+
5765
@pytest.fixture
5866
def count_vectorizer(text_inp_data):
5967
vectorizer = CountVectorizer()
@@ -195,7 +203,9 @@ def test_hook_lgb(lgbm_model, inp_data):
195203
assert signature.returns == returns
196204

197205

198-
@pytest.mark.parametrize("model", ["classifier", "regressor", "pipeline"])
206+
@pytest.mark.parametrize(
207+
"model", ["classifier", "regressor", "pipeline", "outlier"]
208+
)
199209
def test_model_type__predict(model, inp_data, request):
200210
model = request.getfixturevalue(model)
201211
model_type = ModelAnalyzer.analyze(model, sample_data=inp_data)
@@ -221,7 +231,14 @@ def test_model_type__reg_predict_proba(regressor, inp_data):
221231
model_type.call_method("predict_proba", inp_data)
222232

223233

224-
@pytest.mark.parametrize("model", ["classifier", "regressor"])
234+
def test_model_type__outlier_predict_proba(outlier, inp_data):
235+
model_type = ModelAnalyzer.analyze(outlier, sample_data=inp_data)
236+
237+
with pytest.raises(ValueError):
238+
model_type.call_method("predict_proba", inp_data)
239+
240+
241+
@pytest.mark.parametrize("model", ["classifier", "regressor", "outlier"])
225242
def test_model_type__dump_load(tmpdir, model, inp_data, request):
226243
model = request.getfixturevalue(model)
227244
model_type = ModelAnalyzer.analyze(model, sample_data=inp_data)

0 commit comments

Comments
 (0)