Oui, c'est possible ! J'ai modifié votre exemple de code pour le faire.
Ma réponse suppose que votre question porte sur l'algorithme - si vous voulez le code le plus rapide en utilisant set
voir d'autres réponses.
Cela permet de maintenir le O(n log(k))
complexité temporelle : tout le code entre if lowest != elem or ary != times_seen:
y unbench_all = False
es O(log(k))
. Il y a une boucle imbriquée à l'intérieur de la boucle principale ( for unbenched in range(times_seen):
) mais cela ne fonctionne que times_seen
temps, et times_seen
est initialement égal à 0 et est remis à 0 à chaque fois que cette boucle interne est exécutée, et ne peut être incrémenté qu'une fois par itération de la boucle principale, de sorte que la boucle interne ne peut pas effectuer plus d'itérations au total que la boucle principale. Ainsi, puisque le code à l'intérieur de la boucle interne est O(log(k))
et s'exécute au maximum autant de fois que la boucle externe, et la boucle externe est O(log(k))
et exécute n
fois, l'algorithme est O(n log(k))
.
Cet algorithme repose sur la façon dont les tuples sont comparés en Python. Il compare les premiers éléments des tuples, et s'ils sont égaux, il compare les deuxièmes éléments (par ex. (x, a) < (x, b)
est vrai si et seulement si a < b
). Dans cet algorithme, contrairement à l'exemple de code de la question, lorsqu'un élément est retiré du tas, il n'est pas nécessairement poussé à nouveau dans la même itération. Puisque nous devons vérifier si toutes les sous-listes contiennent le même nombre, après qu'un nombre soit sorti du tas, sa sous-liste est ce que j'appelle "benched", ce qui signifie qu'elle n'est pas ajoutée au tas. C'est parce que nous devons vérifier si d'autres sous-listes contiennent le même élément, donc l'ajout du prochain élément de cette sous-list n'est pas nécessaire pour le moment.
Si un nombre est effectivement dans toutes les sous-listes, alors le tas ressemblera à quelque chose comme [(2,0),(2,1),(2,2),(2,3)]
avec tous les premiers éléments des tuples identiques, donc heappop
sélectionnera celui dont l'indice de sous-liste est le plus faible. Cela signifie que le premier index 0 sera supprimé et que times_seen
sera incrémenté jusqu'à 1, puis l'index 1 sera supprimé et times_seen
sera incrémenté à 2 - si ary
n'est pas égal à times_seen
alors le nombre n'est pas dans l'intersection de toutes les sous-listes. Cela conduit à la condition if lowest != elem or ary != times_seen:
qui décide quand un nombre ne doit pas figurer dans le résultat. Le site else
branche de cette if
La déclaration est destinée aux cas où elle pourrait encore figurer dans le résultat.
Le site unbench_all
booléen est pour quand toutes les sous-listes doivent être retirées du banc - cela pourrait être parce que :
- On sait que le numéro actuel ne se trouve pas dans l'intersection des sous-listes.
- On sait qu'il se trouve dans l'intersection des sous-listes suivantes
Lorsque unbench_all
es True
toutes les sous-listes qui ont été retirées du tas sont ajoutées à nouveau. On sait que ce sont celles qui ont des indices dans range(times_seen)
puisque l'algorithme ne retire des éléments du tas que s'ils ont le même numéro, ils doivent donc avoir été retirés dans l'ordre de l'indice, de manière contiguë et en commençant par l'indice 0, et il doit y avoir times_seen
d'entre eux. Cela signifie que nous n'avons pas besoin de stocker les indices des sous-listes mises en attente, seulement le nombre de celles qui ont été mises en attente.
import heapq
def mergeArys(srtd_arys):
heap = []
srtd_iters = [iter(x) for x in srtd_arys]
# put the first element from each srtd array onto the heap
for idx, it in enumerate(srtd_iters):
elem = next(it, None)
if elem:
heapq.heappush(heap, (elem, idx))
res = []
# the number of tims that the current number has been seen
times_seen = 0
# the lowest number from the heap - currently checking if the first numbers in all sub-lists are equal to this
lowest = heap[0][0] if heap else None
# collect results in nlogK time
while heap:
elem, ary = heap[0]
unbench_all = True
if lowest != elem or ary != times_seen:
if lowest == elem:
heapq.heappop(heap)
it = srtd_iters[ary]
nxt = next(it, None)
if nxt:
heapq.heappush(heap, (nxt, ary))
else:
heapq.heappop(heap)
times_seen += 1
if times_seen == len(srtd_arys):
res.append(elem)
else:
unbench_all = False
if unbench_all:
for unbenched in range(times_seen):
unbenched_it = srtd_iters[unbenched]
nxt = next(unbenched_it, None)
if nxt:
heapq.heappush(heap, (nxt, unbenched))
times_seen = 0
if heap:
lowest = heap[0][0]
return res
if __name__ == '__main__':
a1 = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]
a2 = [[1, 1], [1, 1, 2, 2, 3]]
for arys in [a1, a2]:
print(mergeArys(arys))
Un algorithme équivalent peut être écrit comme ceci, si vous préférez :
def mergeArys(srtd_arys):
heap = []
srtd_iters = [iter(x) for x in srtd_arys]
# put the first element from each srtd array onto the heap
for idx, it in enumerate(srtd_iters):
elem = next(it, None)
if elem:
heapq.heappush(heap, (elem, idx))
res = []
# collect results in nlogK time
while heap:
elem, ary = heap[0]
lowest = elem
keep_elem = True
for i in range(len(srtd_arys)):
elem, ary = heap[0]
if lowest != elem or ary != i:
if ary != i:
heapq.heappop(heap)
it = srtd_iters[ary]
nxt = next(it, None)
if nxt:
heapq.heappush(heap, (nxt, ary))
keep_elem = False
i -= 1
break
heapq.heappop(heap)
if keep_elem:
res.append(elem)
for unbenched in range(i+1):
unbenched_it = srtd_iters[unbenched]
nxt = next(unbenched_it, None)
if nxt:
heapq.heappush(heap, (nxt, unbenched))
if len(heap) < len(srtd_arys):
heap = []
return res