FFT 学习笔记

发布于 2018-09-05  445 次阅读


快速傅里叶变换 (Fast Fourier Transform, FFT) 是一种 $ O(n\ log_2\ n) $ 的时间复杂度内完成离散傅里叶变换 (DFT) 的算法,OI 中通常用来优化多项式乘法。

常见题目类型链接
给定一个 $n$ 次多项式 $F(x)$ 和一个 $m$ 次多项式 $G(x) $。求出它们的卷积。

FFT 就可以在 $ n\ log_2\ n $ 的时间复杂度内解决该类问题。

先来学习一些 FFT 前备知识。

Q:我是初中的啊…看不懂怎么办?

A:本人尽量写到初中生理解水平...因为我就是一个没学过数学奥赛的初中生…大可放心食用

多项式

系数表示法

设 $F(x)$ 表示一个 $n-1$ 次的多项式,则 $$ F(x) = \Sigma_{i=0}^{n-1} a_i * x_i $$

利用系数直接运算多项式乘法的时间复杂度是 $ O(n^2) $ 。(即每个系数都要枚举相乘)

点值表示法

将这个多项式 $F(x) $ 看成一个函数,就是说代入 $ n $ 个互不相同的 $x$ ,会得到 $n$ 个不同的取值 $y$。

那么这个多项式就被这 $n$ 个点 $(x_1,y_1),(x_2,y_2) \cdots (x_n,y_n) $ 唯一确定。

当然这样计算也是 $O(n^2)$ 的。

当然这些都需要优化。对于第一种好像不会优化,于是我们考虑第二种表示法的优化。


我们为了照顾初中选手,先介绍一些简单的东西。

向量(矢量)

有方向也有大小的量,与标量(无方向)相对。

在集合中我们用有箭头的直线来表示。

圆的弧度制

我们之前学的角度制的定义是将圆 $ 360 $ 等分。

弧度制的定义是等于半径长的圆弧对应的圆心角叫做 1 弧度的角。用符号 rad 表示,读作弧度,用 rad 作为单位来度量角的制度叫做弧度制(字面意思)

显然有:$ 1 ^{ ∘ } = \frac{\pi}{180}rad $

因为设这个圆的半径为 r ,我们发现周长是 $2\pi r$ ,那么在这个周长上取长度为 r 的弧,圆心角显然是 $ ( \frac{r}{2\pi r} * 360^∘ )rad= \frac{\pi}{180}rad $

平行四边形定则

有一个平行四边形 $ABCD $ ,那么有 $ AB + AD = AC$。

可以表示成向量加法。

等比数列求和

设一个等比数列中:首项为 $ a_1 $,公比是 $q$。

那么前 $n$ 项的和 $S = \frac{a_1(1-q^n)}{1-q}$

证明可以通过错位相减来实现。

复数

定义

设实数 $a,b$,定义虚数单位 $ i $,有 $i^2 = -1 $ ,将形如 $a + bi$ 的数叫做复数。复数域是目前已知最大的域。

在复平面里(x 轴表示实数,y 轴表示虚数)。从 $ O(0,0) $ 到 $A(a,b)$ 的向量能表示复数 $a + bi$。

模长:$OA$ 的长度,即 $\sqrt{a^2 + b^2}$

幅角:假设以逆时针为正方向,从 $x$ 正半轴旋转至向量的转角的有向角叫做幅角。

运算法则

加法

向量相加,满足平行四边形定则。即:$ (a + bi) + (c + di) = (a + c) + (b+d)i $

复数的加法是封闭的。

乘法

向量的乘法几何定义:复数相乘,模长相乘,幅角相加。

代数定义:

$$(a + bi) * (c + di) =ac + adi + bci + bdi^2 = ac + adi + bci - bd = (ac-bd) + (bc + ad)i$$

单位根

下文若不做特别的声明,均默认 $n$ 为 $2$ 的整数次幂。

