Skip to content

Commit cd3f5b6

Browse files
committed
fix is_classifier for older xgboost
1 parent ab85a17 commit cd3f5b6

File tree

4 files changed

+27
-20
lines changed

4 files changed

+27
-20
lines changed

eli5/sklearn/explain_prediction.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
# -*- coding: utf-8 -*-
21
from functools import partial
32

43
import numpy as np
54
import scipy.sparse as sp
6-
from sklearn.base import BaseEstimator, is_classifier
5+
from sklearn.base import BaseEstimator
76
from sklearn.ensemble import (
87
ExtraTreesClassifier,
98
ExtraTreesRegressor,
@@ -54,14 +53,15 @@
5453
from eli5.base_utils import singledispatch
5554
from eli5.utils import (
5655
get_target_display_names,
57-
get_binary_target_scale_label_id
56+
get_binary_target_scale_label_id,
5857
)
5958
from eli5.sklearn.utils import (
6059
add_intercept,
6160
get_coef,
6261
get_default_target_names,
6362
get_X,
6463
get_X0,
64+
is_classifier,
6565
is_multiclass_classifier,
6666
is_multitarget_regressor,
6767
predict_proba,

eli5/sklearn/permutation_importance.py

+6-16
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
1-
# -*- coding: utf-8 -*-
21
from functools import partial
3-
from typing import List
42

53
import numpy as np
64
from sklearn.model_selection import check_cv
75
from sklearn.utils.metaestimators import available_if
86
from sklearn.utils import check_array, check_random_state
9-
from sklearn.base import (
10-
BaseEstimator,
11-
MetaEstimatorMixin,
12-
clone,
13-
is_classifier
14-
)
7+
from sklearn.base import BaseEstimator, MetaEstimatorMixin, clone
158
from sklearn.metrics import check_scoring
169

1710
from eli5.permutation_importance import get_score_importances
18-
from eli5.sklearn.utils import pandas_available
11+
from eli5.sklearn.utils import pandas_available, is_classifier
1912

2013
if pandas_available:
2114
import pandas as pd
@@ -157,7 +150,6 @@ class PermutationImportance(BaseEstimator, MetaEstimatorMixin):
157150
"""
158151
def __init__(self, estimator, scoring=None, n_iter=5, random_state=None,
159152
cv='prefit', refit=True):
160-
# type: (...) -> None
161153
if isinstance(cv, str) and cv != "prefit":
162154
raise ValueError("Invalid cv value: {!r}".format(cv))
163155
self.refit = refit
@@ -174,8 +166,7 @@ def pd_scorer(model, X, y):
174166
return base_scorer(model, X, y)
175167
return pd_scorer
176168

177-
def fit(self, X, y, groups=None, **fit_params):
178-
# type: (...) -> PermutationImportance
169+
def fit(self, X, y, groups=None, **fit_params) -> 'PermutationImportance':
179170
"""Compute ``feature_importances_`` attribute and optionally
180171
fit the base estimator.
181172
@@ -224,8 +215,8 @@ def fit(self, X, y, groups=None, **fit_params):
224215
def _cv_scores_importances(self, X, y, groups=None, **fit_params):
225216
assert self.cv is not None
226217
cv = check_cv(self.cv, y, classifier=is_classifier(self.estimator))
227-
feature_importances = [] # type: List
228-
base_scores = [] # type: List[float]
218+
feature_importances: list = []
219+
base_scores: list[float] = []
229220
weights = fit_params.pop('sample_weight', None)
230221
fold_fit_params = fit_params.copy()
231222
for train, test in cv.split(X, y, groups):
@@ -249,8 +240,7 @@ def _get_score_importances(self, score_func, X, y):
249240
random_state=self.rng_)
250241

251242
@property
252-
def caveats_(self):
253-
# type: () -> str
243+
def caveats_(self) -> str:
254244
if self.cv == 'prefit':
255245
return CAVEATS_PREFIT
256246
elif self.cv is None:

eli5/sklearn/utils.py

+18
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,30 @@
22

33
import numpy as np
44
import scipy.sparse as sp
5+
import sklearn.base
56
from sklearn.multiclass import OneVsRestClassifier
67

78
from eli5.sklearn.unhashing import invert_hashing_and_fit, handle_hashing_vec
89
from eli5._feature_names import FeatureNames
910

1011

12+
def is_classifier(estimator):
13+
try:
14+
return sklearn.base.is_classifier(estimator)
15+
except AttributeError:
16+
# old xgboost < 2.0.0 is not compatible with new sklean here
17+
try:
18+
import xgboost
19+
except ImportError:
20+
pass
21+
else:
22+
if isinstance(estimator, xgboost.XGBClassifier):
23+
return True
24+
elif isinstance(estimator, (xgboost.XGBRanker, xgboost.XGBRegressor)):
25+
return False
26+
raise
27+
28+
1129
def is_multiclass_classifier(clf) -> bool:
1230
"""
1331
Return True if a classifier is multiclass or False if it is binary.

eli5/utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
import numpy as np
32
from scipy import sparse as sp
43

0 commit comments

Comments
 (0)