L' torch.gather
de la fonction (ou torch.Tensor.gather
) est un multi-indice de sélection de la méthode. Regardez l'exemple suivant de l'officiel docs:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1, 1],
# [ 4, 3]])
Commençons par passer par la sémantique des arguments différents: Le premier argument, input
, est la source du tenseur que nous voulons sélectionner les éléments à partir. La deuxième, dim
, est la dimension (ou de l'axe dans tensorflow/numpy) que nous souhaitons recueillir le long. Et enfin, index
sont les indices de l'indice input
.
Comme pour la sémantique de l'opération, c'est comment les docs officielles de l'expliquer:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Donc, nous allons passer par l'exemple.
l'entrée tenseur est - [[1, 2], [3, 4]]
, et le dim argument est - 1
, c'est à dire que nous voulons recueillir à partir de la deuxième dimension. Les indices de la deuxième dimension sont donnés en [0, 0]
et [1, 0]
.
Comme nous "sauter" la première dimension (la dimension que nous voulons recueillir est - 1
), la première dimension de la résultat est implicitement donné que la première dimension de l' index
. Cela signifie que les indices de la seconde dimension, ou la colonne des indices, mais pas la ligne des indices. Ceux qui sont donnés par les indices de l' index
du tenseur de lui-même.
Pour l'exemple, cela signifie que la sortie va être en première ligne d'une sélection des éléments de l' input
du tenseur de première ligne en tant que bien, comme donné par la première ligne de l' index
tenseur de première ligne. Comme la colonne des indices sont donnés par l' [0, 0]
, nous avons donc sélectionner le premier élément de la première ligne de l'entrée deux fois, résultant en [1, 1]
. De même, les éléments de la deuxième ligne de résultat sont le résultat de l'indexation de la deuxième ligne de l' input
tenseur par les éléments de la deuxième ligne de l' index
tenseur, résultant en [4, 3]
.
Pour illustrer encore plus loin, nous allons swap de la dimension dans l'exemple:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1, 2],
# [ 3, 2]])
Comme vous pouvez le voir, les indices sont maintenant recueillies le long de la première dimension.
Pour l'exemple que vous avez référé,
current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))
gather
indexe les lignes de la q-valeurs (par exemple q-valeurs dans un lot de q-valeurs) par le lot-liste des actions. Le résultat sera le même que si vous aviez fait ce qui suit (mais il sera beaucoup plus rapide qu'une boucle):
q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)