4 votes

Erreur Pytorch, RuntimeError : type scalaire attendu Long mais trouvé Double

J'ai rencontré l'erreur suivante lors de la formation d'un classificateur BERT. Le site

type(b_input_mask) = type(b_labels) = torch.Tensor      

type(b_labels[i]) = tensor(1., dtype=torch.float64)

type(b_input_masks[i]) = class'torch.Tensor'

Quelle pourrait être l'erreur de type de données possible ici puisque je n'ai pas tapé de variable en long ou en double ?

Merci d'avance ! Error Stack Trace

2voto

Bugface Points 44

Dans une tâche de classification, le type de données pour les étiquettes d'entrée devrait être Long, mais vous les avez assignées en tant que float64.

type(b_labels[i]) = tensor(1., dtype=torch.float64)

=>

type(b_labels[i]) = tensor(1., dtype=torch.long)

0voto

Dishin H Goyani Points 35

Vous pouvez utiliser torch.Tensor.long pour convertir le tenseur en attendu long type.

# Here, you can pass parameter like this in your call
..., labels = b_labels.long())

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X