2 votes

Comment créer une liste de tenseurs et utiliser tf.stack dans tensorflow2 dans une boucle for

Je suis en train d'essayer de créer une liste de tenseurs et de les empiler ensemble en utilisant une boucle for dans tensorflow2. J'ai créé un exemple de test et j'ai essayé comme suit.

import tensorflow as tf

@tf.function
def test(x):
    tensor_list = []
    for i in tf.range(x):
        tensor_list.append(tf.ones(4)*tf.cast(i, tf.float32))
    return  tf.stack(tensor_list)

result = test(5)
print(result)

mais je reçois l'erreur suivante comme ceci:

Traceback (most recent call last):
  File "test.py", line 10, in 
    result = test(5)
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 457, in __call__
    result = self._call(*args, **kwds)
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 503, in _call
    self._initialize(args, kwds, add_initializers_to=initializer_map)
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 408, in _initialize
    *args, **kwds))
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 1848, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2150, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2041, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 915, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 358, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 905, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.InaccessibleTensorError: in converted code:

    test.py:8 test  *
        return  tf.stack(tensor_list)
    /root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/util/dispatch.py:180 wrapper
        return target(*args, **kwargs)
    /root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/ops/array_ops.py:1165 stack
        return gen_array_ops.pack(values, axis=axis, name=name)
    /root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_array_ops.py:6304 pack
        "Pack", values=values, axis=axis, name=name)
    /root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/op_def_library.py:793 _apply_op_helper
        op_def=op_def)
    /root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:544 create_op
        inp = self.capture(inp)
    /root/.pyenv/versions/summarization-abstractive/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:603 capture
        % (tensor, tensor.graph, self))

    InaccessibleTensorError: The tensor 'Tensor("mul:0", shape=(4,), dtype=float32)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=while_body_8, id=139870442952744); accessed from: FuncGraph(name=test, id=139870626510608).

Est-ce que quelqu'un sait ce que je fais de mal? Comment puis-je créer une liste de tenseurs et les empiler ensemble avec une boucle for dans tensorflow 2?

2voto

Lau Points 807

Parcourir un tenseur devrait généralement être fait en utilisant ´tf.map_fn´. Voici une solution qui fonctionne :

import tensorflow as tf
import numpy as np

@tf.function
def test(x):

    tensor_list = tf.map_fn(lambda inp: tf.ones(4)*tf.cast(inp, tf.float32), x, dtype=tf.dtypes.float32)

    return  tf.stack(tensor_list)

result = test(np.arange(5))
print(result)

Cependant, vous devez fournir un véritable tableau dans votre fonction test(), mais alternativement, vous pouvez appeler tf.range() à l'intérieur de la fonction tf.function pour convertir un scalaire en un tenseur.

2voto

hello_world Points 21

Utilisez tf.TensorArray au lieu de list

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X