Mise à jour : Ici est un petit cahier de collaboration pour la démonstration de cette réponse.
Imaginez que vous ayez un ensemble de données : [1, 2, 3, 4, 5, 6]
alors :
Fonctionnement de ds.shuffle()
dataset.shuffle(buffer_size=3)
allouera un tampon de taille 3 pour choisir des entrées aléatoires. Ce tampon sera connecté à l'ensemble de données source. Nous pourrions l'imaginer comme ceci :
Random buffer
|
| Source dataset where all other elements live
| |
[1,2,3] <= [4,5,6]
Supposons que l'entrée 2
a été prise dans le tampon aléatoire. L'espace libre est rempli par l'élément suivant du tampon source, à savoir 4
:
2 <= [1,3,4] <= [5,6]
Nous continuons à lire jusqu'à ce qu'il ne reste plus rien :
1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6] <= []
6 <= [4] <= []
4 <= [] <= []
Comment fonctionne ds.repeat()
Dès que toutes les entrées sont lues dans l'ensemble de données et que vous essayez de lire l'élément suivant, l'ensemble de données émet une erreur. C'est là que ds.repeat()
entre en jeu. Il va réinitialiser l'ensemble de données, le rendant à nouveau comme ceci :
[1,2,3] <= [4,5,6]
Que produira ds.batch()
El ds.batch()
prendra le premier batch_size
et en faire un lot. Ainsi, une taille de lot de 3 pour notre ensemble de données d'exemple produira deux enregistrements de lot :
[2,1,5]
[3,6,4]
Comme nous avons un ds.repeat()
avant le lot, la génération des données se poursuivra. Mais l'ordre des éléments sera différent, du fait de la ds.random()
. Ce qu'il faut prendre en compte, c'est que 6
ne sera jamais présent dans le premier lot, en raison de la taille du tampon aléatoire.