#  FHT
#  John Bryan, 2017
#  Python 2.7.3
import numpy as np
import time
import matplotlib.pyplot as plt
import math
import warnings
np.set_printoptions(threshold=np.nan,precision=3,suppress=1)
warnings.filterwarnings("ignore")


def swap(x,i,j):
    temp=x[i]
    x[i]=x[j]
    x[j]=temp
    return None


def digitreversal(g,b,p,n):
    if p%2==0:
        n1=int(np.sqrt(n))       #seed table size 
    else:
        n1=int(np.sqrt(int(n/b)))
    # algorithm 2, compute seed table  
    r=np.zeros(n1,dtype=int)
    r[1]=int(n/b)
    for j in range(1,b):
        r[j]=r[j-1]+r[1]
        for i in range(1,int(n1/b)):
            r[b*i]=int(r[i]/b)
            for j in range(1,b):
                r[int(b*i)+j]=r[int(b*i)]+r[j]
    #algorithm 1
    for i in range(0,n1-1):
        for j in range(i+1,n1):
            u=i+r[j]
            v=j+r[i]
            swap(g,u,v)
            if p%2==1:
                for z in range(1,b):
                    uu=i+r[j]+(z*n1)
                    vv=j+r[i]+(z*n1)
                    swap(g,uu,vv)
    return g


def fht(x,c,s,N,p):
    # fast hartley transform 
    xo=np.zeros(N)
    Bp=N/2;
    Np=2;
    tss=N/2
    for P in range(0, p):
       Npp= Np/2
       baseT=0
       for b in range(0, Bp):
          baseB=baseT+Npp
          baseBB=baseT+Np
          for n in range(0, Npp):
             tf=n*tss
             Nmn=(baseBB-n)%baseBB
             if (P%2==0):
                 if (n==0):
                     xcs=x[baseB+n]
                     xo[baseT+n]=x[baseT+n]+xcs
                     xo[baseB+n]=x[baseT+n]-xcs
                 else:
                     xcs=(x[baseB+n]*c[tf])+(x[Nmn]*s[tf])
                     xo[baseT+n]=x[baseT+n]+xcs
                     xo[baseB+n]=x[baseT+n]-xcs
             else:
                 if (n==0):
                     xcs=xo[baseB+n]
                     x[baseT+n]=xo[baseT+n]+xcs
                     x[baseB+n]=xo[baseT+n]-xcs
                 else:
                     xcs=(xo[baseB+n]*c[tf])+(xo[Nmn]*s[tf])
                     x[baseT+n]=xo[baseT+n]+xcs
                     x[baseB+n]=xo[baseT+n]-xcs
          baseT=baseT+Np
       Bp=Bp/2
       Np=Np*2
       tss=tss/2
    if (P%2==0):
        return xo
    else:
        return x


def fft2 (x,twiddle,s0) :
    # radix-2 dit fft
    N=x.size
    Bp=N/2;
    Np=2;
    twiddle_step_size=N/2
    for P in range(0, s0):
       Npp= Np/2
       baseT=0
       for b in range(0, Bp):
          baseB=baseT+Npp
          for n in range(0, Npp):
             if n==0:
                 bot=x[baseB+n]
             else:
                 twiddle_factor=n*twiddle_step_size
                 bot=x[baseB+n]*twiddle[twiddle_factor]
             top=x[baseT+n]
             x[baseT+n]=top+bot
             x[baseB+n]=top-bot
          baseT=baseT+Np
       Bp=Bp/2
       Np=Np*2
       twiddle_step_size=twiddle_step_size/2
    return x


def plot_times(t2,t3):
    u=np.zeros(10,dtype=int)
    for i in range(4,14):
        u[i-4]=np.power(2,i)
    plt.figure(figsize=(7,5))
    plt.rc("font",size=9)
    plt.loglog(u,t3*1000,'^', ms=5, markerfacecolor="None", markeredgecolor='blue', \
    markeredgewidth=1,basex=2, basey=10, label='FFT')
    plt.loglog(u,t2*1000,'o', ms=5, markerfacecolor="None", markeredgecolor='black', \
    markeredgewidth=1,basex=2, basey=10, label='FHT')
    plt.legend(loc=2)
    plt.grid()
    plt.xlim([9,16000])
    plt.ylim([.1,3000])
    plt.ylabel("time (milliseconds)")
    plt.xlabel("sequence length")
    plt.title("Time vs Length for Python code")
    plt.savefig('phartley.png', bbox_inches='tight')
    plt.show()
    return None


def test():
   # Test w/ different length random sequences  
   flag=0
   i=0
   times=np.zeros(10)
   times2=np.zeros(10)
   for r in range (4,14):
      N=np.power(2,r)
      np.random.seed(1)
      g=np.random.rand(N)
      np.random.seed(1)
      gf=np.random.rand(N)
      gpy=np.fft.fft(g)
      hpy=gpy.real-gpy.imag
      gpy=gpy/N
      hpy=hpy/N
      s=np.zeros(N, dtype=np.float)
      c=np.zeros(N, dtype=np.float)
      for d in range(0,N):
          s[d]=np.sin(2*np.pi*d/N)
          c[d]=np.cos(2*np.pi*d/N)
      kmax=(float(N)/2.)-1
      k=np.linspace(0,kmax,kmax+1)
      t=np.exp(-2j*np.pi*k/N)
      b=2
      gf=digitreversal(gf,b,r,N)
      t0=time.time()
      gf=fft2(gf,t,r)
      times2[i]=time.time()-t0
      gf=gf/N
      g=digitreversal(g,b,r,N)
      t0=time.time()
      G2=fht(g,c,s,N,r)
      times[i]=time.time()-t0
      G2=G2/N
      tf=np.allclose(G2,hpy)
      if tf==0: flag=1
      #assert if false 
      assert(tf)
      i=i+1
   if flag==0: print ("All results were correct.")
   plot_times(times,times2)

test()