15 votes

Quelle est la manière la plus efficace d'obtenir l'intersection de k tableaux triés ?

Étant donné k tableaux triés, quelle est la manière la plus efficace d'obtenir l'intersection de ces listes ?

Exemple

INPUT :

[[1,3,5,7], [1,1,3,5,7], [1,4,7,9]] 

Sortie :

[1,7]

Il existe un moyen d'obtenir l'union de k tableaux triés, basé sur ce que j'ai lu dans le livre Elements of programming interviews, en nlogk temps. Je me demandais s'il était possible de faire quelque chose de similaire pour l'intersection.

## merge sorted arrays in nlogk time [ regular appending and merging is nlogn time ]
import heapq
def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]

    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))

    res = []

    # collect results in nlogK time
    while heap:
        elem, ary = heapq.heappop(heap)
        it = srtd_iters[ary]
        res.append(elem)
        nxt = next(it, None)
        if nxt:
            heapq.heappush(heap, (nxt, ary))

EDIT : évidemment, il s'agit d'une question d'algorithme que j'essaie de résoudre. Je ne peux donc pas utiliser les fonctions intégrées comme l'intersection d'ensembles, etc.

16voto

Raymond Hettinger Points 231

Exploitation de l'ordre de tri

Voici une approche O(n) qui ne nécessite pas de structures de données spéciales ou de mémoire auxiliaire au-delà de l'exigence fondamentale d'un itérateur et d'une valeur par sous-liste :

from itertools import cycle

def intersection(data):
    ITERATOR, VALUE = 0, 1
    n = len(data)
    result = []
    try:
        pairs = cycle([(it := iter(sublist)), next(it)] for sublist in data)
        pair = next(pairs)
        curr = pair[VALUE]  # Candidate is the largest value seen so far
        matches = 1         # Number of pairs where the candidate occurs
        while True:
            iterator, value = pair = next(pairs)
            while value < curr:
                value = next(iterator)
            pair[VALUE] = value
            if value > curr:
                curr, matches = value, 1
                continue
            matches += 1
            if matches != n:
                continue
            result.append(curr)
            while (value := next(iterator)) == curr:
                pass
            pair[VALUE] = value
            curr, matches = value, 1
    except StopIteration:
        return result

Voici un exemple de session :

>>> data = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
>>> intersection(data)
[1, 7]

Algorithme en mots

L'algorithme tourne autour des paires itérateur-valeur. Si une valeur correspond à toutes les paires, elle appartient à l'intersection. Si une valeur est inférieure à toutes les autres vues jusqu'à présent, l'itérateur actuel est avancé. Si une valeur est supérieure à toutes celles vues jusqu'à présent, elle devient la nouvelle cible et le nombre de correspondances est remis à un. Lorsqu'un itérateur est épuisé, l'algorithme est terminé.

Ne dépend pas des fonctions intégrées

L'utilisation de itertools.cycle() est totalement facultatif. On peut facilement l'émuler en incrémentant un index qui s'enroule à la fin.

Au lieu de :

iterator, value = pair = next(pairs)

Vous pourriez écrire :

pairnum += 1
if pairnum == n:
    pairnum = 0
iterator, value = pair = pairs[pairnum]    

Ou de manière plus compacte :

pairnum = (pairnum + 1) % n
iterator, value = pair = pairs[pairnum] 

Valeurs répétées

Si les répétitions doivent être conservées (comme un multiset), il s'agit d'une modification facile, il suffit de changer les quatre lignes après result.append(curr) pour retirer l'élément correspondant de chaque itérateur :

def intersection(data):
    ITERATOR, VALUE = 0, 1
    n = len(data)
    result = []
    try:
        pairs = cycle([(it := iter(sublist)), next(it)] for sublist in data)
        pair = next(pairs)
        curr = pair[VALUE]  # Candidate is the largest value seen so far
        matches = 1         # Number of pairs where the candidate occurs
        while True:
            iterator, value = pair = next(pairs)
            while value < curr:
                value = next(iterator)
            pair[VALUE] = value
            if value > curr:
                curr, matches = value, 1
                continue
            matches += 1
            if matches != n:
                continue
            result.append(curr)
            for i in range(n):
                iterator, value = pair = next(pairs)
                pair[VALUE] = next(iterator)
            curr, matches = pair[VALUE], 1
    except StopIteration:
        return result

