9 votes

TensorFlow tf.data.Dataset et mise en buckets

Pour un réseau LSTM, j'ai vu de grandes améliorations avec le bucketing.

J'ai trouvé la section sur le bucketing dans la documentation de TensorFlow dans (tf.contrib).

Cependant, dans mon réseau, j'utilise l'API tf.data.Dataset, en particulier je travaille avec des TFRecords, donc mon pipeline d'entrée ressemble à ceci

dataset = tf.data.TFRecordDataset(TFRECORDS_PATH)
dataset = dataset.map(_parse_function)
dataset = dataset.map(_scale_function)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.padded_batch(batch_size, padded_shapes={.....})

Comment puis-je incorporer la méthode de bucketing dans le pipeline de tf.data.Dataset ?

Si cela est important, dans chaque enregistrement du fichier TFRecords j'ai la longueur de séquence enregistrée sous forme d'entier.

5voto

vijay m Points 10299

Divers cas d'utilisation du `bucketing` en utilisant `Dataset API` sont bien expliqués [ici](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py).

`**bucket_by_sequence_length()` exemple:**

def elements_gen():
   texte = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2]]
   label = [1, 2, 1, 2]
   pour x, y in zip(texte, label):
       yield (x, y)

def element_length_fn(x, y):
   retourne tf.shape(x)[0]

dataset = tf.data.Dataset.from_generator(generator=elements_gen,
                                     output_shapes=([None],[]),
                                     output_types=(tf.int32, tf.int32))

dataset =   dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_func=element_length_fn,
                                                              bucket_batch_sizes=[2, 2, 2],
                                                              bucket_boundaries=[0, 8]))

batch = dataset.make_one_shot_iterator().get_next()

avec tf.Session() as sess:

   pour _ in range(2):
      print('Get_next:')
      print(sess.run(batch))

Sortie:

Get_next:
(array([[1, 2, 3, 0, 0],
   [3, 4, 5, 6, 7]], dtype=int32), array([1, 2], dtype=int32))
Get_next:
(array([[1, 2, 0, 0],
   [8, 9, 0, 2]], dtype=int32), array([1, 2], dtype=int32))``

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X