diff --git a/eli5/keras/explain_prediction.py b/eli5/keras/explain_prediction.py index 73deb25b..7f64379b 100644 --- a/eli5/keras/explain_prediction.py +++ b/eli5/keras/explain_prediction.py @@ -53,7 +53,7 @@ def explain_prediction_keras(model, # type: Model The tensor must be of suitable shape for the ``model``. - Check ``model.input_shape`` to confirm the required dimensions of the input tensor. + Check ``model.input.shape`` to confirm the required dimensions of the input tensor. :raises TypeError: if ``doc`` is not a numpy array. @@ -260,7 +260,7 @@ def _validate_doc(model, doc): """ if not isinstance(doc, np.ndarray): raise TypeError('doc must be a numpy.ndarray, got: {}'.format(doc)) - input_sh = model.input_shape + input_sh = model.input.shape doc_sh = doc.shape if len(input_sh) == 4: # rank 4 with (batch, ...) shape @@ -337,6 +337,6 @@ def _is_suitable_activation_layer(model, layer): # check layer name # a check that asks "can we resize this activation layer over the image?" - rank = len(layer.output_shape) - required_rank = len(model.input_shape) + rank = len(layer.output.shape) + required_rank = len(model.input.shape) return rank == required_rank diff --git a/tests/test_keras.py b/tests/test_keras.py index e4631175..893b5792 100644 --- a/tests/test_keras.py +++ b/tests/test_keras.py @@ -38,12 +38,14 @@ def simple_seq(): """A simple sequential model for images.""" model = Sequential([ - Activation('linear', input_shape=(32, 32, 1)), # index 0, input + Input((32, 32, 1)), + Activation('linear'), # index 0, input conv_layer, # index 1, conv Conv2D(20, (3, 3)), # index 2, conv2 GlobalAveragePooling2D(), # index 3, gap # output shape is (None, 20) ]) + model(Input((32, 32, 1))) print('Summary of model:') model.summary() # rename layers @@ -101,7 +103,7 @@ def test_validate_doc(simple_seq): def test_validate_doc_custom(): # model with custom (not rank 4) input shape - model = Sequential([Dense(1, input_shape=(2, 3))]) + model = Sequential(Input((2, 3)), [Dense(1)]) # not matching shape with pytest.raises(ValueError): _validate_doc(model, np.zeros((5, 3)))