Skip to content

Error in stats.kendalls_tau as Keras Metric #1417

Open
@gonzalesMK

Description

@gonzalesMK

I am trying to use TensorFlow probability as a metric in Keras. With respect to kendalls_tau, I get the following error:

import tensorflow_probability as tfp
import tensorflow as tf
import numpy as np 

def kendalls_tau(y_true, y_pred):
    a = tf.reshape(y_true, shape=(-1,))
    b = tf.reshape(y_pred, shape=(-1,))
    kendall = tfp.stats.kendalls_tau(a, b)
    return kendall

inputs = tf.keras.layers.Input(shape=(3,))
outputs = tf.keras.layers.Dense(2)(inputs)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)

model.compile(optimizer="Adam", loss="mse", metrics=kendalls_tau)

x = np.random.random((2, 3))
y = np.random.randint(0, 2, (2, 2) )
model.fit(x, y)

TypeError: in user code:

    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:864 train_function  *
        return step_function(self, iterator)
    <ipython-input-4-14a2210abe73>:5 kendalls_tau  *
        kendall = tfp.stats.kendalls_tau(a, b)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:196 kendalls_tau  **
        lexa = lexicographical_indirect_sort(y_true, y_pred)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:154 lexicographical_indirect_sort
        left, _, lexicographic = tf.while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:614 new_func
        return func(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2531 while_loop_v2
        return while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2729 while_loop
        return while_v2.while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:214 while_loop
        body_graph = func_graph_module.func_graph_from_py_func(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:200 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:152 body
        tf.cond(not_equal, secondary_sort, lambda: lexicographic))
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1438 cond_for_tf_v2
        return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:546 new_func
        return func(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1254 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py:83 cond_v2
        true_graph = func_graph_module.func_graph_from_py_func(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:148 secondary_sort
        tensorshape_util.set_shape(x, [n])
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py:328 set_shape
        tensor.set_shape(shape)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:758 set_shape
        shape = tensor_shape.TensorShape(shape)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 __init__
        self._dims = [Dimension(d) for d in dims]
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 <listcomp>
        self._dims = [Dimension(d) for d in dims]
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:206 __init__
        six.raise_from(
    <string>:3 raise_from
        

    TypeError: Dimension value must be integer or None or have an __index__ method, got value '<tf.Tensor 'kendalls_tau/lexicographical_indirect_sort/size0/strided_slice_1:0' shape=() dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/data/Mestrado/Ensaios/drbc_tf.py in 
     257 x = np.random.random((2, 3))
     258 y = np.random.randint(0, 2, (2, 2) )
---> 259 model.fit(x, y)
     260 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1190                 _r=1):
   1191               callbacks.on_train_batch_begin(step)
-> 1192               tmp_logs = self.train_function(iterator)
   1193               if data_handler.should_sync:
   1194                 context.async_wait()

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    883 
    884       with OptionalXlaContext(self._jit_compile):
--> 885         result = self._call(*args, **kwds)
    886 
    887       new_tracing_count = self.experimental_get_tracing_count()

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    927       # This is the first call of __call__, so we have to initialize.
    928       initializers = []
--> 929       self._initialize(args, kwds, add_initializers_to=initializers)
    930     finally:
    931       # At this point we know that the initialization is complete (or less

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    757     self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
    758     self._concrete_stateful_fn = (
--> 759         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
    760             *args, **kwds))
    761 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3057       args, kwargs = None, None
   3058     with self._lock:
-> 3059       graph_function, _ = self._maybe_define_function(args, kwargs)
   3060     return graph_function
   3061 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3454 
   3455           self._function_cache.missed.add(call_context_key)
-> 3456           graph_function = self._create_graph_function(args, kwargs)
   3457           self._function_cache.primary[cache_key] = graph_function
   3458 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3289     arg_names = base_arg_names + missing_arg_names
   3290     graph_function = ConcreteFunction(
-> 3291         func_graph_module.func_graph_from_py_func(
   3292             self._name,
   3293             self._python_function,

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1005         _, original_func = tf_decorator.unwrap(python_func)
   1006 
-> 1007       func_outputs = python_func(*func_args, **func_kwargs)
   1008 
   1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    666         # the function a weak reference to itself to avoid a reference cycle.
    667         with OptionalXlaContext(compile_with_xla):
--> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    669         return out
    670 

/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    992           except Exception as e:  # pylint:disable=broad-except
    993             if hasattr(e, "ag_error_metadata"):
--> 994               raise e.ag_error_metadata.to_exception(e)
    995             else:
    996               raise

TypeError: in user code:

    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:864 train_function  *
        return step_function(self, iterator)
    <ipython-input-4-14a2210abe73>:5 kendalls_tau  *
        kendall = tfp.stats.kendalls_tau(a, b)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:196 kendalls_tau  **
        lexa = lexicographical_indirect_sort(y_true, y_pred)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:154 lexicographical_indirect_sort
        left, _, lexicographic = tf.while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:614 new_func
        return func(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2531 while_loop_v2
        return while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2729 while_loop
        return while_v2.while_loop(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:214 while_loop
        body_graph = func_graph_module.func_graph_from_py_func(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:200 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:152 body
        tf.cond(not_equal, secondary_sort, lambda: lexicographic))
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1438 cond_for_tf_v2
        return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:546 new_func
        return func(*args, **kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1254 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py:83 cond_v2
        true_graph = func_graph_module.func_graph_from_py_func(
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:148 secondary_sort
        tensorshape_util.set_shape(x, [n])
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py:328 set_shape
        tensor.set_shape(shape)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:758 set_shape
        shape = tensor_shape.TensorShape(shape)
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 __init__
        self._dims = [Dimension(d) for d in dims]
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 <listcomp>
        self._dims = [Dimension(d) for d in dims]
    /data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:206 __init__
        six.raise_from(
    <string>:3 raise_from
        

    TypeError: Dimension value must be integer or None or have an __index__ method, got value '<tf.Tensor 'kendalls_tau/lexicographical_indirect_sort/size0/strided_slice_1:0' shape=() dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'

How can I fix this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions