5 votes

TensorFlow : `tf.data.Dataset.from_generator()` ne fonctionne pas avec les chaînes de caractères sur Python 3.x

J'ai besoin d'itérer à travers un grand nombre de fichiers d'images et de fournir les données à Tensorflow. J'ai créé un Dataset par une fonction génératrice qui produit les noms de chemin d'accès aux fichiers sous forme de chaînes, puis transforme les chaînes de chemin d'accès en données d'image à l'aide de la fonction map . Mais cela a échoué car la génération de valeurs de chaîne ne fonctionne pas, comme indiqué ci-dessous. Existe-t-il un correctif ou un moyen de contourner ce problème ?

2017-12-07 15:29:05.820708: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
producing data/miniImagenet/val/n01855672/n0185567200001000.jpg
2017-12-07 15:29:06.009141: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type str
2017-12-07 15:29:06.009215: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type str
     [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_STRING], token="pyfunc_1"](arg0)]]
Traceback (most recent call last):
  File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1323, in _do_call
    return fn(*args)
  File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1302, in _run_fn
    status, run_metadata)
  File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.UnimplementedError: Unsupported object type str
     [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_STRING], token="pyfunc_1"](arg0)]]
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,21168]], output_types=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

Les codes de test sont indiqués ci-dessous. Il peut fonctionner correctement avec from_tensor_slices ou en mettant d'abord la liste des noms de fichiers dans un tenseur. Cependant, ces deux solutions épuiseraient la mémoire du GPU.

import tensorflow as tf

if __name__ == "__main__":
    file_names = ['data/miniImagenet/val/n01855672/n0185567200001000.jpg',
                  'data/miniImagenet/val/n01855672/n0185567200001005.jpg']
    # note: converting the file list to tensor and returning an index from generator works
    # path_to_indexes = {p: i for i, p in enumerate(file_names)}
    # file_names_tensor = tf.convert_to_tensor(file_names)

    def dataset_producer():
        for s in file_names:
            print('producing', s)
            yield s
    dataset = tf.data.Dataset.from_generator(dataset_producer, output_types=(tf.string),
                                             output_shapes=(tf.TensorShape([])))

    # note: this would also work
    # dataset = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(file_names))

    def read_image(filename):
        # filename = file_names_tensor[filename_index]
        image_file = tf.read_file(filename, name='read_file')
        image = tf.image.decode_jpeg(image_file, channels=3)
        image.set_shape((84,84,3))
        image = tf.reshape(image, [21168])
        image = tf.cast(image, tf.float32) / 255.0
        return image

    dataset = dataset.map(read_image)
    dataset = dataset.batch(2)
    data_iterator = dataset.make_one_shot_iterator()
    images = data_iterator.get_next()
    print('images', images)
    max_value = tf.argmax(images)
    with tf.Session() as session:
        result = session.run(max_value)
        print(result)

6voto

mrry Points 1

Il s'agit d'un bogue affectant Python 3.x qui était Correction de après la version 1.4 de TensorFlow. Toutes les versions de TensorFlow à partir de la version 1.5 contiennent la correction.

Si vous n'utilisez qu'une version antérieure, la solution de contournement consiste à convertir les chaînes de caractères au format bytes avant de les renvoyer du générateur. Le code suivant devrait fonctionner :

def dataset_producer():
    for s in file_names:
        print('producing', s)
        yield s.encode('utf-8')  # Convert `s` to `bytes`.

dataset = tf.data.Dataset.from_generator(dataset_producer, output_types=(tf.string),
                                         output_shapes=(tf.TensorShape([])))

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X