// Radix-2 DIF recursion, loop, FFTW3, and non-FFT DFT time comparison 
// John Bryan 2017

// -lrt -lm -lfftw3 

#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <fftw3.h>
#include <complex.h>
#include <math.h>
#include <stdint.h>


#define nu 1000000
#define BILLION 1000000000L
// rows is the number of lengths
#define rows 8    
// columns is the number of iterations averaged
#define columns 1000 


double randf(void)
{
    //return random value between -1 and 1
    double low=-1.0, high=1.0;
    return (rand()/(double)(RAND_MAX))*abs(low-high)+low;
}


double Avg(long int array[]){
    // compute average
    size_t c;
    double sum = 0;
    for (c = 0; c < nu; c++){ sum += array[c]; }
    return sum / nu;
}


long int get_op_time()
     // compute complex multiply time average
{
    complex double a,b,c;
    struct timespec start,end;
    long int times[nu],ts,time;
    int i;
    // nu is the number of iterations averaged over
    for (i=0; i<nu; i++)
    {
         a=randf()+randf()*I;
         b=randf()+randf()*I;
         clock_gettime(CLOCK_MONOTONIC,&start);
         a*b;
         clock_gettime(CLOCK_MONOTONIC,&end);
         ts = BILLION * (end.tv_sec - start.tv_sec) + end.tv_nsec - start.tv_nsec;
         times[i]=ts;
     }
     time=Avg(times);
     return time;
}

int equality_check(complex long double *a, complex long double *b, int N)
{
    // check for closeness 
    int i,flag=0;
    float epsilon=.1;
    for (i=0; i<N;  i++)
       if ((fabs(a[i]-b[i]) > epsilon * fabs(a[i])))
          flag=1;
    return flag;
}

void matrix_print(int matrix[rows][columns])
{
    //  print values  
    int row,column;
    for (row=0; row<rows; row++)
    {
        for(column=0; column<columns; column++)
            {
             printf("%d     ", matrix[row][column]);
            }
        printf("\n");
     }
}


void swap (complex long double *x, int i, int j)
{
    // used in bitreversal 
    complex long double temp;
    temp=x[i];
    x[i]=x[j];
    x[j]=temp;
}

void bitreversal(complex long double *x, int N)
{
     // used at end of fft
     int M=1,T,k,n,h,r[N];
     for (h=0;h<N;h++)
         r[h]=0;
     while (M<N)
     {
         k=0;
         while (k<M)
         {
             T= (int)(2*r[k]);
             r[k]=T;
             r[k+M]=T+1;
             k=k+1;
          }
          M = (int)(2*M);
      }
      n=1;
      while (n<(N-1)) {
          if (r[n]>n) swap(x,n,r[n]);
          n=n+1; }
}

void DIF (int BaseE, int N, complex long double *x, complex long double *twiddle, int tss)
{
    // recursion DIF fft
    int Nprime=0, BaseO=0, n=0,twiddle_factor;
    complex long double e,o;
    if (N==1) return;
    else
    {
       Nprime=N/2;
       BaseO=BaseE+Nprime;
       for (n=0;n<Nprime;n++)
       {
           twiddle_factor=n*tss;
           e=x[BaseE+n]+x[BaseO+n];
           o=(x[BaseE+n]-x[BaseO+n])*twiddle[twiddle_factor];
           x[BaseE+n] =e;
           x[BaseO+n] =o;
        }
        tss=2*tss;
        DIF(BaseE,Nprime,x,twiddle,tss);
        DIF(BaseO,Nprime,x,twiddle,tss);
    }
}

void DIFL(int N, complex long double *x2, complex long double *twiddle,int p)
{
    // loop DIF fft
    int Nprime=0, BaseE=0, BaseO=0, n=0,twiddle_factor=0,Bp=1,tss=1,P=0; int b=0;
    complex long double e,o;
    for (P=0; P<p; P++)
    {
       Nprime=N/2;
       BaseE=0;
       for (b=0; b<Bp; b++)
       {
           BaseO=BaseE+Nprime;
           for (n=0;n<Nprime;n++)
           {
               twiddle_factor=n*tss;
               e=x2[BaseE+n]+x2[BaseO+n];
               o=(x2[BaseE+n]-x2[BaseO+n])*twiddle[twiddle_factor];
               x2[BaseE+n] =e;
               x2[BaseO+n] =o;
            }
            BaseE=BaseE+N;
       }
       Bp=Bp*2;
       N=N/2;
       tss=2*tss;
    }
}

double dAvg(int array[], size_t length){
    // compute average
    size_t c;
    double sum = 0;
    for (c = 0; c < length; c++){ sum += array[c]; }
    return sum / (double) length;
}


