# Radix-4 N=64 DIT FFT with reduced twiddle table size of 9 
# John Bryan, 2017
# Python 2.7.3
import numpy as np
import time
import matplotlib.pyplot as plt
import math
np.set_printoptions(threshold=np.nan,precision=3,suppress=1)


def tobinary(dec):
    bin=''
    dividend=dec
    while dividend!=0:
        bin=bin+str(dividend%2)
        dividend=dividend/2
    return bin[::-1]


def todec(c):
    i=0
    cs=0
    while (i<4):
        cs=cs+(c[i]*np.power(2,3-i))
        i=i+1
    return cs


def adder(a,q0):
    aa=np.array([0,0,0])
    c=np.zeros(4,dtype=int)
    b=q0
    i=0
    while (i<3):
       aa[i]=a[i]^q0
       i=i+1
    i=3
    while (i>0):
       c[i]=aa[i-1]^b
       b=aa[i-1]&b
       i=i-1
    c[0]=b
    return todec(c)


def swap(x,i,j):
    temp=x[i]
    x[i]=x[j]
    x[j]=temp
    return None


def digitreversal(g,b,p,n):
    if p%2==0:
        n1=int(np.sqrt(n))
    else:
        n1=int(np.sqrt(int(n/b)))
    # algorithm 2, compute seed table  
    r=np.zeros(n1,dtype=int)
    r[1]=int(n/b)
    for j in range(1,b):
        r[j]=r[j-1]+r[1]
        for i in range(1,int(n1/b)):
            r[b*i]=int(r[i]/b)
            for j in range(1,b):
                r[int(b*i)+j]=r[int(b*i)]+r[j]
    #algorithm 1
    for i in range(0,n1-1):
        for j in range(i+1,n1):
            u=i+r[j]
            v=j+r[i]
            swap(g,u,v)
            if p%2==1:
                for z in range(1,b):
                    uu=i+r[j]+(z*n1)
                    vv=j+r[i]+(z*n1)
                    swap(g,uu,vv)
    return g


def tf(ind,mw):
    index=tobinary(ind)
    index=index.rjust(6,'0')
    q0=int(index[2])
    q1=int(index[1])
    q2=int(index[0])
    r=index[3:6]
    q=np.zeros(3,dtype=int)
    for s in range(0,3):
        q[s]=int(r[s])
    index0=adder(q,q0)
    m=mw[index0]
    q0q1=q0^q1
    q0q2=q0^q2
    q0q1q2=q0^q1^q2
    rm=np.real(m)
    im=np.imag(m)
    if q0q2==0:
        pq0q2=1
    else:
        pq0q2=-1
    if q0q1q2==0:
        pq0q1q2=1
    else:
        pq0q1q2=-1
    pq0q2rm=pq0q2*rm
    pq0q2im=pq0q2*im
    jpq0q1q2im=1j*pq0q1q2*im
    jpq0q1q2rm=1j*pq0q1q2*rm
    q0q1q2=q0^q1^q2
    if q0q1==0:
        w=pq0q2rm+jpq0q1q2im
    else:
        w=pq0q2im+jpq0q1q2rm
    return w


def fft4(x,mw,s):
    N=4
    tss=np.power(4,s-1)
    krange=1
    block=int(x.size/4)
    base=0
    for w in range(0,s):
        for h in range(0,block):
           for k in range(0,krange):
                os=N/4
                a=base+k
                b=base+k+os
                c=base+k+(2*os)
                d=base+k+(3*os)
                index1=int(k*tss)
                index2=int(2*k*tss)
                index3=int(3*k*tss)
                r1=tf(index1,mw)
                r2=tf(index2,mw)
                r3=tf(index3,mw)
                xbr1=(x[b]*r1)
                xcr2=(x[c]*r2)
                xdr3=(x[d]*r3)
                e=x[a]+xcr2
                f=x[a]-xcr2
                g=xbr1+xdr3
                hh=xbr1-xdr3
                jh=1j*hh
                x[a]=e+g
                x[b]=f-jh
                x[c]=-g+e
                x[d]=jh+f
           base=base+(4*krange)
        block=block/4
        N=4*N
        krange=4*krange
        base=0
        tss=float(tss)/4.
    return x


def test():
    flag=0
    s=3
    x=np.random.rand(2*np.power(4,s)).view(np.complex128)
    Xpy=np.fft.fft(x)
    b=4
    N=np.power(4,s)
    x=digitreversal(x,b,s,N)
    hh=np.linspace(0,8./float(N),9)
    mw=np.exp(-2j*np.pi*hh)
    x=fft4(x,mw,s)
    tf=np.allclose(x,Xpy)
    if tf==0: flag=1
    assert(tf)
    if flag==0: print ("All radix-4 results were correct.")
    return None


test()