diff --git a/eli5/keras/explain_prediction.py b/eli5/keras/explain_prediction.py
index 73deb25b..7f64379b 100644
--- a/eli5/keras/explain_prediction.py
+++ b/eli5/keras/explain_prediction.py
@@ -53,7 +53,7 @@ def explain_prediction_keras(model, # type: Model
 
         The tensor must be of suitable shape for the ``model``.
 
-        Check ``model.input_shape`` to confirm the required dimensions of the input tensor.
+        Check ``model.input.shape`` to confirm the required dimensions of the input tensor.
 
 
         :raises TypeError: if ``doc`` is not a numpy array.
@@ -260,7 +260,7 @@ def _validate_doc(model, doc):
     """
     if not isinstance(doc, np.ndarray):
         raise TypeError('doc must be a numpy.ndarray, got: {}'.format(doc))
-    input_sh = model.input_shape
+    input_sh = model.input.shape
     doc_sh = doc.shape
     if len(input_sh) == 4:
         # rank 4 with (batch, ...) shape
@@ -337,6 +337,6 @@ def _is_suitable_activation_layer(model, layer):
     # check layer name
 
     # a check that asks "can we resize this activation layer over the image?"
-    rank = len(layer.output_shape)
-    required_rank = len(model.input_shape)
+    rank = len(layer.output.shape)
+    required_rank = len(model.input.shape)
     return rank == required_rank
diff --git a/tests/test_keras.py b/tests/test_keras.py
index e4631175..893b5792 100644
--- a/tests/test_keras.py
+++ b/tests/test_keras.py
@@ -38,12 +38,14 @@
 def simple_seq():
     """A simple sequential model for images."""
     model = Sequential([
-        Activation('linear', input_shape=(32, 32, 1)), # index 0, input
+        Input((32, 32, 1)),
+        Activation('linear'),                          # index 0, input
         conv_layer,                                    # index 1, conv
         Conv2D(20, (3, 3)),                            # index 2, conv2
         GlobalAveragePooling2D(),                      # index 3, gap
         # output shape is (None, 20)
     ])
+    model(Input((32, 32, 1)))
     print('Summary of model:')
     model.summary()
     # rename layers
@@ -101,7 +103,7 @@ def test_validate_doc(simple_seq):
 
 def test_validate_doc_custom():
     # model with custom (not rank 4) input shape
-    model = Sequential([Dense(1, input_shape=(2, 3))])
+    model = Sequential(Input((2, 3)), [Dense(1)])
     # not matching shape
     with pytest.raises(ValueError):
         _validate_doc(model, np.zeros((5, 3)))