diff --git a/tensorflow_probability/python/stats/kendalls_tau.py b/tensorflow_probability/python/stats/kendalls_tau.py
index 9818cba4d1..75b7535fc9 100644
--- a/tensorflow_probability/python/stats/kendalls_tau.py
+++ b/tensorflow_probability/python/stats/kendalls_tau.py
@@ -14,9 +14,9 @@
 # ==============================================================================
 """Implements Kendall's Tau metric and loss."""
 
+import numpy as np
 import tensorflow as tf
 
-from tensorflow_probability.python.internal import assert_util
 from tensorflow_probability.python.internal import dtype_util
 from tensorflow_probability.python.internal import prefer_static as ps
 from tensorflow_probability.python.internal import tensorshape_util
@@ -46,7 +46,7 @@ def iterative_mergesort(y, permutation, name=None):
         permutation, name='permutation', dtype=tf.int32)
     shape = permutation.shape
     tensorshape_util.assert_is_compatible_with(y.shape, shape)
-    n = ps.size(y)
+    y_size = ps.size(y)
 
     def outer_body(k, exchanges, permutation):
       # The outer body progressively merges lists as k grows by powers of 2,
@@ -58,46 +58,47 @@ def middle_body(left, exchanges, permutation):
         # the middle body advances through the sublists of size k, advancing
         # the left edge until the end of the input is reached.
         right = left + k
-        end = tf.minimum(right + k, n)
+        end = tf.minimum(right + k, y_size)
 
         # See explanation here
         # https://www.geeksforgeeks.org/counting-inversions/.
 
-        def inner_body(i, j, x, np, p):
+        def inner_body(i, j, x, n, p):
           # The [left, right) and [right, end) lists are merged sorted, with
           # i and j tracking the advance through each range. x records the
           # number of order (bubble-sort equivalent) swaps that are happening
-          # with each insertion, and np represents the size of the output
+          # with each insertion, and n represents the size of the output
           # permutation that's been filled in using the p tensor.
           y_less = y_ordered[i] <= y_ordered[j]
           element = tf.where(y_less, [permutation[i]], [permutation[j]])
-          new_p = tf.concat([p[0:np], element, p[np + 1:n]], axis=0)
+          new_p = tf.concat([p[0:n], element, p[n + 1:y_size]], axis=0)
           tensorshape_util.set_shape(new_p, p.shape)
           return (tf.where(y_less, i + 1, i), tf.where(y_less, j, j + 1),
-                  tf.where(y_less, x, x + right - i), np + 1, new_p)
+                  tf.where(y_less, x, x + right - i), n + 1, new_p)
 
-        i_j_x_np_p = (left, right, exchanges, 0, tf.zeros([n], dtype=tf.int32))
-        (i, j, exchanges, np, p) = tf.while_loop(
-            cond=lambda i, j, x, np, p: tf.math.logical_and(i < right, j < end),
+        i_j_x_n_p = (left, right, exchanges, 0,
+                     tf.zeros([y_size], dtype=tf.int32))
+        (i, j, exchanges, n, p) = tf.while_loop(
+            cond=lambda i, j, x, n, p: tf.math.logical_and(i < right, j < end),
             body=inner_body,
-            loop_vars=i_j_x_np_p)
+            loop_vars=i_j_x_n_p)
         permutation = tf.concat([
-            permutation[0:left], p[0:np], permutation[i:right],
-            permutation[j:end], permutation[end:n]
+            permutation[0:left], p[0:n], permutation[i:right],
+            permutation[j:end], permutation[end:y_size]
         ],
                                 axis=0)
         tensorshape_util.set_shape(permutation, shape)
         return left + 2 * k, exchanges, permutation
 
       _, exchanges, permutation = tf.while_loop(
-          cond=lambda left, exchanges, permutation: left < n - k,
+          cond=lambda left, exchanges, permutation: left < y_size - k,
           body=middle_body,
           loop_vars=(0, exchanges, permutation))
       k *= 2
       return k, exchanges, permutation
 
     _, exchanges, permutation = tf.while_loop(
-        cond=lambda k, exchanges, permutation: k < n,
+        cond=lambda k, exchanges, permutation: k < y_size,
         body=outer_body,
         loop_vars=(1, 0, permutation))
     return exchanges, permutation