int main()
{
   int k,u,n,i,BaseE=0,N,tss=1;
   int uN=rows;
   int it=columns;
   int c=0;
   int z=0;
   double mean[uN],mean2[uN],mean3[uN],mean4[uN];
   float PI=acos(-1);
   complex long double *dft,*x,*x2,*twiddle;
   long int stime;
   int ts, times[uN][it], times2[uN][it], times3[uN][it], times4[uN][it];
   struct timespec start,end;
   FILE *f=fopen("file.txt","w");
   fftw_complex  *in,*out;
   fftw_plan p;
   for (u=0;u<uN;u++)
   {
      for (c=0;c<it;c++)
      {
          N=(int)powf(2.0,(float)(u+3));
          twiddle=malloc(sizeof(complex long double)*N);
          for ( i = 0; i < N; i++ ) twiddle[i]=cexp(-2*PI*I*i/N);
          x=malloc(sizeof(complex long double)*N);
          x2=malloc(sizeof(complex long double)*N);
          for ( i = 0; i < N; i++ ) x[i]=randf()+randf()*I;
          dft=malloc(sizeof(complex long double)*N);
          for ( i = 0; i < N; i++ ) dft[i]=0.0+0.0*I;
          clock_gettime(CLOCK_MONOTONIC,&start);
          for ( k = 0; k < N; k++ )
             for ( n = 0; n < N; n++ )
                 dft[k]=dft[k]+x[n]*cexp(-2*PI*I*n*k/N);
          clock_gettime(CLOCK_MONOTONIC,&end);
          ts = BILLION * (end.tv_sec - start.tv_sec) + end.tv_nsec - start.tv_nsec;
          times2[u][c]=ts;
          in=fftw_malloc(sizeof(fftw_complex)*N);
          out=fftw_malloc(sizeof(fftw_complex)*N );
          for ( i = 0; i < N; i++ )
          {
              in[i][0]=creal(x[i]);
              in[i][1]=cimag(x[i]);
          }
          for ( i = 0; i < N; i++ )
          {
              x2[i]=x[i];
          }
          BaseE=0; tss=1;
          clock_gettime(CLOCK_MONOTONIC,&start);
          DIF(BaseE,N,x,twiddle,tss);
          bitreversal(x,N);
          clock_gettime(CLOCK_MONOTONIC,&end);
          ts = BILLION * (end.tv_sec - start.tv_sec) + end.tv_nsec - start.tv_nsec;
          times[u][c]=ts;
          clock_gettime(CLOCK_MONOTONIC,&start);
          DIFL(N,x2,twiddle,u);
          bitreversal(x2,N);
          clock_gettime(CLOCK_MONOTONIC,&end);
          ts = BILLION * (end.tv_sec - start.tv_sec) + end.tv_nsec - start.tv_nsec;
          times4[u][c]=ts;
          p = fftw_plan_dft_1d ( N, in, out, FFTW_FORWARD, FFTW_ESTIMATE );
          clock_gettime(CLOCK_MONOTONIC,&start);
          fftw_execute(p);
          clock_gettime(CLOCK_MONOTONIC,&end);
          ts = BILLION * (end.tv_sec - start.tv_sec) + end.tv_nsec - start.tv_nsec;
          times3[u][c]=ts;
          fftw_destroy_plan(p);
          fftw_free(in);
          fftw_free(out);
          free(dft);
          free(x);
      }
   }
   stime=get_op_time();
   for (z=0; z<rows; z++)
   {
        mean[z] =dAvg(times[z], columns)/stime;
        mean2[z]=dAvg(times2[z], columns)/stime;
        mean3[z]=dAvg(times3[z], columns)/stime;
        mean4[z]=dAvg(times4[z], columns)/stime;
   }
   // print time results to file
   //recursion time results
   for ( i = 0; i< uN; i++ )
      fprintf (f, "  %3d   %llu \n",(int)powf(2.0,(float)(i+3)) , (long long unsigned int)mean[i] );
   //loop fft time results
   for ( i = 0; i < uN; i++ )
      fprintf (f, "  %3d   %llu \n",(int)powf(2.0,(float)(i+3)) , (long long unsigned int)mean4[i] );
   //non-fft dft time results
   for ( i = 0; i < uN; i++ )
      fprintf (f, "  %3d   %llu \n",(int)powf(2.0,(float)(i+3)) , (long long unsigned int)mean2[i] );
   //fftw3 time results
   for ( i = 0; i < uN; i++ )
      fprintf (f, "  %3d   %llu \n",(int)powf(2.0,(float)(i+3)) , (long long unsigned int)mean3[i] );
   fclose(f);
   exit(0);
}