在复平面上以原点为圆心,1 为半径作出的圆叫做单位圆。以圆点为起点,圆的 n 等分点为终点作出 n 个向量。设幅角为正且最小的向量对应的复数为 $\omega_n$ ,称为 $n$ 次单位根。

其余的 $n-1$ 个复数显然是 $ \omega_n^2,\omega_n^3 \cdots \omega_n^n$

注意到有 $\omega_n^0 = \omega_n^n = 1$ (对应方向为 x 轴正半轴的向量)

单位根的幅角是周角的 $\frac{1}{n}$ (显然 $n$ 等分)。

一些性质:

性质1

$$ \omega_n^k = cos\ k * \frac{2\pi}{n} + i\ sin\ k * \frac{2\pi}{n} \Rightarrow \omega_{2n}^{2k} = \omega_n^k$$

证明:

$$\omega_{2n}^{2k} = \cos 2k * \frac{2 \pi }{2n} + i \sin 2k* \frac{2 \pi }{2n} = \cos k* \frac{2 \pi }{n} + i \sin k * \frac{2\pi}{n} =\omega_n^k$$

性质2

$ \omega_n^{k+\frac{n}{2}} = -\omega_n^k $

证明:

$$ w_n^{ \frac{n}{2} } = \cos \frac{n}{2} * \frac{2 \pi }{n} + i \sin \frac{n}{2} * \frac{2 \pi }{n} = \cos \pi + i \sin \pi = -1$$

我们终于把预备知识讲完了,那么接下来开始正题。

快速傅里叶变换

朴素的求一个多项式的点值表达法的复杂度是 $ O(n^2) $,我们先来探究一些性质。

设多项式 $F(x)$ 的系数为 $$ {a_0,a_{1},a_{2},\cdots,a_{n-1}}$$

那么

$ F(x) = \Sigma_{i=0}^{n-1} a_i * x^i $

将其按照下标分成两个多项式:

$$ F_1 (x) = \Sigma_{i=0}^{ \frac{n}{2} - 1 } a_{i * 2} x^{i * 2} $$
$$F_2(x) = \Sigma_{i=0}^{ \frac{n}{2} - 1 }a_{i * 2 + 1 } x^{i * 2 + 1 }$$

那么可以得到:

$$ F(x) = F_1(x^2) + xF_2(x^2) $$

我们代入 $\omega_n^k(k < \frac{n}{2})$ 得到

$$ F( \omega^k_n ) = F_1( \omega^{2k}_ n ) + \omega^k_n F_2( \omega_n^{2k} ) $$

$$ = F_1(\omega^k_{ \frac{n}{2} }) + \omega^k_nF_2( \omega_{ \frac{n}{2}}^{k})$$

同理,代入 $w_n^{k+\frac{n}{2}}$得到

$$ F( \omega_n^{k + \frac{n}{2} } ) = F_1({ \omega^{2k+n}_ n }) + \omega_n^{k+ \frac{n}{2} }F_2(\omega_n^{2k+n}) $$
$$ = F_1(\omega_n^{2k}) - \omega_n^kF_2(\omega_n^{2k} * \omega_n^n) $$
$$= F_1(\omega_n^{2k}) - \omega_n^kF_2(\omega_n^{2k}) $$
$$ = F_1(\omega^k_{\frac{n}{2}}) - \omega^k_nF_2(\omega_{\frac{n}{2}}^{k})$$

发现这两个式子仅由一个常数项不同(-号)

所以我们只需要算出第一个式子,第二个式子就可以 $O(1)$ 求。

所以说我们可以递归二分去计算这个东西,遇到常数项就返回。

时间复杂度:$ T(n) = 2T(n/2) + O(n) = O(n\ log_2\ n)$

快速傅里叶逆变换

还没有结束...

我们要考虑怎么输出答案啊。

一般来说没有人会让你输出点值表示法的,所以我们需要用逆变换来换回来。

首先我们可以将 FFT 后的结果看做一个向量 $$ (y_0,y_1,y_2,\cdots,y_{n-1}) $$。

