101 votes

Que font le batch, la répétition et le shuffle avec les données TensorFlow ?

Je suis en train d'apprendre TensorFlow mais je suis tombé sur une confusion dans le code ci-dessous :

dataset = dataset.shuffle(buffer_size = 10 * batch_size) 
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()

Je sais que l'ensemble de données contiendra d'abord toutes les données, mais qu'en est-il ? shuffle() , repeat() et batch() à l'ensemble de données ? Veuillez m'aider avec un exemple et une explication.

159voto

Vlad-HC Points 540

Mise à jour : Ici est un petit cahier de collaboration pour la démonstration de cette réponse.


Imaginez que vous ayez un ensemble de données : [1, 2, 3, 4, 5, 6] alors :

Fonctionnement de ds.shuffle()

dataset.shuffle(buffer_size=3) allouera un tampon de taille 3 pour choisir des entrées aléatoires. Ce tampon sera connecté à l'ensemble de données source. Nous pourrions l'imaginer comme ceci :

Random buffer
   |
   |   Source dataset where all other elements live
   |         |

[1,2,3] <= [4,5,6]

Supposons que l'entrée 2 a été prise dans le tampon aléatoire. L'espace libre est rempli par l'élément suivant du tampon source, à savoir 4 :

2 <= [1,3,4] <= [5,6]

Nous continuons à lire jusqu'à ce qu'il ne reste plus rien :

1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6]   <= []
6 <= [4]     <= []
4 <= []      <= []

Comment fonctionne ds.repeat()

Dès que toutes les entrées sont lues dans l'ensemble de données et que vous essayez de lire l'élément suivant, l'ensemble de données émet une erreur. C'est là que ds.repeat() entre en jeu. Il va réinitialiser l'ensemble de données, le rendant à nouveau comme ceci :

[1,2,3] <= [4,5,6]

Que produira ds.batch()

El ds.batch() prendra le premier batch_size et en faire un lot. Ainsi, une taille de lot de 3 pour notre ensemble de données d'exemple produira deux enregistrements de lot :

[2,1,5]
[3,6,4]

Comme nous avons un ds.repeat() avant le lot, la génération des données se poursuivra. Mais l'ordre des éléments sera différent, du fait de la ds.random() . Ce qu'il faut prendre en compte, c'est que 6 ne sera jamais présent dans le premier lot, en raison de la taille du tampon aléatoire.

11voto

Les méthodes suivantes dans tf.Dataset :

  1. repeat( count=0 ) La méthode répète l'ensemble des données count nombre de fois.
  2. shuffle( buffer_size, seed=None, reshuffle_each_iteration=None) La méthode mélange les échantillons dans l'ensemble de données. Le site buffer_size est le nombre d'échantillons qui sont randomisés et renvoyés en tant que tf.Dataset .
  3. batch(batch_size,drop_remainder=False) Crée des lots de l'ensemble de données avec une taille de lot donnée comme suit batch_size qui est aussi la longueur des lots.

0voto

pmod Points 21

Un exemple qui montre le bouclage sur les époques. Lors de l'exécution de ce script, remarquez la différence en

  • dataset_gen1 - L'opération de brassage produit des sorties plus aléatoires ( cela peut être plus utile lors de l'exécution d'expériences d'apprentissage automatique. )
  • dataset_gen2 - l'absence d'opération de brassage produit des éléments en séquence

Autres ajouts dans ce script

  • tf.data.experimental.sample_from_datasets - utilisé pour combiner deux ensembles de données. Notez que l'opération de brassage dans ce cas doit créer un tampon qui échantillonne de manière égale les deux ensembles de données.

    import os

    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # to avoid all those prints os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" # to avoid large "Kernel Launch Time"

    import tensorflow as tf if len(tf.config.list_physical_devices('GPU')): tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)

    class Augmentations:

    def __init__(self):
        pass
    
    @tf.function
    def filter_even(self, x):
        if x % 2 == 0:
            return False
        else:
            return True

    class Dataset:

    def __init__(self, aug, range_min=0, range_max=100):
        self.range_min = range_min
        self.range_max = range_max
        self.aug = aug
    
    def generator(self):
        dataset = tf.data.Dataset.from_generator(self._generator
                        , output_types=(tf.float32), args=())
    
        dataset = dataset.filter(self.aug.filter_even)
    
        return dataset
    
    def _generator(self):
        for item in range(self.range_min, self.range_max):
            yield(item)

    Can be used when you have multiple datasets that you wish to combine

    class ZipDataset:

    def __init__(self, datasets):
        self.datasets = datasets
        self.datasets_generators = []
    
    def generator(self):
        for dataset in self.datasets:
            self.datasets_generators.append(dataset.generator())
        return tf.data.experimental.sample_from_datasets(self.datasets_generators)

    if name == "main": aug = Augmentations() dataset1 = Dataset(aug, 0, 100) dataset2 = Dataset(aug, 100, 200) dataset = ZipDataset([dataset1, dataset2])

    epochs = 2
    shuffle_buffer = 10
    batch_size = 4
    prefetch_buffer = 5
    
    dataset_gen1 = dataset.generator().shuffle(shuffle_buffer).batch(batch_size).prefetch(prefetch_buffer)
    # dataset_gen2 = dataset.generator().batch(batch_size).prefetch(prefetch_buffer) # this will output odd elements in sequence 
    
    for epoch in range(epochs):
        print ('\n ------------------ Epoch: {} ------------------'.format(epoch))
        for X in dataset_gen1.repeat(1): # adding .repeat() in the loop allows you to easily control the end of the loop
            print (X)
    
        # Do some stuff at end of loop

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X