サブロウ丸

サブロウ丸

主にプログラミングと数学

フラクショナルカスケーディング応用

k-th number

数列a_1, a_2, ..., a_nとクエリを表す数の3つ組みがm個与えられる。各クエリ(i, j, k)に対しa_i, ..., a_jを昇順にソートした際のk番目の数を出力せよ。

  • n ≤ 1000000
  • m ≤ 5000
  • |a_i| ≤ 109

2104 -- K-th Number

単純にクエリごとに部分配列の取り出し→ソートを行うと計算量はm * O(n log n) = O(m n log n)です。問題のnの最大値が大きいので計算量をもう少し小さくしたいところ。

蟻本、第2版第7刷、Sec3-3, p174にあるようにこの問題は領域探索木を用いて解くことができます。本稿ではfractional cascadingを用いて問題を解く方法とコードを紹介します。


方法

k-th number問題について、あるクエリ(i, j, k)は

(1) {(i, a_i); i = 1, ..., n} 点群があるときに
(2) i ≤ x ≤ j, y ≤ a_v の領域にある点がちょうどkであるようなi ≤ v ≤ jを見つける

ことと同値になります。元のクエリにはa_vを返せすことになります。 (2)のi ≤ x ≤ jに当てはまるのはa_i,...,a_jの点のみで、さらにy ≤ a_vの領域にちょうど点がkであることはちょうどa_v以下の点が自分を含めてk個存在することになり、これが問題が求める数に相当するからです。

さらにこの問題は2つ目の条件を

(2') i ≤ x ≤ j, y ≤ a_v の領域にある点がちょうどkであるような最小のa_v (1 ≤ v ≤ n)を見つける

と変換することができます(vの範囲に注目)。証明は(2')を満たすvをv'としたときにa_v'が(a_i, ..., a_j)の中でk番目に大きな数Kと等しいことを言えばよく、これはv' < Kとすると(2')で述べた領域内の点がk個であることと矛盾することから分かります。

(2')を満たすvは配列aをsortした配列をbとしたときに、配列bについて二分探索を行えばよいですね。これが大まかな方針です。配列のsortについては複数のクエリがあっても初めに1回のみ行えば良い(n log n)ので目をつむることにして、あとはi ≤ x ≤ j, y ≤ a_v の領域にある点を効率よく見つけられればよいですが、ここでfractional cascading(フラクショナルカスケーディング)を使います。

fractional cascading

私の記事から引用します。

ラクショナルカスケーディングは2次元領域において, 層状領域木を用いて指定した長方形領域に含まれる点を高速に探索する技術です。問い合わせ時間はO(log n + k)O(log⁡ n + k), nはデータ点数, kは報告される点の個数です。

実装

fractional cascading について、pipのパッケージとしてfractional_cascadingを利用するなら

git clone https://github.com/nariaki3551/fractional_cascading.git
cd fractional_cascading
python -m pip install .

単体のスクリプトが欲しいなら、

fractional_cascading.py · GitHub

からダウンロードします。

扱う問題は

  • a = [1, 5, 2, 6, 3, 7, 4]
  • m = [(2, 5, 3)]

とします。

import fractional_cascading as fc

a = [1, 5, 2, 6, 3, 7, 4]

点群の作成

points = [fc.Point(name=i, loc=(i, a[i])) for i in range(len(a))]

描写するとこんな感じ。

import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.scatter(
    [point.loc()[0] for point in points],
    [point.loc()[1] for point in points],
)
ax.grid("--")

配列aのsort

b = sorted(a)

2分探索(python2.10公式から)

def bisect_left(a, x, lo=0, hi=None, *, key=None):
    if lo < 0:
        raise ValueError('lo must be non-negative')
    if hi is None:
        hi = len(a)
    while lo < hi:
        mid = (lo + hi) // 2
        if key(a[mid]) < x:
            lo = mid + 1
        else:
            hi = mid
    return lo

2分探索のkey。これはfractional cascadingを用いて、i ≤ x ≤ j, y ≤ b[index] にある点の個数を返す関数です。探索領域の長方形をfc.Rectangle(x_min=i, x_max=j, y_min=float("-inf"), y_max=b[index])と定義します。

def key(index):
    R = fc.Rectangle(x_min=i, x_max=j, y_min=float("-inf"), y_max=b[index])
    num_points = sum(1 for _ in tree.query(R))
    return num_points

2分探索呼び出し。このケースの正解と一致します。

i, j, k = 2, 5, 3
index = bisect_left(b, k+1, key=key)
print(b[index])  # answer
>>> 6

プログラムのまとめ

import fractional_cascading as fc

a = [1, 5, 2, 6, 3, 7, 4]
points = [fc.Point(name=i, loc=(i, a[i])) for i in range(len(a))]

b = sorted(a)

def bisect_left(aa, x, key):
    lo = 0
    hi = len(aa)
    while lo < hi:
        mid = (lo + hi) // 2
        if key(aa[mid]) < x:
            lo = mid + 1
        else:
            hi = mid
    return lo

def key(index):
    R = fc.Rectangle(x_min=i, x_max=j, y_min=float("-inf"), y_max=b[index])
    num_points = sum(1 for _ in tree.query(R))
    return num_points

i, j, k = 2, 5, 3
index = bisect_left(b, k+1, key=key)
print(b[index])  # answer

まとめ

以前プログラムとしてまとめたフラクショナルカスケーディングが使える問題に出会ったので記事にしてみました。