# Transform of length 2N real sequence using length N complex FFT. 
# Python 2.7.3
# John Bryan, 2017

import numpy as np
import time
import matplotlib.pyplot as plt
import math
import warnings
np.set_printoptions(threshold=np.nan,precision=3,suppress=1)
warnings.filterwarnings("ignore")


def swap(x,i,j):
    temp=x[i]
    x[i]=x[j]
    x[j]=temp
    return None


def bitreversal(x):
     N = int(x.size)
     r = np.zeros(N, dtype=int)
     M = 1
     while M < N:
        k = 0
        while k < M:
            T = int(2*r[k])
            r[k] = T
            r[k+M] = T+1
            k = k+1
        M = int(2*M)
     n=1
     while n < N-1:
        if r[n]>n:
           swap(x,n,r[n])
        n=n+1
     return x


def fft0 (Zo,twiddle) :
    # p is number of passes
    p=int(np.log2(int(Zo.size)))
    x=Zo.astype(np.complex_)
    N=x.size
    # Bp=#blocks
    Bp=N/2;
    # Np=#butterflies/block 
    Np=2;
    twiddle_step_size=N/2
    for P in range(0, p):
       # pass loop
       Npp= Np/2
       baseT=0
       for b in range(0, Bp):
          # block loop
          baseB=baseT+Npp
          for n in range(0, Npp):
             # butterfly loop
             twiddle_factor=n*twiddle_step_size
             top= x[baseT+n]
             bot=x[baseB+n]*twiddle[twiddle_factor]
             x[baseT+n]=top+bot
             x[baseB+n]=top-bot
          baseT=baseT+Np
       # number of blocks halved
       Bp=Bp/2
       # number of butterflies/block doubled
       Np=Np*2
       twiddle_step_size=twiddle_step_size/2
    return x


def steps(g,wk,W):
   # steps function carries out the steps to calculate transform.    
   # Input g is real sequence.  Output G is transform of g. 
   N1=len(g)
   N2=N1/2
   G=np.zeros(N1,dtype=np.complex_)
   Gn=np.zeros(N1,dtype=np.complex_)
   X1=np.zeros(N2,dtype=np.complex_)
   X2=np.zeros(N2,dtype=np.complex_)
   # x1 is even-indexed values of g. (eq.6) 
   x1=g[::2]
   # x2 is odd-indexed values of g.  (eq.7)
   x2=g[1::2]
   x2j=[1j*m for m in x2]
   x=[a+b for a,b in zip(x1,x2j)]   # (eq.8)
   x0=np.array(x)
   x_=x0.astype(np.complex_)
   Zo=bitreversal(x_)
   X=fft0(Zo,wk)                     # (eq.9)
   Xc=np.conjugate(X)
   for k in range(1,N2):            # (eqs.10,11)
      X1[k]=0.5*(X[k]+Xc[N2-k])
      X2[k]=-0.5*1j*(X[k]-Xc[N2-k])
   X1[0]=0.5*(X[0]+Xc[0])
   X2[0]=-0.5*1j*(X[0]-Xc[0])
   for k in range(0,N2):
       G[k]=X1[k]+(W[k]*X2[k])                 # (eq.12)
   G[N2]=0.5*((X[0]+Xc[0])+1j*(X[0]-Xc[0]))     # (eq.13)
   for k in range (1,N2):
         G[N1-k]= np.conjugate(G[k])            # (eq.14) 
   return G



def plot_times(t2,t3):
    u=np.zeros(15,dtype=int)
    for i in range(4,19):
        u[i-4]=np.power(2,i)
    plt.figure(figsize=(7,5))
    plt.rc("font",size=9)
    plt.loglog(u,t2,'o', ms=5, markerfacecolor="None", markeredgecolor='red', \
    markeredgewidth=1,basex=2, basey=10, label='Real')
    plt.loglog(u,t3,'^', ms=5, markerfacecolor="None", markeredgecolor='green', \
    markeredgewidth=1,basex=2, basey=10, label='Complex')
    plt.legend(loc=2)
    plt.grid()
    plt.xlim([9,530000])
    plt.ylim([.00001,8])
    plt.ylabel("time (seconds)")
    plt.xlabel("sequence length")
    plt.title("Time vs Length")
    plt.savefig('t22.png', bbox_inches='tight')
    plt.show()
    return None


def test():
   # Test w/ different length random sequences  
   flag=0
   i=0
   times=np.zeros(15)
   times2=np.zeros(15)
   for r in range (4,19):
      s=np.power(2,r)
      # Generate random real seq. g w/ length=power of 2.  
      g=np.random.rand(s)
      # Input g to numpy fft to get Gpy to compare w/ G. 
      Gpy=np.fft.fft(g)
      N2=g.size/2
      W=np.zeros(N2, dtype=np.complex)
      for h in range(0,N2):
          W[h]=np.exp(-np.pi*1j*h/N2)
      wk=np.exp(-2j*np.pi*np.arange(0,0.5,1./(0.5*g.size),dtype=np.complex_))
      wk2=np.exp(-2j*np.pi*np.arange(0,0.5,1./(g.size),dtype=np.complex_))
      t0=time.time()
      # Input g to steps function, returning transform G. 
      G=steps(g,wk,W)
      times[i]=time.time()-t0
      t0=time.time()
      g=bitreversal(g)
      G2=fft0(g,wk2)
      times2[i]=time.time()-t0
      # tf is true if transform=numpy fft, false otherwise.
      tf=np.allclose(G,Gpy)
      if tf==0: flag=1
      # assert if false 
      assert(tf)
      print tf
      i=i+1
   if flag==0: print ("All results were correct.")
   print 'real steps fft time is',times
   print 'regular fft time is',times2
   plot_times(times,times2)

test()