10 votes

Tensorflow ne peut pas obtenir `image.shape` à partir de la méthode dans `dataset.map( mapFn )`

Je suis en train de faire l'équivalent tensorflow de torch.transforms.Resize(TRAIN_IMAGE_SIZE), qui redimensionne la plus petite dimension de l'image à TRAIN_IMAGE_SIZE. Quelque chose comme ça

def transforms(filename):
  parts = tf.strings.split(filename, '/')
  label = parts[-2]

  image = tf.io.read_file(filename)
  image = tf.image.decode_jpeg(image)
  image = tf.image.convert_image_dtype(image, tf.float32)

  # cela ne fonctionne pas avec Dataset.map() car image.shape=(None,None,3) à partir de Dataset.map()
  image = largest_sq_crop(image) 

  image = tf.image.resize(image, (256,256))
  return image, label

list_ds = tf.data.Dataset.list_files('{}/*/*'.format(DATASET_PATH))
images_ds = list_ds.map(transforms).batch(4)

La réponse simple est ici: Tensorflow: Recadrer la plus grande région carrée centrale de l'image

Mais lorsque j'utilise la méthode avec tf.data.Dataset.map(transforms), j'obtiens shape=(None,None,3) à partir de l'intérieur de largest_sq_crop(image). La méthode fonctionne bien lorsque je l'appelle normalement.

3voto

michael Points 414

J'ai trouvé la réponse. Cela avait à voir avec le fait que ma méthode de redimensionnement fonctionnait bien avec une exécution rapide, par exemple tf.executing_eagerly()==True mais échouait lorsqu'elle était utilisée dans dataset.map(). Apparemment, dans cet environnement d'exécution, tf.executing_eagerly()==False.

Mon erreur était dans la façon dont j'extrais la forme de l'image pour obtenir les dimensions pour le redimensionnement. L'exécution du graphique Tensorflow ne semble pas prendre en charge l'accès au tuple tensor.shape.

  # incorrect
  b,h,w,c = img.shape
  print("ERR> ", h,w,c)
  # ERR>  None None 3

  # aussi incorrect
  b = img.shape[0]
  h = img.shape[1]
  w = img.shape[2]
  c = img.shape[3]
  print("ERR> ", h,w,c)
  # ERR>  None None 3

  # mais cela fonctionne !!!
  shape = tf.shape(img)
  b = shape[0]
  h = shape[1]
  w = shape[2]
  c = shape[3]
  img = tf.reshape( img, (-1,h,w,c))
  print("OK> ", h,w,c)
  # OK>  Tensor("strided_slice_2:0", shape=(), dtype=int32) Tensor("strided_slice_3:0", shape=(), dtype=int32) Tensor("strided_slice_4:0", shape=(), dtype=int32)

J'utilisais les dimensions de la forme en aval dans ma fonction dataset.map() et cela a déclenché l'exception suivante car elle obtenait None au lieu d'une valeur.

TypeError: Impossible de convertir l'objet de type  en Tensor. Contenu : (-1, None, None, 3). Considérez le moulage des éléments en un type pris en charge.

Lorsque j'ai changé pour extraire manuellement la forme de tf.shape(), tout fonctionnait bien.

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X