1
- # -*- coding: utf-8 -*-
2
1
from functools import partial
3
- from typing import List
4
2
5
3
import numpy as np
6
4
from sklearn .model_selection import check_cv
7
5
from sklearn .utils .metaestimators import available_if
8
6
from sklearn .utils import check_array , check_random_state
9
- from sklearn .base import (
10
- BaseEstimator ,
11
- MetaEstimatorMixin ,
12
- clone ,
13
- is_classifier
14
- )
7
+ from sklearn .base import BaseEstimator , MetaEstimatorMixin , clone
15
8
from sklearn .metrics import check_scoring
16
9
17
10
from eli5 .permutation_importance import get_score_importances
18
- from eli5 .sklearn .utils import pandas_available
11
+ from eli5 .sklearn .utils import pandas_available , is_classifier
19
12
20
13
if pandas_available :
21
14
import pandas as pd
@@ -157,7 +150,6 @@ class PermutationImportance(BaseEstimator, MetaEstimatorMixin):
157
150
"""
158
151
def __init__ (self , estimator , scoring = None , n_iter = 5 , random_state = None ,
159
152
cv = 'prefit' , refit = True ):
160
- # type: (...) -> None
161
153
if isinstance (cv , str ) and cv != "prefit" :
162
154
raise ValueError ("Invalid cv value: {!r}" .format (cv ))
163
155
self .refit = refit
@@ -174,8 +166,7 @@ def pd_scorer(model, X, y):
174
166
return base_scorer (model , X , y )
175
167
return pd_scorer
176
168
177
- def fit (self , X , y , groups = None , ** fit_params ):
178
- # type: (...) -> PermutationImportance
169
+ def fit (self , X , y , groups = None , ** fit_params ) -> 'PermutationImportance' :
179
170
"""Compute ``feature_importances_`` attribute and optionally
180
171
fit the base estimator.
181
172
@@ -224,8 +215,8 @@ def fit(self, X, y, groups=None, **fit_params):
224
215
def _cv_scores_importances (self , X , y , groups = None , ** fit_params ):
225
216
assert self .cv is not None
226
217
cv = check_cv (self .cv , y , classifier = is_classifier (self .estimator ))
227
- feature_importances = [] # type: List
228
- base_scores = [] # type: List[float ]
218
+ feature_importances : list = []
219
+ base_scores : list [ float ] = [ ]
229
220
weights = fit_params .pop ('sample_weight' , None )
230
221
fold_fit_params = fit_params .copy ()
231
222
for train , test in cv .split (X , y , groups ):
@@ -249,8 +240,7 @@ def _get_score_importances(self, score_func, X, y):
249
240
random_state = self .rng_ )
250
241
251
242
@property
252
- def caveats_ (self ):
253
- # type: () -> str
243
+ def caveats_ (self ) -> str :
254
244
if self .cv == 'prefit' :
255
245
return CAVEATS_PREFIT
256
246
elif self .cv is None :
0 commit comments