Je suis en train de faire l'équivalent tensorflow
de torch.transforms.Resize(TRAIN_IMAGE_SIZE)
, qui redimensionne la plus petite dimension de l'image à TRAIN_IMAGE_SIZE
. Quelque chose comme ça
def transforms(filename):
parts = tf.strings.split(filename, '/')
label = parts[-2]
image = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image)
image = tf.image.convert_image_dtype(image, tf.float32)
# cela ne fonctionne pas avec Dataset.map() car image.shape=(None,None,3) à partir de Dataset.map()
image = largest_sq_crop(image)
image = tf.image.resize(image, (256,256))
return image, label
list_ds = tf.data.Dataset.list_files('{}/*/*'.format(DATASET_PATH))
images_ds = list_ds.map(transforms).batch(4)
La réponse simple est ici: Tensorflow: Recadrer la plus grande région carrée centrale de l'image
Mais lorsque j'utilise la méthode avec tf.data.Dataset.map(transforms)
, j'obtiens shape=(None,None,3)
à partir de l'intérieur de largest_sq_crop(image)
. La méthode fonctionne bien lorsque je l'appelle normalement.