# Radix-4 DIT, Radix-4 DIF, Radix-2 DIT, Radix-2 DIF FFTs
# John Bryan, 2017
# Python 2.7.3

import numpy as np
import time
import matplotlib.pyplot as plt
import math
import warnings
np.set_printoptions(threshold=np.nan,precision=3,suppress=1)
warnings.filterwarnings("ignore")


def swap(x,i,j):
    temp=x[i]
    x[i]=x[j]
    x[j]=temp
    return None


def digitreversal(g,b,p,n):
    if p%2==0:
        n1=int(np.sqrt(n))       #seed table size 
    else:
        n1=int(np.sqrt(int(n/b)))
    # algorithm 2, compute seed table  
    r=np.zeros(n1,dtype=int)
    r[1]=int(n/b)
    for j in range(1,b):
        r[j]=r[j-1]+r[1]
        for i in range(1,int(n1/b)):
            r[b*i]=int(r[i]/b)
            for j in range(1,b):
                r[int(b*i)+j]=r[int(b*i)]+r[j]
    #algorithm 1
    for i in range(0,n1-1):
        for j in range(i+1,n1):
            u=i+r[j]
            v=j+r[i]
            swap(g,u,v)
            if p%2==1:
                for z in range(1,b):
                    uu=i+r[j]+(z*n1)
                    vv=j+r[i]+(z*n1)
                    swap(g,uu,vv)
    return g


def dif_fft4(x,t,s):
    # radix-4 dif fft
    N=np.power(4,s)
    tss=1
    krange=int(float(N)/4.)
    block=1
    base=0
    for w in range(0,s):
        for h in range(0,block):
           for k in range(0,krange):
                # butterfly
                os=int(N/4)
                a=base+k
                b=base+k+os
                c=base+k+(2*os)
                d=base+k+(3*os)
                apc=x[a]+x[c]
                bpd=x[b]+x[d]
                amc=x[a]-x[c]
                bmd=x[b]-x[d]
                x[a]=apc+bpd
                if k==0:
                    x[b]=amc-(1j*bmd)
                    x[c]=apc-bpd
                    x[d]=amc+(1j*bmd)
                else:
                    r1=t[k*tss]
                    r2=t[2*k*tss]
                    r3=t[3*k*tss]
                    x[b]=(amc-(1j*bmd))*r1
                    x[c]=(apc-bpd)*r2
                    x[d]=(amc+(1j*bmd))*r3
           base=base+(4*krange)
        block=block*4
        N=float(N)/4.
        krange=int(float(krange)/4.)
        base=0
        tss=int(tss*4)
    return x



def fft4(x,t,s):
    # radix-4 dit fft
    N=4
    tss=np.power(4,s-1)
    krange=1
    block=int(x.size/4)
    base=0
    for w in range(0,s):
        for z in range(0,block):
           for k in range(0,krange):
                # butterfly
                os=N/4
                a=base+k
                b=base+k+os
                c=base+k+(2*os)
                d=base+k+(3*os)
                if k==0:
                    xbr1=x[b]
                    xcr2=x[c]
                    xdr3=x[d]
                else:
                    r1=t[k*tss]
                    r2=t[2*k*tss]
                    r3=t[3*k*tss]
                    xbr1=(x[b]*r1)
                    xcr2=(x[c]*r2)
                    xdr3=(x[d]*r3)
                e=x[a]+xcr2
                f=x[a]-xcr2
                g=xbr1+xdr3
                h=xbr1-xdr3
                jh=1j*h
                x[a]=e+g
                x[b]=f-jh
                x[c]=-g+e
                x[d]=jh+f
           base=base+(4*krange)
        block=block/4
        N=4*N
        krange=4*krange
        base=0
        tss=float(tss)/4.
    return x


def dif_fft0 (x,twiddle,p) :
    # radix-2 dif  
    Bp=1;
    Np=x.size;
    twiddle_step_size=1
    for P in range(0, p):           # pass loop
       Npp= Np/2
       baseE=0
       for b in range(0, Bp):       # block loop
          baseO=baseE+Npp
          for n in range(0, Npp):   # butterfly loop
             e= x[baseE+n]+x[baseO+n]
             if n==0:
                 o=x[baseE+n]-x[baseO+n]
             else:
                 twiddle_factor= n*twiddle_step_size
                 o=(x[baseE+n]-x[baseO+n])*twiddle[twiddle_factor]
             x[baseE+n]=e
             x[baseO+n]=o
          baseE=baseE+Np
       Bp=Bp*2
       Np=Np/2
       twiddle_step_size=2*twiddle_step_size
    return x


def fft2 (x,twiddle,s0) :
    # radix-2 dit
    N=x.size
    Bp=N/2;
    Np=2;
    twiddle_step_size=N/2
    for P in range(0, s0):
       Npp= Np/2
       baseT=0
       for b in range(0, Bp):
          baseB=baseT+Npp
          for n in range(0, Npp):
             if n==0:
                 bot=x[baseB+n]
             else:
                 twiddle_factor=n*twiddle_step_size
                 bot=x[baseB+n]*twiddle[twiddle_factor]
             top=x[baseT+n]
             x[baseT+n]=top+bot
             x[baseB+n]=top-bot
          baseT=baseT+Np
       Bp=Bp/2
       Np=Np*2
       twiddle_step_size=twiddle_step_size/2
    return x


