42 votes

Comment obtenir la valeur d'un tenseur dans PyTorch?

Imprimer le tenseur donne :

>>> x = torch.tensor([3])
>>> print(x)
tensor([3])

De même, l'indexation de son .data donne :

>>> x.data[0]
tensor(3)

Comment obtenir juste la valeur 3 ?

99voto

Vimal Thilak Points 1168

Vous pouvez utiliser x.item() pour obtenir un nombre Python à partir d'un tenseur qui a un seul élément.

15voto

prosti Points 4630

Pour obtenir une valeur à partir d'un tenseur à élément unique, x.item() fonctionne toujours :

Exemple : Tenseur à élément unique sur CPU

x = torch.tensor([3])
x.item()

Sortie :

3

Exemple : Tenseur à élément unique sur CPU avec AD

x = torch.tensor([3.], requires_grad=True)
x.item()

Sortie :

3.0

REMARQUE : Nous devions utiliser l'arithmétique à virgule flottante pour l'AD

Exemple : Tenseur à élément unique sur CUDA

x = torch.tensor([3], device='cuda')
x.item()

Sortie :

3

Exemple : Tenseur à élément unique sur CUDA avec AD

x = torch.tensor([3.], device='cuda', requires_grad=True)
x.item()

Sortie :

3.0

Exemple : Tenseur à élément unique sur CUDA avec AD à nouveau

x = torch.ones((1,1), device='cuda', requires_grad=True)
x.item()

Sortie :

1.0

Pour obtenir une valeur à partir d'un tenseur à plusieurs éléments, nous devons faire attention :

L'exemple suivant montrera que le tenseur PyTorch résidant sur le CPU partage le même stockage qu'un tableau numpy na

Exemple : Stockage partagé

import torch
a = torch.ones((1,2))
print(a)
na = a.numpy()
na[0][0]=10
print(na)
print(a)

Sortie :

tensor([[1., 1.]])
[[10.  1.]]
tensor([[10.,  1.]])

Exemple : Éliminer l'effet du stockage partagé, copier d'abord le tableau numpy

Pour éviter l'effet du stockage partagé, nous devons copier() le tableau numpy na vers un nouveau tableau numpy nac. La méthode copy() de Numpy crée un nouveau stockage séparé.

import torch
a = torch.ones((1,2))
print(a)
na = a.numpy()
nac = na.copy()
nac[0][0]=10
print(nac)
print(na)
print(a)

Sortie :

tensor([[1., 1.]])
[[10.  1.]]
[[1. 1.]]
tensor([[1., 1.]])

Maintenant, seul le tableau numpy nac sera modifié avec la ligne nac[0][0]=10, na et a resteront tels quels.

Exemple : Tenseur CPU requires_grad=True

import torch
a = torch.ones((1,2), requires_grad=True)
print(a)
na = a.detach().numpy()
na[0][0]=10
print(na)
print(a)

Sortie :

tensor([[1., 1.]], requires_grad=True)
[[10.  1.]]
tensor([[10.,  1.]], requires_grad=True)

Ici, nous appelons :

na = a.numpy() 

Cela provoquerait : RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead., car les tenseurs avec requires_grad=True sont enregistrés par l'AD de PyTorch.

C'est pourquoi nous devons d'abord les detach() avant de les convertir en utilisant numpy().

Exemple : Tenseur CUDA requires_grad=False

a = torch.ones((1,2), device='cuda')
print(a)
na = a.to('cpu').numpy()
na[0][0]=10
print(na)
print(a)

Sortie :

tensor([[1., 1.]], device='cuda:0')
[[10.  1.]]
tensor([[1., 1.]], device='cuda:0')

Ici, nous ne convertissons tout simplement pas le tenseur CUDA en CPU. Il n'y a aucun effet de partage de stockage ici.

Exemple : Tenseur CUDA requires_grad=True

a = torch.ones((1,2), device='cuda', requires_grad=True)
print(a)
na = a.detach().to('cpu').numpy()
na[0][0]=10
print(na)
print(a)

Sortie :

tensor([[1., 1.]], device='cuda:0', requires_grad=True)
[[10.  1.]]
tensor([[1., 1.]], device='cuda:0', requires_grad=True)

Sans la méthode detach(), l'erreur RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead. sera générée.

Sans la méthode .to('cpu'), l'erreur TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first. sera générée.

15voto

Ioannis Nasios Points 3273

Convertir le tenseur en numpy:

x.numpy()[0]

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X