@@ -159,37 +160,9 @@ def secondary_sort():
                      axis=0)
 
 
-def kendalls_tau(y_true, y_pred, name=None):
-  """Computes Kendall's Tau for two ordered lists.
-
-  Kendall's Tau measures the correlation between ordinal rankings. This
-  implementation is similar to the one used in scipy.stats.kendalltau.
-  The provided values may be of any type that is sortable, with the
-  argsort indices indicating the true or proposed ordinal sequence.
-
-  Args:
-    y_true: a `Tensor` of shape `[n]` containing the true ordinal ranking.
-    y_pred: a `Tensor` of shape `[n]` containing the predicted ordering of the
-      same N items.
-    name: Optional Python `str` name for ops created by this method.
-      Default value: `None` (i.e., 'kendalls_tau').
-
-  Returns:
-    kendalls_tau: Kendall's Tau, the 1945 tau-b formulation that ignores
-      ordering of ties, as a `float32` scalar Tensor.
-  """
-  with tf.name_scope(name or 'kendalls_tau'):
-    in_type = dtype_util.common_dtype([y_true, y_pred], dtype_hint=tf.float32)
-    y_true = tf.convert_to_tensor(y_true, name='y_true', dtype=in_type)
-    y_pred = tf.convert_to_tensor(y_pred, name='y_pred', dtype=in_type)
-    tensorshape_util.assert_is_compatible_with(y_true.shape, y_pred.shape)
-    assertions = [
-        assert_util.assert_rank(y_true, 1),
-        assert_util.assert_greater(
-            ps.size(y_true), 1, 'Ordering requires at least 2 elements.')
-    ]
-    with tf.control_dependencies(assertions):
-      lexa = lexicographical_indirect_sort(y_true, y_pred)
+def _compute_kendalls_tau(y_true, y_pred):
+    """Kendall's Tau Implementation."""
+    lexa = lexicographical_indirect_sort(y_true, y_pred)
 
     # See A Computer Method for Calculating Kendall's Tau with Ungrouped Data
     # by William Night, Journal of the American Statistical Association,
@@ -238,13 +211,43 @@ def ties_in_y_pred_body(first, u, i):
         loop_vars=(0, 0, 1))
     u += ((n - first) * (n - first - 1)) // 2
     n0 = (n * (n - 1)) // 2