def testr4dif():
    # Test and time dif radix4 w/ multiple length random sequences
    flag=0
    i=0
    r4diftimes=np.zeros(6)
    for s in range (2,8):
        x=np.random.rand(2*np.power(4,s)).view(np.complex128)
        Xpy=np.fft.fft(x)
        b=4
        N=np.power(4,s)
        kmax=3*((float(N)/4.)-1)
        k=np.linspace(0,kmax,kmax+1)
        t=np.exp(-2j*np.pi*k/N)
        t0=time.time()
        x=dif_fft4(x,t,s)
        r4diftimes[i]= time.time()-t0
        x=digitreversal(x,b,s,N)
        tf=np.allclose(x,Xpy)
        if tf==0: flag=1
        assert(tf)
        i=i+1
    if flag==0: print ("All radix-4 dif results were correct.")
    return r4diftimes


def testr4():
    # Test and time dit radix4 w/ multiple length random sequences
    flag=0
    i=0
    r4times=np.zeros(6)
    for s in range (2,8):
        x=np.random.rand(2*np.power(4,s)).view(np.complex128)
        Xpy=np.fft.fft(x)
        b=4
        N=np.power(4,s)
        x=digitreversal(x,b,s,N)
        kmax=3*((float(N)/4.)-1)
        k=np.linspace(0,kmax,kmax+1)
        t=np.exp(-2j*np.pi*k/N)
        t0=time.time()
        x=fft4(x,t,s)
        r4times[i]= time.time()-t0
        tf=np.allclose(x,Xpy)
        if tf==0: flag=1
        assert(tf)
        i=i+1
    if flag==0: print ("All radix-4 dit results were correct.")
    return r4times


def testr2dif():
    # Test and time radix2 dif w/ multiple length random sequences
    flag=0
    i=0
    r2diftimes=np.zeros(6)
    for r in range (2,8):
        s=np.power(4,r)
        c0=np.random.rand(2*s).view(np.complex_)
        Gpy=np.fft.fft(c0)
        N=s
        kmax=(float(N)/2.)-1
        k=np.linspace(0,kmax,kmax+1)
        t=np.exp(-2j*np.pi*k/N)
        b=2
        t1=time.time()
        Gg=dif_fft0(c0,t,int(2*r))
        r2diftimes[i]= time.time()-t1
        Zo=digitreversal(Gg,b,int(2*r),s)
        tf=np.allclose(Zo,Gpy)
        if tf==0: flag=1
        assert(tf)
        i=i+1
    if flag==0: print ("All radix-2 dif results were correct.")
    return r2diftimes


def testr2():
    # Test and time radix2 dit w/ multiple length random sequences
    flag=0
    i=0
    r2times=np.zeros(6)
    for r in range (2,8):
        s=np.power(4,r)
        c0=np.random.rand(2*s).view(np.complex_)
        Gpy=np.fft.fft(c0)
        N=s
        kmax=(float(N)/2.)-1
        k=np.linspace(0,kmax,kmax+1)
        t=np.exp(-2j*np.pi*k/N)
        b=2
        Zo=digitreversal(c0,b,int(2*r),s)
        t1=time.time()
        G=fft2(Zo,t,int(2*r))
        r2times[i]= time.time()-t1
        tf=np.allclose(G,Gpy)
        if tf==0: flag=1
        assert(tf)
        i=i+1
    if flag==0: print ("All radix-2 dit results were correct.")
    return r2times


def plot_times(tr2,trdif2,tr4,trdif4):
    u=np.zeros(6,dtype=int)
    for i in range(2,8):
        u[i-2]=np.power(4,i)
    plt.figure(figsize=(7,5))
    plt.rc("font",size=9)
    plt.loglog(u,trdif2,'o', ms=5, markerfacecolor="None", markeredgecolor='red', \
    markeredgewidth=1,basex=4, basey=10, label='radix-2 DIF')
    plt.loglog(u,tr2,'^', ms=5, markerfacecolor="None", markeredgecolor='green', \
    markeredgewidth=1,basex=4, basey=10, label='radix-2 DIT')
    plt.loglog(u,trdif4,'D', ms=5, markerfacecolor="None", markeredgecolor='blue', \
    markeredgewidth=1,basex=4, basey=10, label='radix-4 DIF')
    plt.loglog(u,tr4,'s', ms=5, markerfacecolor="None", markeredgecolor='black', \
    markeredgewidth=1,basex=4, basey=10, label='radix-4 DIT')
    plt.legend(loc=2)
    plt.grid()
    plt.xlim([12,18500])
    plt.ylim([.00004,1])
    plt.ylabel("time (seconds)")
    plt.xlabel("sequence length")
    plt.title("Time vs Length")
    plt.savefig('tvl2.png', bbox_inches='tight')
    plt.show()
    return None


def test():
    trdif4=testr4dif()
    tr4=testr4()
    trdif2=testr2dif()
    tr2=testr2()
    plot_times(tr2,trdif2,tr4,trdif4)
    return None


test()