学習する天然ニューラルネット

主に機械学習に関する覚書や情報の整理

PriorityQueue Classを作る [Pythonで競プロ]

この問題を解くのにpriority queueを使う方法がある。 atcoder.jp

Pythonでpriority queueを実装するためには2つ方法があるがどちらも欠点がある。

  1. heapqを用いた方法

    • こちらを用いて実装する方が多いと思う。でもめちゃくちゃ使いづらくないですか?
    • これで用意されている関数は、リストに対してin-placeで処理を施す。
    • クラスが用意されていない。
  2. from deque import PriorityQueue を用いた方法

    • クラスが用意されていて1よりも扱いやすいが、2倍ぐらい遅い。
    • 中身が確認できない。(中身でfor を回す等の作業ができない。)

そこで、1をベースにPriorityQueueクラスを用意した。 pushやpopをメソッドとすることで、heapqをそのまま使うよりもスッキリ見やすく実装することができる。 また、インスタンスをそのまま実行するとheapの中身が見られるようにした。

from heapq import heapify, heappop, heappush, heappushpop

class PriorityQueue:
    def __init__(self, heap):
        '''
        heap ... list
        '''
        self.heap = heap
        heapify(self.heap)

    def push(self, item):
        heappush(self.heap, item)

    def pop(self):
        return heappop(self.heap)

    def pushpop(self, item):
        return heappushpop(self.heap, item)

    def __call__(self):
        return self.heap

冒頭に上げた問題で使い方の具体例を示すと、こう。

import sys
read = sys.stdin.readline

def read_ints():
    return list(map(int, read().split()))

X, Y, Z, K = read_ints()
A = read_ints()
B = read_ints()
C = read_ints()

A.sort(reverse=True)
B.sort(reverse=True)
C.sort(reverse=True)


from heapq import heapify, heappop, heappush, heappushpop

class PriorityQueue:
    def __init__(self, heap):
        '''
        heap ... list
        '''
        self.heap = heap
        heapify(self.heap)

    def push(self, item):
        heappush(self.heap, item)

    def pop(self):
        return heappop(self.heap)

    def pushpop(self, item):
        return heappushpop(self.heap, item)

    def __call__(self):
        return self.heap


heap = []  # ヒープといっても順序を工夫したただのリスト

q = PriorityQueue(heap) #ここでインスタンスを作ってます
q.push((-(A[0] + B[0] + C[0]), 0, 0, 0))

considered = set()
ans = []
for k_th in range(1, K+1):
    heap_max, i, j, k = q.pop() #ここで一番小さな要素(先頭が見られる)を取り出してます
    ans.append(-heap_max)
    for di, dj, dk in zip([1, 0, 0], [0, 1, 0], [0, 0, 1]):
        i_new, j_new, k_new = i + di, j + dj, k + dk
        if i_new >= X or j_new >= Y or k_new >= Z:
            continue
        if (i_new, j_new, k_new) in considered:
            continue
        considered.add((i_new, j_new, k_new))
        q.push((-(A[i_new] + B[j_new] + C[k_new]), i_new, j_new, k_new)) #ここで要素の追加を行っています。


print(*ans, sep='\n')