Skip to content

Support for TensorFlow Keras #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
10 changes: 5 additions & 5 deletions eli5/formatters/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@


def format_as_image(expl, # type: Explanation
resampling_filter=Image.LANCZOS, # type: int
resampling_filter=Image.Resampling.LANCZOS, # type: int
colormap=matplotlib.cm.viridis, # type: Callable[[np.ndarray], np.ndarray]
alpha_limit=0.65, # type: Optional[Union[float, int]]
):
# type: (...) -> Image
"""format_as_image(expl, resampling_filter=Image.LANCZOS, colormap=matplotlib.cm.viridis, alpha_limit=0.65)
"""format_as_image(expl, resampling_filter=Image.Resampling.LANCZOS, colormap=matplotlib.cm.viridis, alpha_limit=0.65)

Format a :class:`eli5.base.Explanation` object as an image.

Expand Down Expand Up @@ -50,7 +50,7 @@ def format_as_image(expl, # type: Explanation

*Note that these attributes are integer values*.

Default is ``PIL.Image.LANCZOS``.
Default is ``PIL.Image.Resampling.LANCZOS``.
:type resampling_filter: int, optional

:param colormap:
Expand Down Expand Up @@ -239,7 +239,7 @@ def _cap_alpha(alpha_arr, alpha_limit):
'got: {}'.format(alpha_limit))


def expand_heatmap(heatmap, image, resampling_filter=Image.LANCZOS):
def expand_heatmap(heatmap, image, resampling_filter=Image.Resampling.LANCZOS):
# type: (np.ndarray, Image, Union[None, int]) -> Image
"""
Resize the ``heatmap`` image array to fit over the original ``image``,
Expand Down Expand Up @@ -286,4 +286,4 @@ def _overlay_heatmap(heatmap, image):
"""
# note that the order of alpha_composite arguments matters
overlayed_image = Image.alpha_composite(image, heatmap)
return overlayed_image
return overlayed_image
34 changes: 22 additions & 12 deletions eli5/keras/explain_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,28 @@
import PIL

import numpy as np
import keras
import keras.backend as K
from keras.models import Model
from keras.layers import Layer
from keras.layers import (
Conv2D,
MaxPooling2D,
AveragePooling2D,
GlobalMaxPooling2D,
GlobalAveragePooling2D,
)
from keras.preprocessing.image import array_to_img
import os
if'TF_KERAS' in os.environ and os.environ['TF_KERAS'] == '1':
from tensorflow import keras
else:
import keras
K = keras.backend
Model = keras.models.Model
Layer = keras.layers.Layer
Conv2D = keras.layers.Conv2D
MaxPooling2D = keras.layers.MaxPooling2D
AveragePooling2D = keras.layers.AveragePooling2D
GlobalMaxPooling2D = keras.layers.GlobalMaxPooling2D
GlobalAveragePooling2D = keras.layers.GlobalAveragePooling2D

try:
# tensorflow<2.9
array_to_img = keras.preprocessing.image.array_to_img
except:
# tensorflow>=2.9
# reference: https://www.tensorflow.org/api_docs/python/tf/keras/utils/array_to_img
from tensorflow.keras.utils import array_to_img


from eli5.base import Explanation, TargetExplanation
from eli5.explain import explain_prediction
Expand Down
12 changes: 8 additions & 4 deletions eli5/keras/gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
from typing import Union, Optional, Tuple, List

import numpy as np
import keras
import keras.backend as K
from keras.models import Model
from keras.layers import Layer
import os
if 'TF_KERAS' in os.environ and os.environ['TF_KERAS'] == '1':
from tensorflow import keras
else:
import keras
K = keras.backend
Model = keras.models.Model
Layer = keras.layers.Layer


def gradcam(weights, activations):
Expand Down
6 changes: 4 additions & 2 deletions eli5/lime/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
)
from eli5.lime._vectorizer import SingleDocumentVectorizer

from packaging.version import parse