-    assertions = [
-        assert_util.assert_less(v, tf.cast(n0, tf.int32),
-                                'All ranks are ties for y_true.'),
-        assert_util.assert_less(u, tf.cast(n0, tf.int32),
-                                'All ranks are ties for y_pred.')
-    ]
-    with tf.control_dependencies(assertions):
-      return (tf.cast(n0 - (u + v - t), tf.float32) -
-              2.0 * tf.cast(exchanges, tf.float32)) / tf.math.sqrt(
-                  tf.cast(n0 - v, tf.float32) * tf.cast(n0 - u, tf.float32))
+    n0i = tf.cast(n0, tf.int32)
+    return tf.where(
+        tf.logical_or(tf.greater_equal(v, n0i), tf.greater_equal(u, n0i)),
+        tf.constant(np.nan),
+        ((tf.cast(n0 - (u + v - t), tf.float32) -
+          2.0 * tf.cast(exchanges, tf.float32)) /
+         tf.math.sqrt(tf.cast(n0 - v, tf.float32) * tf.cast(n0 - u, tf.float32))))
+
+
+def kendalls_tau(y_true, y_pred, name=None):
+  """Computes Kendall's Tau for two ordered lists.
+
+  Kendall's Tau measures the correlation between ordinal rankings. This
+  implementation is similar to the one used in scipy.stats.kendalltau.
+  The provided values may be of any type that is sortable, with the
+  argsort indices indicating the true or proposed ordinal sequence.
+
+  Args:
+    y_true: a `Tensor` of shape `[n]` containing the true ordinal ranking.
+    y_pred: a `Tensor` of shape `[n]` containing the predicted ordering of the
+      same N items.
+    name: Optional Python `str` name for ops created by this method.
+      Default value: `None` (i.e., 'kendalls_tau').
+
+  Returns:
+    kendalls_tau: Kendall's Tau, the 1945 tau-b formulation that ignores
+      ordering of ties, as a `float32` scalar Tensor.
+      Will return np.nan under conditions when the order is undefined, such
+      as when all the elements of y_true or y_pred are the same, or when the
+      number of elements is less than 2.
+  """
+  with tf.name_scope(name or 'kendalls_tau'):
+    in_type = dtype_util.common_dtype([y_true, y_pred], dtype_hint=tf.float32)
+    y_true = tf.convert_to_tensor(y_true, name='y_true', dtype=in_type)
+    y_pred = tf.convert_to_tensor(y_pred, name='y_pred', dtype=in_type)
+    tensorshape_util.assert_is_compatible_with(y_true.shape, y_pred.shape)
+    return tf.where(
+        tf.logical_or(
+            tf.not_equal(ps.rank(y_true), 1), tf.less(ps.size(y_true), 2)),
+        tf.constant(np.nan), _compute_kendalls_tau(y_true, y_pred))
diff --git a/tensorflow_probability/python/stats/kendalls_tau_test.py b/tensorflow_probability/python/stats/kendalls_tau_test.py
index 8d01261cf9..fbe986c638 100644
--- a/tensorflow_probability/python/stats/kendalls_tau_test.py
+++ b/tensorflow_probability/python/stats/kendalls_tau_test.py
@@ -15,6 +15,7 @@
 """Tests Kendall's Tau metric."""
 
 import random
+import numpy as np
 
 from scipy import stats
 
@@ -74,21 +75,43 @@ def test_kendall_random_lists(self):
       self.assertAllClose(expected, res, atol=1e-5)
 
   def test_kendall_tau_assert_all_ties_y_true(self):
-    with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)):
-      self.evaluate(tfp.stats.kendalls_tau([12, 12, 12], [1, 4, 7]))
+    self.assertTrue(
+          self.evaluate(
+              tf.math.is_nan(tfp.stats.kendalls_tau([12, 12, 12], [1, 4, 7]))))
 
   def test_kendall_tau_assert_all_ties_y_pred(self):
-    with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)):
-      self.evaluate(tfp.stats.kendalls_tau([1, 2, 3], [4, 4, 4]))
+    self.assertTrue(
+          self.evaluate(
+              tf.math.is_nan(tfp.stats.kendalls_tau([1, 2, 3], [4, 4, 4]))))
 
   def test_kendall_tau_assert_scalar(self):
-    with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)):
-      tfp.stats.kendalls_tau([1], [4])
+    self.assertTrue(
+        self.evaluate(tf.math.is_nan(tfp.stats.kendalls_tau([1], [4]))))
 
   def test_kendall_tau_assert_unmatched(self):
     with self.assertRaises(ValueError):
       tfp.stats.kendalls_tau([1, 2], [3, 4, 5])
 
+  def test_kendall_tau_edge_case_behavior(self):
+    self.assertTrue(
+        self.evaluate(
+            tf.math.is_nan(
+                tfp.stats.kendalls_tau(
+                    tf.constant([0, 0]), tf.constant([3, 5])))))
+    self.assertTrue(
+        self.evaluate(
+            tf.math.is_nan(
+                tfp.stats.kendalls_tau(
+                    tf.constant([0, 1]), tf.constant([3, 3])))))
+    self.assertTrue(
+        self.evaluate(
+            tf.math.is_nan(
+                tfp.stats.kendalls_tau(tf.constant([0]), tf.constant([3])))))
+    self.assertTrue(
+        self.evaluate(
+            tf.math.is_nan(
+                tfp.stats.kendalls_tau(tf.constant([]), tf.constant([])))))
+
 
 if __name__ == '__main__':
   test_util.main()