5voto

Oli Points 318

Oui, c'est possible ! J'ai modifié votre exemple de code pour le faire.

Ma réponse suppose que votre question porte sur l'algorithme - si vous voulez le code le plus rapide en utilisant set voir d'autres réponses.

Cela permet de maintenir le O(n log(k)) complexité temporelle : tout le code entre if lowest != elem or ary != times_seen: y unbench_all = False es O(log(k)) . Il y a une boucle imbriquée à l'intérieur de la boucle principale ( for unbenched in range(times_seen): ) mais cela ne fonctionne que times_seen temps, et times_seen est initialement égal à 0 et est remis à 0 à chaque fois que cette boucle interne est exécutée, et ne peut être incrémenté qu'une fois par itération de la boucle principale, de sorte que la boucle interne ne peut pas effectuer plus d'itérations au total que la boucle principale. Ainsi, puisque le code à l'intérieur de la boucle interne est O(log(k)) et s'exécute au maximum autant de fois que la boucle externe, et la boucle externe est O(log(k)) et exécute n fois, l'algorithme est O(n log(k)) .

Cet algorithme repose sur la façon dont les tuples sont comparés en Python. Il compare les premiers éléments des tuples, et s'ils sont égaux, il compare les deuxièmes éléments (par ex. (x, a) < (x, b) est vrai si et seulement si a < b ). Dans cet algorithme, contrairement à l'exemple de code de la question, lorsqu'un élément est retiré du tas, il n'est pas nécessairement poussé à nouveau dans la même itération. Puisque nous devons vérifier si toutes les sous-listes contiennent le même nombre, après qu'un nombre soit sorti du tas, sa sous-liste est ce que j'appelle "benched", ce qui signifie qu'elle n'est pas ajoutée au tas. C'est parce que nous devons vérifier si d'autres sous-listes contiennent le même élément, donc l'ajout du prochain élément de cette sous-list n'est pas nécessaire pour le moment.

Si un nombre est effectivement dans toutes les sous-listes, alors le tas ressemblera à quelque chose comme [(2,0),(2,1),(2,2),(2,3)] avec tous les premiers éléments des tuples identiques, donc heappop sélectionnera celui dont l'indice de sous-liste est le plus faible. Cela signifie que le premier index 0 sera supprimé et que times_seen sera incrémenté jusqu'à 1, puis l'index 1 sera supprimé et times_seen sera incrémenté à 2 - si ary n'est pas égal à times_seen alors le nombre n'est pas dans l'intersection de toutes les sous-listes. Cela conduit à la condition if lowest != elem or ary != times_seen: qui décide quand un nombre ne doit pas figurer dans le résultat. Le site else branche de cette if La déclaration est destinée aux cas où elle pourrait encore figurer dans le résultat.

Le site unbench_all booléen est pour quand toutes les sous-listes doivent être retirées du banc - cela pourrait être parce que :

  1. On sait que le numéro actuel ne se trouve pas dans l'intersection des sous-listes.
  2. On sait qu'il se trouve dans l'intersection des sous-listes suivantes

Lorsque unbench_all es True toutes les sous-listes qui ont été retirées du tas sont ajoutées à nouveau. On sait que ce sont celles qui ont des indices dans range(times_seen) puisque l'algorithme ne retire des éléments du tas que s'ils ont le même numéro, ils doivent donc avoir été retirés dans l'ordre de l'indice, de manière contiguë et en commençant par l'indice 0, et il doit y avoir times_seen d'entre eux. Cela signifie que nous n'avons pas besoin de stocker les indices des sous-listes mises en attente, seulement le nombre de celles qui ont été mises en attente.

import heapq

