# Radix-2 DIF FFT in Python 2.7.3
# John Bryan, June 2016

import numpy as np
import matplotlib.pyplot as plt


def swap (x,i,j) :
     # swap function called in bitreversal. Swap x[i] and x[j].
     temp=x[i]
     x[i]=x[j]
     x[j]=temp
     return None

def bitreversal(x):
        # Called in fft0. 
        # Argument x is the out-of-order array near the end of fft0.
        # Returns in-order x.
        N = int(x.size)
        r = np.zeros(N, dtype=int)
        M = 1
        # initialize r[] in while loop
        while M < N:
             k = 0
             while k < M:
                   T = int(2*r[k])
                   r[k] = T
                   r[k+M] = T+1
                   k = k+1
             M = int(2*M)
        # end of initializing r[]
        # r[n] is the bitreversal of n.
        # for n=1,2,...,N-2 do: if r[n] > n, swap x[n],x[r[n]]
        n=1
        while n < N-1:
            if r[n]>n:
               swap(x,n,r[n])
            n=n+1
        return x


def fft0 (Z) :
        # Radix-2 DIF FFT function 
        # Called from ploty function.
        # Assumes length of input array Z is a power of two.
        # Returns FFT output array z.
        p=int(np.log2(int(Z.size)))
        x=Z.astype(np.complex_)
        twiddle=np.exp(-2j*np.pi*np.arange(0,0.5,1./x.size,dtype=np.complex_))
        Bp=1;
        Np=x.size;
        twiddle_step_size=1
        for P in range(0, p):           # pass loop
           Npp= Np/2
           baseE=0
           for b in range(0, Bp):       # block loop
              baseO=baseE+Npp
              for n in range(0, Npp):   # butterfly loop
                 twiddle_factor= n*twiddle_step_size
                 e= x[baseE+n]+x[baseO+n]
                 o=(x[baseE+n]-x[baseO+n])*twiddle[twiddle_factor]
                 x[baseE+n]=e
                 x[baseO+n]=o
              baseE=baseE+Np
           Bp=Bp*2
           Np=Np/2
           twiddle_step_size=2*twiddle_step_size
        z=bitreversal(x)
        return z


def ploty(t,y,Fs):
        # Set up plot, call FFT function, plot result.
        # Called  from sine function.
        # Inputs t:time vector, y: array, Fs: sampling rate. 
        plt.subplot(2,1,1)
        plt.title('Test of DIF FFT with 10 Hz Sine Input')
        plt.plot(t,y,'k-')
        plt.xlabel('time')
        plt.ylabel('amplitude')
        plt.subplot(2,1,2)
        n = len(y)                       # length of the signal
        k = np.arange(n)
        T = n/Fs
        frq = k/T                        # two-sided frequency range
        freq = frq[range(n/2)]           # one-sided frequency range
        Y = fft0(y)/n                    # fft and normalization
        Y = Y[range(n/2)]
        markerline, stemlines, baseline = plt.stem(freq, abs(Y), '--')
        plt.xlabel('freq (Hz)')
        plt.ylabel('|Y(freq)|')
        plt.ylim((0.0,0.55))
        plt.setp(markerline, 'markerfacecolor', 'b')
        plt.setp(baseline, 'color', 'b', 'linewidth', 2)
        plt.show()
        return None


def sine():
        # sine waveform specification; called from test function. 
        Fs = 128                         # sampling rate
        Ts = 1.0/Fs                      # sampling interval
        t = np.arange(0,1,Ts)            # time vector
        ff = 10                          # frequency of the signal
        y = np.sin(2 * np.pi * ff * t)   # 10 hz sine wave signal
        ploty(t,y,Fs)                    # send sine to ploty function
        return None


def test():
        # test fft 
        sine()
        return None


test()