123 votes

Que fait la fonction de collecte en pytorch en termes simples?

J'ai été dans la doc officielle et ce mais il est difficile de comprendre ce qui se passe.

J'essaie de comprendre un DQN code source et il utilise la rassembler la fonction sur la ligne 197.

Quelqu'un pourrait-il expliquer en termes simples ce que l'rassembler fonction n'? Quel est le but de cette fonction?

311voto

Ritesh Points 881

torch.gather crée un nouveau tenseur de l'entrée du tenseur en prenant les valeurs de chaque rangée le long de la cote dim. Les valeurs en torch.LongTensor, passée en index, spécifier la valeur à prendre de chaque 'ligne'. La dimension de la sortie du tenseur est la même que la dimension de l'indice de tenseur. Illustration suivante à partir de l'officiel docs explique plus clairement: Pictoral representation from the docs

(Note: Dans l'illustration, l'indexation commence à partir de 1 et non de 0).

Dans le premier exemple, la dimension donnée est le long de lignes (de haut en bas), donc pour coordonnées (1,1) de la position de l' result, il prend la ligne de la valeur de l' index de la src c'est - 1. De (1,1) dans la source de la valeur est 1 donc, sorties 1 de (1,1) en result. De même pour (2,2) la ligne de la valeur de l'indice pour l' src est 3. Au (3,2) la valeur en src est 8 , et donc les sorties 8 et ainsi de suite.

De la même façon pour le deuxième exemple, l'indexation est le long des colonnes, et donc a (2,2) position de l' result, la valeur de la colonne de l'index pour src est 3, donc à (2,3) à partir de src ,6 sont prises et sorties d' result à (2,2)

94voto

cleros Points 991

L' torch.gather de la fonction (ou torch.Tensor.gather) est un multi-indice de sélection de la méthode. Regardez l'exemple suivant de l'officiel docs:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  1],
#        [ 4,  3]])

Commençons par passer par la sémantique des arguments différents: Le premier argument, input, est la source du tenseur que nous voulons sélectionner les éléments à partir. La deuxième, dim, est la dimension (ou de l'axe dans tensorflow/numpy) que nous souhaitons recueillir le long. Et enfin, index sont les indices de l'indice input. Comme pour la sémantique de l'opération, c'est comment les docs officielles de l'expliquer:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

Donc, nous allons passer par l'exemple.

l'entrée tenseur est - [[1, 2], [3, 4]], et le dim argument est - 1, c'est à dire que nous voulons recueillir à partir de la deuxième dimension. Les indices de la deuxième dimension sont donnés en [0, 0] et [1, 0].

Comme nous "sauter" la première dimension (la dimension que nous voulons recueillir est - 1), la première dimension de la résultat est implicitement donné que la première dimension de l' index. Cela signifie que les indices de la seconde dimension, ou la colonne des indices, mais pas la ligne des indices. Ceux qui sont donnés par les indices de l' index du tenseur de lui-même. Pour l'exemple, cela signifie que la sortie va être en première ligne d'une sélection des éléments de l' input du tenseur de première ligne en tant que bien, comme donné par la première ligne de l' index tenseur de première ligne. Comme la colonne des indices sont donnés par l' [0, 0], nous avons donc sélectionner le premier élément de la première ligne de l'entrée deux fois, résultant en [1, 1]. De même, les éléments de la deuxième ligne de résultat sont le résultat de l'indexation de la deuxième ligne de l' input tenseur par les éléments de la deuxième ligne de l' index tenseur, résultant en [4, 3].

Pour illustrer encore plus loin, nous allons swap de la dimension dans l'exemple:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  2],
#        [ 3,  2]])

Comme vous pouvez le voir, les indices sont maintenant recueillies le long de la première dimension.

Pour l'exemple que vous avez référé,

current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))

gather indexe les lignes de la q-valeurs (par exemple q-valeurs dans un lot de q-valeurs) par le lot-liste des actions. Le résultat sera le même que si vous aviez fait ce qui suit (mais il sera beaucoup plus rapide qu'une boucle):

q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
    q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X