def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]

    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))

    res = []

    # the number of tims that the current number has been seen
    times_seen = 0

    # the lowest number from the heap - currently checking if the first numbers in all sub-lists are equal to this
    lowest = heap[0][0] if heap else None

    # collect results in nlogK time
    while heap:
        elem, ary = heap[0]
        unbench_all = True

        if lowest != elem or ary != times_seen:
            if lowest == elem:
                heapq.heappop(heap)
                it = srtd_iters[ary]
                nxt = next(it, None)
                if nxt:
                    heapq.heappush(heap, (nxt, ary))
        else:
            heapq.heappop(heap)
            times_seen += 1

            if times_seen == len(srtd_arys):
                res.append(elem)
            else:
                unbench_all = False

        if unbench_all:
            for unbenched in range(times_seen):
                unbenched_it = srtd_iters[unbenched]
                nxt = next(unbenched_it, None)
                if nxt:
                    heapq.heappush(heap, (nxt, unbenched))
            times_seen = 0
            if heap:
                lowest = heap[0][0]

    return res

if __name__ == '__main__':
    a1 = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]
    a2 = [[1, 1], [1, 1, 2, 2, 3]]
    for arys in [a1, a2]:
        print(mergeArys(arys))

Un algorithme équivalent peut être écrit comme ceci, si vous préférez :

def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]

    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))

    res = []

    # collect results in nlogK time
    while heap:
        elem, ary = heap[0]
        lowest = elem
        keep_elem = True
        for i in range(len(srtd_arys)):
            elem, ary = heap[0]
            if lowest != elem or ary != i:
                if ary != i:
                    heapq.heappop(heap)
                    it = srtd_iters[ary]
                    nxt = next(it, None)
                    if nxt:
                        heapq.heappush(heap, (nxt, ary))

                keep_elem = False
                i -= 1
                break
            heapq.heappop(heap)

        if keep_elem:
            res.append(elem)

        for unbenched in range(i+1):
            unbenched_it = srtd_iters[unbenched]
            nxt = next(unbenched_it, None)
            if nxt:
                heapq.heappush(heap, (nxt, unbenched))

        if len(heap) < len(srtd_arys):
            heap = []

    return res

3voto

AlexTorx Points 550

Vous pouvez utiliser des ensembles intégrés et des intersections d'ensembles :

d = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]] 
result = set(d[0]).intersection(*d[1:])
{1, 7}

2voto

Onyambu Points 16644

Vous pouvez utiliser reduce :

from functools import reduce

a = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]] 
reduce(lambda x, y: x & set(y), a[1:], set(a[0]))
 {1, 7}

1voto

egjlmn1 Points 261

J'ai mis au point cet algorithme. Il ne dépasse pas O(n k) Je ne sais pas si c'est assez bon pour vous. L'intérêt de cet algorithme est que vous pouvez avoir k index pour chaque tableau et qu'à chaque itération vous trouvez les index du prochain élément dans l'intersection et augmentez chaque index jusqu'à ce que vous dépassiez les limites d'un tableau et qu'il n'y ait plus d'éléments dans l'intersection. l'astuce est que, puisque les tableaux sont triés, vous pouvez regarder deux éléments dans deux tableaux différents et si l'un est plus grand que l'autre, vous pouvez instantanément jeter l'autre parce que vous savez que vous ne pouvez pas avoir un nombre plus petit que celui que vous regardez. le pire cas de cet algorithme est que chaque index sera augmenté jusqu'à la limite, ce qui prend k n temps puisqu'un indice ne peut pas diminuer sa valeur.

  inter = []

  for n in range(len(arrays[0])):
    if indexes[0] >= len(arrays[0]):
        return inter
    for i in range(1,k):
      if indexes[i] >= len(arrays[i]):
        return inter
      while indexes[i] < len(arrays[i]) and arrays[i][indexes[i]] < arrays[0][indexes[0]]:
        indexes[i] += 1
      while indexes[i] < len(arrays[i]) and indexes[0] < len(arrays[0]) and arrays[i][indexes[i]] > arrays[0][indexes[0]]:
        indexes[0] += 1
    if indexes[0] < len(arrays[0]):
      inter.append(arrays[0][indexes[0]])
    indexes = [idx+1 for idx in indexes]
  return inter

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X