2 votes

torch trouver les indices des lignes correspondantes dans 2 tenseurs 2D

J'ai deux tenseurs 2D, de longueurs différentes, qui sont tous deux des sous-ensembles différents du même tenseur 2D original et j'aimerais trouver toutes les "lignes" correspondantes
Par exemple

A = [[1,2,3],[4,5,6],[7,8,9],[3,3,3]
B = [[1,2,3],[7,8,9],[4,4,4]]
torch.2dintersect(A,B) -> [0,2] (the indecies of A that B also have)

Je n'ai vu que des solutions numpy, qui utilisent dtype comme dicts, et cela ne fonctionne pas pour pytorch.

Voici comment je procède en numpy

arr1 = edge_index_dense.numpy().view(np.int32)
arr2 = edge_index2_dense.numpy().view(np.int32)
arr1_view = arr1.view([('', arr1.dtype)] * arr1.shape[1])
arr2_view = arr2.view([('', arr2.dtype)] * arr2.shape[1])
intersected = np.intersect1d(arr1_view, arr2_view, return_indices=True)

5voto

Berriel Points 2567

Cette réponse a été postée avant que l'OP ne mette à jour la question avec d'autres restrictions qui ont changé le problème de manière significative.

TL;DR Vous pouvez faire quelque chose comme ça :

torch.where((A == B).all(dim=1))[0]

Tout d'abord, en supposant que vous ayez :

import torch
A = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
B = torch.Tensor([[1,2,3],[4,4,4],[7,8,9]])

Nous pouvons vérifier que A == B les retours :

>>> A == B
tensor([[ True,  True,  True],
        [ True, False, False],
        [ True,  True,  True]])

Ce que nous voulons, c'est donc : les lignes dans lesquelles ils sont tous True . Pour cela, nous pouvons utiliser la fonction .all() et spécifier la dimension qui nous intéresse, dans notre cas 1 :

>>> (A == B).all(dim=1)
tensor([ True, False,  True])

Ce que vous voulez savoir en réalité, c'est où se trouve le True sont. Pour cela, nous pouvons obtenir la première sortie de la fonction torch.where() fonction :

>>> torch.where((A == B).all(dim=1))[0]
tensor([0, 2])

2voto

okh Points 460

Si A et B sont des tenseurs 2D, le code suivant trouve les indices tels que A[indices] == B . Si plusieurs indices satisfont à cette condition, le premier indice trouvé est renvoyé. Si tous les éléments de B ne sont pas présents dans A, l'indice correspondant est ignoré.

values, indices = torch.topk(((A.t() == B.unsqueeze(-1)).all(dim=1)).int(), 1, 1)
indices = indices[values!=0]
# indices = tensor([0, 2])

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X