Pour un réseau LSTM, j'ai vu de grandes améliorations avec le bucketing.
J'ai trouvé la section sur le bucketing dans la documentation de TensorFlow dans (tf.contrib).
Cependant, dans mon réseau, j'utilise l'API tf.data.Dataset
, en particulier je travaille avec des TFRecords, donc mon pipeline d'entrée ressemble à ceci
dataset = tf.data.TFRecordDataset(TFRECORDS_PATH)
dataset = dataset.map(_parse_function)
dataset = dataset.map(_scale_function)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.padded_batch(batch_size, padded_shapes={.....})
Comment puis-je incorporer la méthode de bucketing dans le pipeline de tf.data.Dataset
?
Si cela est important, dans chaque enregistrement du fichier TFRecords j'ai la longueur de séquence enregistrée sous forme d'entier.