2021/06/15

mod p での convolution を FFT で実装する(備忘録)

 mod p でのたたみ込みを FFT で実装するコードを自作するための備忘録。

 

フーリエ変換のまとめ 

 

離散フーリエ変換

$h$ をある体$F$の上での$x$の関数

\begin{equation*}h(x) = \sum_{i=0}^{n-1} c_i x^i\end{equation*}

とする。また、$F$には1の原始$n$乗根が存在すると仮定し、それを$\zeta_n$とする。つまり$\zeta_n$は

\begin{equation*} \zeta_n^i \left\{ \begin{array}{} = 1 & \text{ if } i = n \\ \ne 1 & \text{ if } 1 \le i \le n-1 \end{array} \right.\end{equation*} 

なる$F$の元とする。

すると、$h(\zeta_n^i)$はまた$F$の元なので、これを$\hat{c}_i$と置き(つまり $\hat{c}_i = h(\zeta_n^i)$)、関数$\hat{h}$を

\begin{equation*}\hat{h}(x) = \sum_{i=0}^{n-1} \hat{c}_i x^i\end{equation*}

とすると、$h \mapsto \hat{h}$ は$F[x]$から$F[x]$への写像となる。

関数の積と離散フーリエ変換

さて、ここで多項式環 $F[x]$ において、$f(x)$ と $g(x)$ との積を考える。$f, g$ をそれぞれ

\begin{equation*} f(x) = \sum_{i=0}^{n-1}a_i x^i , \quad g(x) = \sum_{j=0}^{n-1}b_j x^j\end{equation*} 

と表すが、簡単のため、$n$は十分大きい数とし、$a_i$, $b_j$ は$i, j$が大きいと$0$であるようにとっておく。特に、$a_i$が$0$でない最大の$i$と、$b_j$が$0$でない最大の$j$との和が$n$未満となるように、$n$を十分大きくしておく。すると、

\begin{equation*} f(x) . g(x) = \left( \sum_{i=0}^{n-1}a_i x^i \right)\left( \sum_{i=0}^{n-1}b_i x^i \right) = \sum_{k=0}^{n-1} \left( \sum_{i+j=k} a_i b_j \right) x^k \end{equation*} 

となる。そこでこの右辺の関数を $f$と$g$との積とし、$f*g$と書くことにする。

では、この関数の積は先の写像$h \mapsto \hat{h}$でどうなるか。

\begin{equation*} (f*g)(x) \mapsto \widehat{(f*g)}(x) = \sum_{k=0}^{n-1} (f*g)(\zeta_n^k) x^k \end{equation*}

だが、$(f*g)(x)$の定義は$(f*g)(x)=f(x).g(x)$だったので、$(f*g)(\zeta_n^k)=f(\zeta_n^k).g(\zeta_n^k)$、よって、

\begin{equation*} (f*g)(x) \mapsto \widehat{(f*g)}(x) = \sum_{k=0}^{n-1} f(\zeta_n^k)g(\zeta_n^k) x^k = \sum_{k=0}^{n-1} \hat{a}_k \hat{b}_k x^k \end{equation*} 

となる。つまり、$F[x]$での関数の積は、$h \mapsto \hat{h}$で移した先では同じ次数の係数の積に変換されていた。

離散フーリエ逆変換

もう一つよく似た写像で、

\begin{equation*} h(x) \mapsto \check{h}(x) = \sum_{i=0}^{n-1} h(\zeta_n^{-i}) x^i \end{equation*} 

を作る。この写像は、$\zeta_n^i$ が$\zeta_n^{-i}$に置き換わったものであるが、$\hat{h}$をこの写像で飛ばすと

\begin{equation*} \hat{h}(x) \mapsto \check{\hat{h}}(x) = \sum_{i=0}^{n-1} \hat{h}(\zeta_n^{-i}) x^i \end{equation*}

ここで、

\begin{equation*} \hat{h}(\zeta_n^{-i}) = \sum_{j=0}^{n-1} h(\zeta_n^j) (\zeta_n^{-i})^j = \sum_{j=0}^{n-1} \sum_{k=0}^{n-1} c_k (\zeta_n^j)^k (\zeta_n^{-i})^j = \sum_{k=0}^{n-1} c_k \sum_{j=0}^{n-1} \zeta_n^{j(k-i)} \end{equation*}
 
すると、最後の$\sum_{j=0}^{n-1}(\zeta_n^{k-i})^j$ は、$k=i$なら$n$、$k\ne i$なら$0$となる。なぜなら、まず$k=i$ならば$\zeta_n^{k-i}=1$で自明なので、$k\ne i$とすると、$-(n-1) \le k - i \le n - 1$であり、$\zeta_n$は1の原始$n$乗根なので、これらはすべて$1$でない$1$の$n$乗根となっている。よって、$\zeta_n^{k-i}=\xi$とし、$\sum_{j=0}^{n-1}\xi^j$ に $\xi - 1$をかけると、

\begin{equation*} (\xi - 1)\sum_{j=0}^{n-1}\xi^j = (\xi - 1)(\xi^{n-1} + \xi^{n-2} + ...+ \xi + 1) = \xi^n - 1 \end{equation*}
 
$\xi$は1の$n$乗根なので右辺は0、また、左辺の$(\xi -1)$は0ではないので、$\sum_{j=0}^{n-1}\xi^j = 0$であることがわかった。
 
以上により、

\begin{equation*} \hat{h}(\zeta_n^{-i}) = n c_i , \quad \check{\hat{h}}(x) = n \sum_{i=0}^{n-1} c_i x^i = n h(x)\end{equation*}
を得た。つまり、$h \mapsto \hat{h} \mapsto \check{\hat{h}} = n h$であり、特にこの二つの写像は全単射であることもわかった。
 

FFT(高速フーリエ変換)

さて、関数$h(x)=\sum_{k=0}^{n-1}c_k x^k$が与えられたとき、この離散フーリエ変換を求めたいとする。つまり、$h(\zeta_n^i)$をすべての$i$について求めたいのだが、計算式通りに求めようとすると、$0 \le i \le n-1$ と $0 \le k \le n-1$との組み合わせ分、$n\times n$回の計算が必要になる。つまり、計算量は$\mathcal{O}(n^2)$であるので、$n > 10^5$などになってくるとかなり厳しい。

そこで、この計算を$\mathcal{O}(n \log n)$程度で済ませるのがFFT。

まず、$n$を適当に大きくとって$2$のべき乗になるようにし、増えた分の$c_k$は0で埋める。(2のべき乗でなくても方法はあるが、2のべき乗が最も高速に処理できる。)

次に、$h$の各項を偶数番目と奇数番目に分けて、

\begin{equation*} h(x) = \sum_{k=0}^{n/2-1}c_{2k}x^{2k} + x \sum_{k=0}^{n/2-1}c_{2k+1}x^{2k} \end{equation*}

とし、それぞれの項を$h_0$、$h_1$とする。つまり、

\begin{eqnarray*} h_0(x) &=& \sum_{k=0}^{n/2-1}c_{2k}x^k \\ h_1(x) &=& \sum_{k=0}^{n/2-1}c_{2k+1}x^k \end{eqnarray*} 

とおけば、

\begin{eqnarray*}h(x) = h_0(x^2) + x h_1(x^2)\end{eqnarray*} 

なので、特に

\begin{eqnarray*}h(\zeta_n^i) = h_0(\zeta_{n/2}^i) + \zeta_n^i h_1(\zeta_{n/2}^i)\end{eqnarray*} 

である。$h_0$, $h_1$の中に入っているのは $(\zeta_n^i)^2 = \zeta_{n/2}^i$、つまり、1の原始$n/2$乗根 $\zeta_n^2$ の$i$乗であることに注意する。

さて、この$h_0$や$h_1$と、$0 \le i \le n-1$について、$h_0(\zeta_{n/2}^i)$および$h_1(\zeta_{n/2}^i)$を求めれば、$h(\zeta_n^i)$がわかるのだが、$\zeta_{n/2}^i$は1の原始$n/2$乗根のべき乗なので、$0 \le i \le n/2 -1$まで計算すれば、あとはそれが再度繰り返されるだけ。つまり、$h_0$および$h_1$からそれぞれ$h_0(\zeta_{n/2}^i)$, $h_1(\zeta_{n/2}^i)$を求める作業は$\mathcal{O}(2\times(n/2)^2)=\mathcal{O}(n^2/2)$かかり、そこから$h$再構成するのに$\mathcal{O}(n)$かかるので、合わせて$\mathcal{O}(n + n^2/2)$かかる。

そして、$h$を$h_0$と$h_1$とに分割した作業をさらに$h_0$と$h_1$にも適用すれば、計算時間は $\mathcal{O}(n + 2 \times n/2 + n^2/4)$に減る。これを繰り返し、2のべき乗であった$n$が1になるまで分割すれば、結局計算量は$\mathcal{O}(n\log n )$程度になる。

Convolutionとの関係

数列$\{a_i\}_{i=0}^{n-1}, \{b_i\}_{i=0}^{n-1}$が与えられたとき、これを係数に持つ関数$f(x), g(x)$をつくって積を取れば、

\begin{equation*} f(x) . g(x) = \left( \sum_{i=0}^{n-1}a_i x^i \right)\left( \sum_{i=0}^{n-1}b_i x^i \right) = \sum_{k=0}^{n-1} \left( \sum_{i+j=k} a_i b_j \right) x^k \end{equation*} 

となるが、この式の右辺の各$x^k$の係数に、数列$\{a_i\}$ と $\{b_j \}$ とのすべての$i+j=k$についての convolution が出現する。つまり、$h = f * g$とすると、$c_k = \sum_{i+j=k}a_i b_j$ となるので、Convolutionの目的は$h$の係数$c_k$を全ての$k$について求めることにある。

ところで、$h = f*g$ を離散フーリエ変換で飛ばすと、$\hat{h}$の係数は

\begin{equation*} h(\zeta_n^i) = f(\zeta_n^i) \times g(\zeta_n^i) \end{equation*}

である。つまり、離散フーリエ変換で飛ばした先では、関数の積は係数の積に置き換わるので、$\hat{h}$の係数は$\hat{f}$ および $\hat{g}$ の係数から$\mathcal{O}(n)$で求まる。

そこで、本来求めたかった$h$の係数$c_i$を、一旦フーリエ変換した先の係数の積で$\hat{h}$を求めておいてから、これをフーリエ逆変換で戻してやれば(逆変換も$\mathcal{O}(n \log n)$で完了するので)、$c_i$の全てを$\mathcal{O}(n\log n)$で求めることができる。

\begin{array}{cccc} f,g &\xrightarrow{\mathcal{O}(n \log n)} & \hat{f},\hat{g} & \\ \downarrow & &\downarrow & \mathcal{O}(n) \\ \check{\hat{h}} = nh &\xleftarrow{\mathcal{O}(n \log n)} & \hat{h} & \end{array} 

 

mod p での Convolution

 以上の操作を、素数$p$に対する体$F=\mathbb{Z}/p\mathbb{Z}$で行う。このときポイントになるのは、なるべく大きい2のべき乗$n$について、1の原始$n$乗根が存在することである。逆に言えば、そのような$n$が存在しない場合には、(簡単な)FFTを使って$\mathbb{Z}/p\mathbb{Z}$上のconvolutionを高速に行うことは難しくなる。では、このような$n$が取れる条件は何か。

 素数$p$に対する$\mathbb{Z}/p\mathbb{Z}$の乗法群$(\mathbb{Z}/p\mathbb{Z})^\times$は、位数が$p-1$の群になる。よって、すべての$a \in \mathbb{Z}/p\mathbb{Z}$ について、$a \ne 0$ならば $a^{p-1} = 1 \mod p$。特に、$(\mathbb{Z}/p\mathbb{Z})^\times$は位数が$p-1$の巡回群にもなるので、必ず位数が$p-1$の元が存在する。それを$r$とすると、$1 \le i < p-1$について$r^i \ne 1 \mod p$ かつ $g^{p-1} = 1 \mod p$である。また、0でない全ての元の位数は$p-1$の約数になっていることもわかるので、特に、1の原始根があるとすれば、その位数は$p-1$の約数である。

さて、競プロで素数の例としてよく使われるのは、$10^9+7$や$998244353$であるが、それぞれについて乗法群の位数を素因数分解すると、前者は$10^9+7 - 1 = 2\times (5 \times 10^8 + 3)$、後者は $998244353 - 1 = 2^{23}\times 7\times 17$である。つまり、前者は乗法群の位数がまた巨大な素数を素因数としてもっており、他方後者は非常に大きい2のべき乗をもっている。上で述べたように、FFTは大きな2のべき乗を$n$として用いたいので、$p = 10^9+7$ では難しく、$p = 998244353$はふさわしいことがわかる。実際、高難易度の大会でない限りは、convolutionの問題では後者の$p$が用いられることがほとんどであろう。

以下、$p=998244353$で考える。次はこの$p$を用いた$\mathbb{Z}/p\mathbb{Z}$にて1の原始根を探す必要があるのだが、実は$3$は1の原始$p-1$乗根の一つであることがわかる。これは、$3^i \mod p$を順に全部計算してみればよい。$3$の位数は$p-1 = 2^{23}\times 7 \times 17 = 2^{23}\times 119$なので、$3^{119} \mod p = 15311432$は1の原始$2^{23}$乗根の一つである。

あとは、$15311432$を2乗していくことで、1の原始$2^{22}$乗根$267099868$、1の原始$2^{21}$乗根$733596141$、・・・を得られる。

以上でFFTを使ったConvolutionを実装する手がかりが全て得られた。


参考にしたサイト

 NTT(数論変換)のやさしい解説 / Senの競技プログラミング備忘録

競技プログラミング だれでもわかる FFT/NTT 入門 (pdf) / monkukui 

 FFTとNTTとFMTと / peria (Qiita)

 FFT(高速フーリエ変換)を完全に理解する話 ageprocpp (Qiita)