# Transform of length 2N real sequence using length N complex FFT. 
# Python 2.7.3
# John Bryan, 2017
import numpy as np

def p(q,r):
   print "%s=" % q
   print (r)
   print
   return None

def swap(x,i,j):
    # function called in bitreversal. Swap x[i],x[j].
    temp=x[i]
    x[i]=x[j]
    x[j]=temp
    return None

def bitreversal(x):
     N = int(x.size)
     r = np.zeros(N, dtype=int)
     M = 1
     # initialize r[] in while loop
     while M < N:
        k = 0
        while k < M:
            T = int(2*r[k])
            r[k] = T
            r[k+M] = T+1
            k = k+1
        M = int(2*M)
     # end of initializing r[]
     # r[n] is the bitreversal of n.
     # for n=1,2,...,N-2 do: if r[n] > n, swap x[n],x[r[n]]
     n=1
     while n < N-1:
        if r[n]>n:
           swap(x,n,r[n])
        n=n+1
     return x

def fft0 (z) :
    # Radix-2 DIT FFT function 
    # Called from steps function.
    # Assumes length of input array z is a power of two.
    # Returns FFT output array.
    Zo=bitreversal(z)
    # p is number of passes
    p=int(np.log2(int(Zo.size)))
    x=Zo.astype(np.complex_)
    # twiddle=twiddle table of size N/2
    twiddle=np.exp(-2j*np.pi*np.arange(0,0.5,1./x.size,dtype=np.complex_))
    N=x.size
    # Bp=#blocks
    Bp=N/2;
    # Np=#butterflies/block 
    Np=2;
    twiddle_step_size=N/2
    for P in range(0, p):
       # pass loop
       Npp= Np/2
       baseT=0
       for b in range(0, Bp):
          # block loop
          baseB=baseT+Npp
          for n in range(0, Npp):
             # butterfly loop
             twiddle_factor=n*twiddle_step_size
             top= x[baseT+n]
             bot=x[baseB+n]*twiddle[twiddle_factor]
             x[baseT+n]=top+bot
             x[baseB+n]=top-bot
          baseT=baseT+Np
       # number of blocks halved
       Bp=Bp/2
       # number of butterflies/block doubled
       Np=Np*2
       twiddle_step_size=twiddle_step_size/2
    return x

def steps(g):
   # steps function carries out the steps to calculate transform.    
   # Input g is real sequence.  Output G is transform of g. 
   np.set_printoptions(precision=3,suppress=1)
   N1=len(g)
   N2=N1/2
   # x1 is even-indexed values of g. (eq.6) 
   x1=g[::2]
   # x2 is odd-indexed values of g.  (eq.7)
   x2=g[1::2]
   x2j=[1j*m for m in x2]
   x=[a+b for a,b in zip(x1,x2j)]   # (eq.8)
   x0=np.array(x)
   x_=x0.astype(np.complex_)
   X=fft0(x_)                       # (eq.9)
   Xc=np.conjugate(X)
   X1=[0]*N2
   X2=[0]*N2
   for k in range(0,N2):            # (eqs.10,11)
      N2_k_mod=(N2-k)%N2
      X1[k]=0.5*(X[k]+Xc[N2_k_mod])
      X2[k]=-0.5*1j*(X[k]-Xc[N2_k_mod])
   G=[0]*N1
   Gn=[0]*N1
   np.vectorize(complex)(G)
   np.vectorize(complex)(X1)
   np.vectorize(complex)(X2)
   for k in range(0,N2):
      G[k]=X1[k]+np.exp(-np.pi*1j*k/N2)*X2[k]   # (eq.12)
   G=np.array(G)
   G[N2]=0.5*((X[0]+Xc[0])+1j*(X[0]-Xc[0]))     # (eq.13)
   for k in range (1,N2):
         G[N1-k]= np.conjugate(G[k])            # (eq.14) 
   return G

def test():
   # Test w/ different length random sequences  
   flag=0
   for r in range (3,8):
      s=np.power(2,r)
      # Generate random real seq. g w/ length=power of 2.  
      g=np.random.rand(s)
      # Input g to steps function, returning transform G. 
      G=steps(g)
      # Also input g to numpy fft to get Gpy to compare w/ G. 
      Gpy=np.fft.fft(g)
      p("G",G)
      p("Gpy",Gpy)
      # tf is true if transform=numpy fft, false otherwise.
      tf=np.allclose(G,Gpy)
      if tf==0: flag=1
      # assert if false 
      assert(tf)
      print tf
   if flag==0: print ("All results were correct.")
   print

test()