Open
Description
I am trying to use TensorFlow probability as a metric in Keras. With respect to kendalls_tau, I get the following error:
import tensorflow_probability as tfp
import tensorflow as tf
import numpy as np
def kendalls_tau(y_true, y_pred):
a = tf.reshape(y_true, shape=(-1,))
b = tf.reshape(y_pred, shape=(-1,))
kendall = tfp.stats.kendalls_tau(a, b)
return kendall
inputs = tf.keras.layers.Input(shape=(3,))
outputs = tf.keras.layers.Dense(2)(inputs)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer="Adam", loss="mse", metrics=kendalls_tau)
x = np.random.random((2, 3))
y = np.random.randint(0, 2, (2, 2) )
model.fit(x, y)
TypeError: in user code:
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:864 train_function *
return step_function(self, iterator)
<ipython-input-4-14a2210abe73>:5 kendalls_tau *
kendall = tfp.stats.kendalls_tau(a, b)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:196 kendalls_tau **
lexa = lexicographical_indirect_sort(y_true, y_pred)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:154 lexicographical_indirect_sort
left, _, lexicographic = tf.while_loop(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:614 new_func
return func(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2531 while_loop_v2
return while_loop(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2729 while_loop
return while_v2.while_loop(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:214 while_loop
body_graph = func_graph_module.func_graph_from_py_func(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:200 wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:152 body
tf.cond(not_equal, secondary_sort, lambda: lexicographic))
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1438 cond_for_tf_v2
return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:546 new_func
return func(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1254 cond
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py:83 cond_v2
true_graph = func_graph_module.func_graph_from_py_func(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:148 secondary_sort
tensorshape_util.set_shape(x, [n])
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py:328 set_shape
tensor.set_shape(shape)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:758 set_shape
shape = tensor_shape.TensorShape(shape)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 __init__
self._dims = [Dimension(d) for d in dims]
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 <listcomp>
self._dims = [Dimension(d) for d in dims]
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:206 __init__
six.raise_from(
<string>:3 raise_from
TypeError: Dimension value must be integer or None or have an __index__ method, got value '<tf.Tensor 'kendalls_tau/lexicographical_indirect_sort/size0/strided_slice_1:0' shape=() dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/data/Mestrado/Ensaios/drbc_tf.py in
257 x = np.random.random((2, 3))
258 y = np.random.randint(0, 2, (2, 2) )
---> 259 model.fit(x, y)
260
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1190 _r=1):
1191 callbacks.on_train_batch_begin(step)
-> 1192 tmp_logs = self.train_function(iterator)
1193 if data_handler.should_sync:
1194 context.async_wait()
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
883
884 with OptionalXlaContext(self._jit_compile):
--> 885 result = self._call(*args, **kwds)
886
887 new_tracing_count = self.experimental_get_tracing_count()
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
927 # This is the first call of __call__, so we have to initialize.
928 initializers = []
--> 929 self._initialize(args, kwds, add_initializers_to=initializers)
930 finally:
931 # At this point we know that the initialization is complete (or less
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
757 self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
758 self._concrete_stateful_fn = (
--> 759 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
760 *args, **kwds))
761
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
3057 args, kwargs = None, None
3058 with self._lock:
-> 3059 graph_function, _ = self._maybe_define_function(args, kwargs)
3060 return graph_function
3061
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3454
3455 self._function_cache.missed.add(call_context_key)
-> 3456 graph_function = self._create_graph_function(args, kwargs)
3457 self._function_cache.primary[cache_key] = graph_function
3458
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3289 arg_names = base_arg_names + missing_arg_names
3290 graph_function = ConcreteFunction(
-> 3291 func_graph_module.func_graph_from_py_func(
3292 self._name,
3293 self._python_function,
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
1005 _, original_func = tf_decorator.unwrap(python_func)
1006
-> 1007 func_outputs = python_func(*func_args, **func_kwargs)
1008
1009 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
666 # the function a weak reference to itself to avoid a reference cycle.
667 with OptionalXlaContext(compile_with_xla):
--> 668 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
669 return out
670
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
992 except Exception as e: # pylint:disable=broad-except
993 if hasattr(e, "ag_error_metadata"):
--> 994 raise e.ag_error_metadata.to_exception(e)
995 else:
996 raise
TypeError: in user code:
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:864 train_function *
return step_function(self, iterator)
<ipython-input-4-14a2210abe73>:5 kendalls_tau *
kendall = tfp.stats.kendalls_tau(a, b)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:196 kendalls_tau **
lexa = lexicographical_indirect_sort(y_true, y_pred)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:154 lexicographical_indirect_sort
left, _, lexicographic = tf.while_loop(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:614 new_func
return func(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2531 while_loop_v2
return while_loop(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:2729 while_loop
return while_v2.while_loop(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:214 while_loop
body_graph = func_graph_module.func_graph_from_py_func(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py:200 wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:152 body
tf.cond(not_equal, secondary_sort, lambda: lexicographic))
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1438 cond_for_tf_v2
return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:546 new_func
return func(*args, **kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py:1254 cond
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py:83 cond_v2
true_graph = func_graph_module.func_graph_from_py_func(
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/stats/kendalls_tau.py:148 secondary_sort
tensorshape_util.set_shape(x, [n])
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py:328 set_shape
tensor.set_shape(shape)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:758 set_shape
shape = tensor_shape.TensorShape(shape)
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 __init__
self._dims = [Dimension(d) for d in dims]
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:765 <listcomp>
self._dims = [Dimension(d) for d in dims]
/data/Mestrado/py_env_gnn/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:206 __init__
six.raise_from(
<string>:3 raise_from
TypeError: Dimension value must be integer or None or have an __index__ method, got value '<tf.Tensor 'kendalls_tau/lexicographical_indirect_sort/size0/strided_slice_1:0' shape=() dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'
How can I fix this?
Metadata
Metadata
Assignees
Labels
No labels