source of highlighter
plain | download
    1 #include <complex>
    2 #include <limits>
    3 #include <cstdint>
    4 #include <cstring>
    5 #include <stdexcept>
    6 #include <byteswap.h>
    7 
    8 static const unsigned power_of_2 (size_t n)
    9 {
   10         const unsigned table[37] = {
   11                 32, 0, 1, 26, 2, 23, 27, 0, 3, 16, 24, 30, 28, 11, 0, 13,
   12                 4, 7, 17, 0, 25, 22, 31, 15, 29, 10, 12, 6, 0, 21, 14, 9,
   13                 5, 20, 8, 19, 18
   14         };
   15 
   16         if (n == 0 || n & (n - 1)) {
   17                 return 0;
   18         } else {
   19                 return table[n % 37];
   20         }
   21 }
   22 
   23 static unsigned revbin (unsigned x, unsigned ln)
   24 {
   25 #define R2(x) ((((x) >> 1) & 1) | (((x) & 1) << 1))
   26 #define R4(x) (R2(((x) >> 2) & 3) | (R2((x) & 3) << 2))
   27 #define R8(x) (R4(((x) >> 4) & 15) | (R4((x) & 15) << 4))
   28 #define R1(x) R8(x), R8(x+1), R8(x+2), R8(x+3)
   29 #define R(x) R1(16*x), R1(16*x+4), R1(16*x+8), R1(16*x+12)
   30         const unsigned char u8rev[256] = {
   31                 R(0), R(1), R(2), R(3), R(4), R(5), R(6), R(7),
   32                 R(8), R(9), R(10), R(11), R(12), R(13), R(14), R(15),
   33         };
   34 
   35         unsigned char * p = static_cast<unsigned char *> (static_cast<void *> (&x));
   36 
   37         for (unsigned i = 0; i < sizeof (x); ++i) {
   38                 p[i] = u8rev[p[i]];
   39         }
   40 
   41         if (sizeof (x) == 4) {
   42                 x = __bswap_32 (x);
   43         } else if (sizeof (x) == 8) {
   44                 x = __bswap_64 (x);
   45         } else {
   46                 for (unsigned i = 0; i < sizeof (x) / 2; ++i) {
   47                         p[i] = p[sizeof (x) - i - 1];
   48                 }
   49         }
   50 
   51         return (x >> (sizeof (x) * 8 - ln));
   52 }
   53 
   54 template<typename T>
   55 class fft
   56 {
   57         unsigned n;
   58         unsigned * permutecache;
   59         std::complex<T> * power;
   60 
   61         inline void permute (std::complex<T> *xs) const
   62         {
   63                 for (unsigned i = 0; i < n; ++i) {
   64                         if (i < permutecache[i])
   65                                 std::swap (xs[i], xs[permutecache[i]]);
   66                 }
   67         }
   68 
   69         void ward (std::complex<T> *xs, bool backward) const
   70         {
   71                 int back = backward ? -1 : 1;
   72 
   73                 for (size_t b = 1; b < n; b *= 2) {
   74                         for (size_t k = 0; k < b; ++k) {
   75                                 auto p = power [(back * n * k / (2 * b)) % n];
   76                                 for (size_t j = 0; j < n; j += 2 * b) {
   77                                         auto t = p * xs[j + k + b];
   78                                         xs[j + k + b] = xs[j + k] - t;
   79                                         xs[j + k] += t;
   80                                 }
   81                         }
   82                 }
   83         }
   84 
   85 public:
   86         fft (unsigned n) : n(n)
   87         {
   88                 unsigned p = power_of_2 (n);
   89 
   90                 if (!p) {
   91                         throw std::invalid_argument ("argument must be a power of 2");
   92                 }
   93 
   94                 permutecache = new unsigned[n];
   95                 power = new std::complex<T>[n];
   96 
   97                 std::complex<T> x (0, -2.0*M_PI/n);
   98 
   99                 for (unsigned i = 0, j = 0; i < n; ++i) {
  100                         permutecache[i] = revbin (i, p);
  101                         power[i] = std::exp (x*static_cast<T> (i));
  102                 }
  103         }
  104 
  105         ~fft () {
  106                 delete[] power;
  107                 delete[] permutecache;
  108         }
  109 
  110         void forward (std::complex<T> *xs) const
  111         {
  112                 permute (xs);
  113                 ward (xs, false);
  114         }
  115 
  116         void backward (std::complex<T> *xs) const
  117         {
  118                 permute (xs);
  119                 ward (xs, true);
  120 
  121                 for (size_t i = 0; i < n; ++i) {
  122                         xs[i] /= n;
  123                 }
  124         }
  125 };