class TextExplainer(BaseEstimator):
"""
Expand Down Expand Up @@ -320,12 +322,12 @@ def _fix_target_names(self, kwargs):

def _default_clf(self):
kwargs = dict(
loss='log',
loss='log_loss',
penalty='elasticnet',
alpha=1e-3,
random_state=self.rng_
)
if sklearn_version() >= '0.19':
if sklearn_version() >= parse('0.19'):
kwargs['tol'] = 1e-3
return SGDClassifier(**kwargs)

Expand Down
4 changes: 2 additions & 2 deletions eli5/lime/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import six

import numpy as np
from scipy.stats import itemfreq
from sklearn.base import BaseEstimator, clone
from sklearn.neighbors import KernelDensity
from sklearn.metrics import pairwise_distances
Expand Down Expand Up @@ -188,7 +187,8 @@ def _sampler_n_samples(self, n_samples):
p=self.weights)
return [
(self.samplers[idx], freq)
for idx, freq in itemfreq(sampler_indices)
# use np.unique due to removal of scipy.stats.itemfreq
for idx, freq in np.vstack( np.unique(sampler_indices, return_counts=True) ).transpose()
]


Expand Down
4 changes: 2 additions & 2 deletions eli5/lime/textutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


# the same as scikit-learn token pattern, but allows single-char tokens
DEFAULT_TOKEN_PATTERN = r'(?u)\b\w+\b'
DEFAULT_TOKEN_PATTERN = r'\b\w+\b'

# non-whitespace chars
CHAR_TOKEN_PATTERN = r'[^\s]'
Expand Down Expand Up @@ -183,7 +183,7 @@ def __init__(self, parts):
def fromtext(cls, text, token_pattern=DEFAULT_TOKEN_PATTERN):
# type: (str, str) -> SplitResult
token_pattern = u"(%s)" % token_pattern
parts = re.split(token_pattern, text)
parts = re.split(token_pattern, text, flags=re.UNICODE)
return cls(parts)

@property
Expand Down
19 changes: 15 additions & 4 deletions eli5/lime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import numpy as np
from scipy.stats import entropy
from sklearn.pipeline import Pipeline
from sklearn.utils import check_random_state, issparse
from sklearn.utils.metaestimators import if_delegate_has_method
from sklearn.utils import check_random_state
from scipy.sparse import issparse
from sklearn.utils.metaestimators import available_if
from sklearn.utils import shuffle as _shuffle

from eli5.utils import vstack
from eli5.sklearn.utils import sklearn_version

from packaging.version import parse


def fit_proba(clf, X, y_proba, expand_factor=10, sample_weight=None,
shuffle=True, random_state=None,
Expand Down Expand Up @@ -73,7 +76,15 @@ def fix_multiclass_predict_proba(y_proba, # type: np.ndarray
class _PipelinePatched(Pipeline):
# Patch from https://github.com/scikit-learn/scikit-learn/pull/7723;
# only needed for scikit-learn < 0.19.
@if_delegate_has_method(delegate='_final_estimator')

# Reference: https://github.com/scikit-learn/scikit-learn/issues/20506
def _estimator_has(attr):
def check(self):
return hasattr(self.estimator, attr)

return check

@available_if(_estimator_has('_final_estimator'))
def score(self, X, y=None, **score_params):
Xt = X
for name, transform in self.steps[:-1]:
Expand All @@ -83,7 +94,7 @@ def score(self, X, y=None, **score_params):


def score_with_sample_weight(estimator, X, y=None, sample_weight=None):
if sklearn_version() < '0.19':
if sklearn_version() < parse('0.19'):
if isinstance(estimator, Pipeline) and sample_weight is not None:
estimator = _PipelinePatched(estimator.steps)
if sample_weight is None:
Expand Down
19 changes: 13 additions & 6 deletions eli5/sklearn/permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
from sklearn.model_selection import check_cv
from sklearn.utils.metaestimators import if_delegate_has_method
from sklearn.utils.metaestimators import available_if
from sklearn.utils import check_array, check_random_state
from sklearn.base import (
BaseEstimator,
Expand Down Expand Up @@ -247,23 +247,30 @@ def caveats_(self):

# ============= Exposed methods of a wrapped estimator:

@if_delegate_has_method(delegate='wrapped_estimator_')
# Reference: https://github.com/scikit-learn/scikit-learn/issues/20506
def _estimator_has(attr):
def check(self):
return hasattr(self.estimator, attr)

return check

@available_if(_estimator_has('wrapped_estimator_'))
def score(self, X, y=None, *args, **kwargs):
return self.wrapped_estimator_.score(X, y, *args, **kwargs)

@if_delegate_has_method(delegate='wrapped_estimator_')
@available_if(_estimator_has('wrapped_estimator_'))
def predict(self, X):
return self.wrapped_estimator_.predict(X)

@if_delegate_has_method(delegate='wrapped_estimator_')
@available_if(_estimator_has('wrapped_estimator_'))
def predict_proba(self, X):
return self.wrapped_estimator_.predict_proba(X)

@if_delegate_has_method(delegate='wrapped_estimator_')
@available_if(_estimator_has('wrapped_estimator_'))
def predict_log_proba(self, X):
return self.wrapped_estimator_.predict_log_proba(X)

@if_delegate_has_method(delegate='wrapped_estimator_')
@available_if(_estimator_has('wrapped_estimator_'))
def decision_function(self, X):
return self.wrapped_estimator_.decision_function(X)

Expand Down
7 changes: 5 additions & 2 deletions eli5/sklearn/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from distutils.version import LooseVersion
#from distutils.version import LooseVersion # deprecated
from packaging.version import parse as LooseVersion
from typing import Any, Optional, List, Tuple

import numpy as np
Expand Down Expand Up @@ -80,7 +81,9 @@ def get_feature_names(clf, vec=None, bias_name='<BIAS>', feature_names=None,
bias_name = None

if feature_names is None:
if vec and hasattr(vec, 'get_feature_names'):
if vec and hasattr(vec, 'get_feature_names_out'):
return FeatureNames(vec.get_feature_names_out(), bias_name=bias_name)
elif vec and hasattr(vec, 'get_feature_names'):
return FeatureNames(vec.get_feature_names(), bias_name=bias_name)
else:
if estimator_feature_names is None:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_keras_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def assert_attention_over_area(expl, area):
heatmap = expl.targets[0].heatmap

# fit heatmap over image
heatmap = expand_heatmap(heatmap, image, Image.LANCZOS)
heatmap = expand_heatmap(heatmap, image, Image.Resampling.LANCZOS)
heatmap = np.array(heatmap)

# get a slice of the area
Expand Down Expand Up @@ -159,4 +159,4 @@ def test_show_prediction_nodeps(show_nodeps, keras_clf, cat_dog_image):
])
def test_explain_prediction_not_supported(model, doc):
res = eli5.explain_prediction(model, doc)
assert 'supported' in res.error
assert 'supported' in res.error