我们设另一个向量 $$ (c_0,c_1,c_2,\cdots,c_{n-1}) $$ 令其满足

$$ c_k=\Sigma_{i=0}^{n-1}y_i(\omega^{-k}_n)^i $$

浅显易懂的讲就是将结果看做一个多项式,求出这个多项式的 FFT 的结果。

推一波式子:

$$ c_k=\Sigma_{i=0}^{n-1}y_i(\omega^{-k}_ n)^i $$
$$=\Sigma_{i=0}^{n-1}(\Sigma_{j=0}^{n-1}a_j(\omega^i_n)^j)(\omega_n^{-k})^i $$
$$=\Sigma_{i=0}^{n-1}(\Sigma_{j=0}^{n-1}a_j(\omega_n^j)^i)(\omega_n^{-k})^i $$
$$= \Sigma_{i=0}^{n-1}(\Sigma_{j=0}^{n-1}a_j(\omega_n^j)^i(\omega_n^{-k})^i $$
$$= \Sigma_{i=0}^{n-1}\Sigma_{j=0}^{n-1}a_j(\omega_n^j)^i(\omega_n^{-k})^i $$
$$=\Sigma_{i=0}^{n-1}\Sigma_{j=0}^{n-1}a_j(\omega_n^{j-k})^i $$
$$ = \Sigma_{j=0}^{n-1}a_j \Sigma_{i=0}^{n-1}(\omega_n^{j-k})^i $$

设 $ S(x) = \Sigma_{i=0}^{n-1} x^i$

代入 $ \omega_n^k $ 得:

$$ S(\omega_n^k) = \Sigma_{i=0}^{n-1}w_n^i $$

直接套用等比数列求和(错位相减法),得:

$ S(\omega_n^k) = \frac{1-1}{\omega_n^k-1} $

不难看出分子为 0 ,分母不为 0 。

分类讨论一下:

$ K \neq 0$ :$ S(\omega_n^k) = 0 $

$ K = 0$:$S(\omega_n^0) = n$

我们继续考虑式子 $$c_k = \Sigma_{j=0}^{n-1}a_j\Sigma_{i=0}^{n-1}(\omega_n^{j-k})^i $$

同理:

$ j \neq k \to (j-k) \neq 0 $:$c_k = 0$

$ j = k \to (j - k) = 0$:$c_k = n $

这样我们就得到点值与系数之间的表示方法:

$$ c_k = na_k \Longleftrightarrow a_k = \frac{c_k}{n} $$

递归实现的代码

const double Pi=acos(-1.0);
struct complex
{
    double x,y;
    complex (double xx=0,double yy=0){x=xx,y=yy;}
}a[MAXN],b[MAXN];
complex operator + (complex a,complex b){ return complex(a.x+b.x , a.y+b.y);}
complex operator - (complex a,complex b){ return complex(a.x-b.x , a.y-b.y);}
complex operator * (complex a,complex b){ return complex(a.x*b.x-a.y*b.y , a.x*b.y+a.y*b.x);} 
void fast_fast_tle(int limit,complex *a,int type)
{
    if(limit==1) return ;// 只有一个常数项
    complex a1[limit>>1],a2[limit>>1];
    for(int i=0;i<=limit;i+=2)// 根据下标的奇偶性分类
        a1[i>>1]=a[i],a2[i>>1]=a[i+1];
    fast_fast_tle(limit>>1,a1,type);
    fast_fast_tle(limit>>1,a2,type);
    complex Wn=complex(cos(2.0*Pi/limit) , type*sin(2.0*Pi/limit)),w=complex(1,0);
    // Wn为单位根,w表示幂
    for(int i=0;i<(limit>>1);i++,w=w*Wn)
        a[i]=a1[i]+w*a2[i],
        a[i+(limit>>1)]=a1[i]-w*a2[i];// O(1)得到另一部分 
}
// 自己懒得写了找了一份别人的代码233以后会补上的

蝴蝶操作

稍微优化一下小常数。

发现 $w*a2[i]$ 被计算了两次,开一个 $t$ 来存储仅需计算一次。

迭代实现

递归的效率过于低下,我们需要更快的解决这个问题。

我们观察原序列的下标和分类后的序列下标。

发现二进制翻转可以就直接解决这个问题。

这样我们就可以 $O(n)$ 的利用某种操作来求得序列,然后依照规律合并即可。

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

#define R register
#define LL long long
#define U unsigned
#define FOR(i,a,b) for(int i = a;i <= b;++i)
#define RFOR(i,a,b) for(int i = a;i >= b;--i)
#define CLR(i,a) memset(i,a,sizeof(i))
#define BR printf("--------------------\n")
#define DEBUG(x) std::cerr << #x << '=' << x << std::endl

namespace fastIO{
    #define BUF_SIZE 100000
    #define OUT_SIZE 100000
    #define ll long long
    bool IOerror=0;
    inline char nc(){
        static char buf[BUF_SIZE],*p1=buf+BUF_SIZE,*pend=buf+BUF_SIZE;
        if (p1==pend){
            p1=buf; pend=buf+fread(buf,1,BUF_SIZE,stdin);
            if (pend==p1){IOerror=1;return -1;}
        }
        return *p1++;
    }
    inline bool blank(char ch){return ch==' '||ch=='\n'||ch=='\r'||ch=='\t';}
    inline void read(int &x){
        bool sign=0; char ch=nc(); x=0;
        for (;blank(ch);ch=nc());
        if (IOerror)return;
        if (ch=='-')sign=1,ch=nc();
        for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
        if (sign)x=-x;
    }
    inline void read(ll &x){
        bool sign=0; char ch=nc(); x=0;
        for (;blank(ch);ch=nc());
        if (IOerror)return;
        if (ch=='-')sign=1,ch=nc();
        for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
        if (sign)x=-x;
    }
    #undef ll
    #undef OUT_SIZE
    #undef BUF_SIZE
};
using namespace fastIO;
#undef R

const int MAXN = 10000000 + 5;

const double PI = acos(-1.0);
struct complex{
    double x,y;
    complex(double x=0,double y=0) : x(x),y(y) {}
}a[MAXN],b[MAXN];

complex operator + (const complex &a,const complex &b){
    return complex(a.x + b.x,a.y + b.y);
}

complex operator - (const complex &a,const complex &b){
    return complex(a.x-b.x,a.y-b.y);
}

complex operator * (const complex &a,const complex &b){
    return complex(a.x * b.x - a.y * b.y,a.x*b.y+a.y*b.x);
}

int N,M;
int f[MAXN],g[MAXN],r[MAXN];
int limit=1,len;

inline void fft(complex *A,int opt){
    FOR(i,0,limit){
        if(i < r[i]) std::swap(A[i],A[r[i]]);
    }
    for(int mid = 1;mid < limit;mid <<= 1){
        complex W(cos(PI/mid),opt*sin(PI/mid));
        for(int j=0,R = (mid << 1);j < limit;j += R){
            complex w(1,0);
            for(int k = 0;k < mid;k++,w=w*W){
                complex x = A[j+k],y=w*A[j+mid+k];
                A[j+k] = x+y;
                A[j+mid+k] = x-y;
            }
        }
    }
}

int main(){
    read(N);read(M);
    FOR(i,0,N){
        int x;read(x);a[i].x = x;
    }
    FOR(i,0,M){
        int x;read(x);b[i].x = x;
    }
    while(limit <= N + M){
        limit <<= 1;len++;
    }
    FOR(i,0,limit){
        r[i] = (r[i>>1]>>1)|((i&1)<<(len-1));
    }
    fft(a,1);
    fft(b,1);
    FOR(i,0,limit) a[i] = a[i]*b[i];
    fft(a,-1);
    FOR(i,0,N+M) printf("%d%c",(int)(a[i].x/limit+0.5),(i == N + M) ? '\n': ' ');
    return 0;
}

一个OIer。