fork download
  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <stdint.h>
  4. #define _USE_MATH_DEFINES
  5. #include <math.h>
  6. #include <assert.h>
  7. #include <stdarg.h>
  8. #include <vector>
  9. #include <utility>
  10. #include <numeric>
  11.  
  12. // [CONF] 中間結果表示
  13. #define SHOW_ITM
  14.  
  15. #ifdef _WIN32
  16. int fopen_s_wrp(FILE** const p_fp, const char* const fpath, const char* const mode)
  17. {
  18. return (int)fopen_s(p_fp, fpath, mode);
  19. }
  20. int sprintf_s_wrp(char buf[], const size_t bufsz, const char* const fmt, ...)
  21. {
  22. va_list ap;
  23. int nRet;
  24.  
  25. va_start(ap, fmt);
  26. nRet = vsprintf_s(buf, bufsz, fmt, ap);
  27. va_end(ap);
  28.  
  29. return nRet;
  30. }
  31. #else
  32. int fopen_s_wrp(FILE** const p_fp, const char* const fpath, const char* const mode)
  33. {
  34. *p_fp = fopen(fpath, mode);
  35. return (*p_fp == NULL);
  36. }
  37. int sprintf_s_wrp(char buf[], const size_t bufsz, const char* const fmt, ...)
  38. {
  39. va_list ap;
  40. int nRet;
  41.  
  42. va_start(ap, fmt);
  43. nRet = vsprintf(buf, fmt, ap);
  44. va_end(ap);
  45.  
  46. return nRet;
  47. }
  48. #define _countof(arr) (sizeof(arr) / sizeof(arr[0]))
  49. #endif
  50.  
  51. namespace priroot
  52. {
  53. typedef int64_t pow_t;
  54.  
  55. /// 1の原始n乗根を表す。
  56. /// nの剰余計算を高速化するため、nは2^pに制限する。
  57. class nth_root
  58. {
  59. const pow_t t90;
  60. const pow_t m360;
  61. std::vector<double> m_sin;
  62.  
  63. public:
  64. nth_root(const int p)
  65. : t90(((pow_t)1 << (p - 2))), m360((t90 << 2) - 1), m_sin((size_t)t90 + 1)
  66. {
  67. const double k = 0.5 * M_PI;
  68. for (pow_t i = 1; i < t90; i++) {
  69. m_sin[i] = sin((k * (double)i) / (double)t90);
  70. }
  71. m_sin[0] = 0.0;
  72. m_sin[t90] = 1.0;
  73. }
  74.  
  75. pow_t n() const { return 4 * t90; }
  76.  
  77. /// 1の原始n乗根のk乗を取得する。
  78. void pow(const pow_t k_, double& re, double& im) const {
  79. pow_t k = (k_ < 0) ? 4 * t90 - k_ : k_;
  80. k &= m360; // nの剰余
  81. assert(0 <= k && k < 4 * t90);
  82. if (k <= t90) {
  83. re = m_sin[t90 - k];
  84. im = m_sin[k];
  85. }
  86. else if (k <= 2 * t90) {
  87. re = -m_sin[k - t90]; // = t90 - (2 * t90 - k)
  88. im = m_sin[2 * t90 - k];
  89. }
  90. else if (k <= 3 * t90) {
  91. re = -m_sin[3 * t90 - k]; // = t90 - (k - 2 * t90)
  92. im = -m_sin[k - 2 * t90];
  93. }
  94. else {
  95. re = m_sin[k - 3 * t90]; // = t90 - (4 * t90 - k)
  96. im = -m_sin[4 * t90 - k];
  97. }
  98. }
  99.  
  100. };
  101.  
  102. } // namespace priroot
  103.  
  104. namespace dftn
  105. {
  106. using namespace priroot;
  107.  
  108. /// フーリエ変換
  109. /// DFTの原理式
  110. /// X(k) = Σ[j=0..t360-1]{ x(j) * w^(j*k) } (k=0..t360-1)
  111. /// に従いDFTを計算し、X(k)の実部をfbgn[k..]、虚部をfbgn[(k + t360)..]に格納する。
  112. void fouri(
  113. std::vector<double>::const_iterator xbgn,
  114. std::vector<double>::const_iterator xend,
  115. std::vector<double>::iterator fbgn,
  116. std::vector<double>::iterator fend,
  117. const int p,
  118. const nth_root& w)
  119. {
  120. const pow_t t360 = (pow_t)1 << p;
  121. assert(xend - xbgn == (ptrdiff_t)(2 * t360));
  122. assert(fend - fbgn == (ptrdiff_t)(2 * t360));
  123. assert(w.n() == t360);
  124.  
  125. for (pow_t k = 0; k < t360; k++) {
  126. double sum_re = 0.0;
  127. double sum_im = 0.0;
  128. for (pow_t j = 0; j < t360; j++) {
  129. const double a = xbgn[j];
  130. const double b = xbgn[j + t360];
  131.  
  132. double wjk_re, wjk_im;
  133. w.pow(j * k, wjk_re, wjk_im);
  134. sum_re += a * wjk_re - b * wjk_im;
  135. sum_im += a * wjk_im + b * wjk_re;
  136. }
  137. fbgn[k] = sum_re;
  138. fbgn[k + t360] = sum_im;
  139. }
  140. }
  141.  
  142. /// フーリエ逆変換
  143. /// IDFTの原理式
  144. /// x(j) = Σ[k=0..t360-1]{ X(k) * w^(-j*k) } (j=0..t360-1)
  145. /// に従いIDFTを計算し、x(k)の実部をxbgn[k..]、虚部をxbgn[(k + t360)..]に格納する。
  146. void fouriinv(
  147. std::vector<double>::const_iterator fbgn,
  148. std::vector<double>::const_iterator fend,
  149. std::vector<double>::iterator xbgn,
  150. std::vector<double>::iterator xend,
  151. const int p,
  152. const nth_root& w)
  153. {
  154. const pow_t t360 = (pow_t)1 << p;
  155. assert(fend - fbgn == (ptrdiff_t)(2 * t360));
  156. assert(xend - xbgn == (ptrdiff_t)(2 * t360));
  157. assert(w.n() == t360);
  158.  
  159. for (pow_t j = 0; j < t360; j++) {
  160. double sum_re = 0.0;
  161. double sum_im = 0.0;
  162. for (pow_t k = 0; k < t360; k++) {
  163. const double a = fbgn[k];
  164. const double b = fbgn[k + t360];
  165.  
  166. double wjk_re, wjk_im;
  167. w.pow(j * k, wjk_re, wjk_im);
  168. sum_re += a * wjk_re + b * wjk_im;
  169. sum_im += b * wjk_re - a * wjk_im;
  170. }
  171. xbgn[j] = sum_re / (double)t360;
  172. xbgn[j + t360] = sum_im / (double)t360;
  173. }
  174. }
  175.  
  176. } // namespace dftn
  177.  
  178. namespace dft4n
  179. {
  180. using namespace priroot;
  181.  
  182. /// 4n乗根によるフーリエ変換
  183. /// x(j) (j=0..(2*t360)-1)から
  184. /// x'(j) = (x(j) + i * x(j+t360)) * u^(j/4) (j=0..t360-1)
  185. /// を構成し、x'(j)をDFTする。ここで、uは1の原始t360乗根。
  186. void fouri4n(
  187. std::vector<double>::iterator xbgn,
  188. std::vector<double>::iterator xend,
  189. std::vector<double>::iterator fbgn,
  190. std::vector<double>::iterator fend,
  191. const int p,
  192. const nth_root& w)
  193. {
  194. const pow_t t360 = (pow_t)1 << p;
  195. assert(xend - xbgn == (ptrdiff_t)(2 * t360));
  196. assert(fend - fbgn == (ptrdiff_t)(2 * t360));
  197. assert(w.n() == (t360 << 2));
  198.  
  199. // { x'(j) }構成
  200. // { x(j) + i * x(j+t360) } にu^(j/4)を掛ける。
  201. for (pow_t j = 0; j < t360; j++) {
  202. const double a = xbgn[j];
  203. const double b = xbgn[j + t360];
  204.  
  205. // u = w^4としてu^(j/4) = w^(j)
  206. double wj_re, wj_im;
  207. w.pow(j, wj_re, wj_im);
  208. xbgn[j] = a * wj_re - b * wj_im;
  209. xbgn[j + t360] = a * wj_im + b * wj_re;
  210. }
  211.  
  212. // { x'(j) }のフーリエ変換
  213. std::vector<double> fr((size_t)t360);
  214. for (pow_t k = 0; k < t360; k++) {
  215. double sum_re = 0.0;
  216. double sum_im = 0.0;
  217. for (pow_t j = 0; j < t360; j++) {
  218. const double a = xbgn[j];
  219. const double b = xbgn[j + t360];
  220.  
  221. // u = w^4としてu^(j*k)を掛けて足す
  222. // u^(j*k) = w^(4*j*k)
  223. double wjk_re, wjk_im;
  224. w.pow(4 * j * k, wjk_re, wjk_im);
  225. sum_re += a * wjk_re - b * wjk_im;
  226. sum_im += a * wjk_im + b * wjk_re;
  227. }
  228. fbgn[k] = sum_re;
  229. fbgn[k + t360] = sum_im;
  230. }
  231. }
  232.  
  233. /// 4n乗根による逆フーリエ変換
  234. void fouri4ninv(
  235. std::vector<double>::const_iterator fbgn,
  236. std::vector<double>::const_iterator fend,
  237. std::vector<double>::iterator xbgn,
  238. std::vector<double>::iterator xend,
  239. const int p,
  240. const nth_root& w)
  241. {
  242. const pow_t t360 = (pow_t)1 << p;
  243. assert(fend - fbgn == (ptrdiff_t)(2 * t360));
  244. assert(xend - xbgn == (ptrdiff_t)(2 * t360));
  245. assert(w.n() == (t360 << 2));
  246.  
  247. // { x'(j) }の逆フーリエ変換
  248. for (pow_t j = 0; j < t360; j++) {
  249. double sum_re = 0.0;
  250. double sum_im = 0.0;
  251. for (pow_t k = 0; k < t360; k++) {
  252. const double a = fbgn[k];
  253. const double b = fbgn[k + t360];
  254.  
  255. // u = w^4としてu^(j*(-k))を掛けて足す
  256. // u^(j*(-k)) = w^(-4*j*k)
  257. double wjk_re, wjk_im;
  258. w.pow(4 * j * k, wjk_re, wjk_im);
  259. sum_re += a * wjk_re + b * wjk_im;
  260. sum_im += b * wjk_re - a * wjk_im;
  261. }
  262. xbgn[j] = sum_re / (double)t360;
  263. xbgn[j + t360] = sum_im / (double)t360;
  264. }
  265.  
  266. // { x(j) }復元
  267. for (pow_t j = 0; j < t360; j++) {
  268. const double a = xbgn[j];
  269. const double b = xbgn[j + t360];
  270.  
  271. // u = w^4としてu^(-j/4) = w^(-j)を掛ける
  272. double wj_re, wj_im;
  273. w.pow(j, wj_re, wj_im);
  274. xbgn[j] = a * wj_re + b * wj_im;
  275. xbgn[j + t360] = b * wj_re - a * wj_im;
  276. }
  277. }
  278.  
  279. } // namespace dftn
  280.  
  281. namespace utest
  282. {
  283. using namespace priroot;
  284.  
  285. void nth_rootTest()
  286. {
  287. const int p = 4;
  288.  
  289. nth_root w(p);
  290. const pow_t t360 = (pow_t)1 << p;
  291. for (pow_t k = 0; k <= 2 * t360; k++) {
  292. double re, im;
  293. w.pow(k, re, im);
  294. const double expn = (k % t360 == 0) ? 0.0 : (2 * M_PI) / atan2(im, re);
  295. printf("rt.pow(%2lld) = % f + i * % f (w^% 10f=1)\n", k, re, im, expn);
  296. }
  297. }
  298.  
  299. void show_complex_seq(const std::vector<double>& s)
  300. {
  301. const size_t hsz = s.size() / 2;
  302. assert(2 * hsz == s.size());
  303.  
  304. printf("Re:");
  305. for (size_t i = 0; i < hsz; i++) {
  306. printf(" % 12f", s[i]);
  307. }
  308. printf("\n");
  309. printf("Im:");
  310. for (size_t i = 0; i < hsz; i++) {
  311. printf(" % 12f", s[i + hsz]);
  312. }
  313. printf("\n");
  314. }
  315.  
  316. void fouriTest()
  317. {
  318. const int p = 4;
  319.  
  320. const pow_t t360 = (pow_t)1 << p;
  321. nth_root w(p);
  322.  
  323. // 入力列準備
  324. std::vector<double> x((size_t)(2 * t360));
  325. std::iota(x.begin(), x.begin() + t360, 0.0);
  326.  
  327. printf("input seq:\n");
  328. show_complex_seq(x);
  329. printf("\n");
  330.  
  331. // フーリエ変換
  332. std::vector<double> f((size_t)(2 * t360));
  333. dftn::fouri(x.begin(), x.end(), f.begin(), f.end(), p, w);
  334. x.clear();
  335.  
  336. printf("dftn::fouri() result:\n");
  337. show_complex_seq(f);
  338. printf("\n");
  339.  
  340. // 逆フーリエ変換
  341. std::vector<double> r((size_t)(2 * t360));
  342. dftn::fouriinv(f.begin(), f.end(), r.begin(), r.end(), p, w);
  343. f.clear();
  344.  
  345. printf("dftn::fouriinv() result:\n");
  346. show_complex_seq(r);
  347. printf("\n");
  348.  
  349. // 照合
  350. const size_t sz = r.size();
  351. if (sz != (size_t)(2 * t360)) { abort(); }
  352. for (size_t i = 0; i < sz / 2; i++) {
  353. const double actv = r[i];
  354. const double expv = (double)i;
  355. const double err = std::abs(actv - expv);
  356. if (err >= 0.1) { abort(); }
  357. }
  358. for (size_t i = sz / 2; i < sz; i++) {
  359. const double expv = 0.0;
  360. const double actv = r[i];
  361. const double err = std::abs(actv - expv);
  362. if (err >= 0.1) { abort(); }
  363. }
  364. }
  365.  
  366. void fouri4nTest()
  367. {
  368. const int p = 4;
  369.  
  370. const pow_t t360 = (pow_t)1 << (p - 1);
  371. nth_root w(p + 1);
  372.  
  373. // 入力列準備
  374. std::vector<double> x((size_t)2 * t360);
  375. std::iota(x.begin(), x.end(), 0.0);
  376.  
  377. printf("input seq:\n");
  378. show_complex_seq(x);
  379. printf("\n");
  380.  
  381. // フーリエ変換
  382. std::vector<double> f((size_t)2 * t360);
  383. dft4n::fouri4n(x.begin(), x.end(), f.begin(), f.end(), p - 1, w);
  384. x.clear();
  385.  
  386. printf("dft4n::fouri4n() result:\n");
  387. show_complex_seq(f);
  388. printf("\n");
  389.  
  390. // 逆フーリエ変換
  391. std::vector<double> r((size_t)2 * t360);
  392. dft4n::fouri4ninv(f.begin(), f.end(), r.begin(), r.end(), p - 1, w);
  393. f.clear();
  394.  
  395. printf("dft4n::fouri4ninv() result:\n");
  396. show_complex_seq(r);
  397. printf("\n");
  398.  
  399. // 照合
  400. const size_t sz = r.size();
  401. if (sz != (size_t)(2 * t360)) { abort(); }
  402. for (size_t i = 0; i < sz; i++) {
  403. const double actv = r[i];
  404. const double expv = (double)i;
  405. const double err = std::abs(actv - expv);
  406. if (err >= 0.1) { abort(); }
  407. }
  408. }
  409.  
  410. } // namespace utest
  411.  
  412. int main()
  413. {
  414. #ifdef SHOW_ITM
  415. utest::nth_rootTest();
  416. utest::fouriTest();
  417. utest::fouri4nTest();
  418. #endif
  419. }
  420.  
Success #stdin #stdout 0.01s 5272KB
stdin
Standard input is empty
stdout
rt.pow( 0) =  1.000000 + i *  0.000000 (w^  0.000000=1)
rt.pow( 1) =  0.923880 + i *  0.382683 (w^ 16.000000=1)
rt.pow( 2) =  0.707107 + i *  0.707107 (w^  8.000000=1)
rt.pow( 3) =  0.382683 + i *  0.923880 (w^  5.333333=1)
rt.pow( 4) =  0.000000 + i *  1.000000 (w^  4.000000=1)
rt.pow( 5) = -0.382683 + i *  0.923880 (w^  3.200000=1)
rt.pow( 6) = -0.707107 + i *  0.707107 (w^  2.666667=1)
rt.pow( 7) = -0.923880 + i *  0.382683 (w^  2.285714=1)
rt.pow( 8) = -1.000000 + i *  0.000000 (w^  2.000000=1)
rt.pow( 9) = -0.923880 + i * -0.382683 (w^ -2.285714=1)
rt.pow(10) = -0.707107 + i * -0.707107 (w^ -2.666667=1)
rt.pow(11) = -0.382683 + i * -0.923880 (w^ -3.200000=1)
rt.pow(12) = -0.000000 + i * -1.000000 (w^ -4.000000=1)
rt.pow(13) =  0.382683 + i * -0.923880 (w^ -5.333333=1)
rt.pow(14) =  0.707107 + i * -0.707107 (w^ -8.000000=1)
rt.pow(15) =  0.923880 + i * -0.382683 (w^-16.000000=1)
rt.pow(16) =  1.000000 + i *  0.000000 (w^  0.000000=1)
rt.pow(17) =  0.923880 + i *  0.382683 (w^ 16.000000=1)
rt.pow(18) =  0.707107 + i *  0.707107 (w^  8.000000=1)
rt.pow(19) =  0.382683 + i *  0.923880 (w^  5.333333=1)
rt.pow(20) =  0.000000 + i *  1.000000 (w^  4.000000=1)
rt.pow(21) = -0.382683 + i *  0.923880 (w^  3.200000=1)
rt.pow(22) = -0.707107 + i *  0.707107 (w^  2.666667=1)
rt.pow(23) = -0.923880 + i *  0.382683 (w^  2.285714=1)
rt.pow(24) = -1.000000 + i *  0.000000 (w^  2.000000=1)
rt.pow(25) = -0.923880 + i * -0.382683 (w^ -2.285714=1)
rt.pow(26) = -0.707107 + i * -0.707107 (w^ -2.666667=1)
rt.pow(27) = -0.382683 + i * -0.923880 (w^ -3.200000=1)
rt.pow(28) = -0.000000 + i * -1.000000 (w^ -4.000000=1)
rt.pow(29) =  0.382683 + i * -0.923880 (w^ -5.333333=1)
rt.pow(30) =  0.707107 + i * -0.707107 (w^ -8.000000=1)
rt.pow(31) =  0.923880 + i * -0.382683 (w^-16.000000=1)
rt.pow(32) =  1.000000 + i *  0.000000 (w^  0.000000=1)
input seq:
Re:     0.000000     1.000000     2.000000     3.000000     4.000000     5.000000     6.000000     7.000000     8.000000     9.000000    10.000000    11.000000    12.000000    13.000000    14.000000    15.000000
Im:     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000

dftn::fouri() result:
Re:   120.000000    -8.000000    -8.000000    -8.000000    -8.000000    -8.000000    -8.000000    -8.000000    -8.000000    -8.000000    -8.000000    -8.000000    -8.000000    -8.000000    -8.000000    -8.000000
Im:     0.000000   -40.218716   -19.313708   -11.972846    -8.000000    -5.345429    -3.313708    -1.591299     0.000000     1.591299     3.313708     5.345429     8.000000    11.972846    19.313708    40.218716

dftn::fouriinv() result:
Re:    -0.000000     1.000000     2.000000     3.000000     4.000000     5.000000     6.000000     7.000000     8.000000     9.000000    10.000000    11.000000    12.000000    13.000000    14.000000    15.000000
Im:     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000     0.000000    -0.000000     0.000000     0.000000     0.000000     0.000000

input seq:
Re:     0.000000     1.000000     2.000000     3.000000     4.000000     5.000000     6.000000     7.000000
Im:     8.000000     9.000000    10.000000    11.000000    12.000000    13.000000    14.000000    15.000000

dft4n::fouri4n() result:
Re:   -44.043434     5.749926     7.163243     7.453990     7.495150     7.357149     6.757625     2.066352
Im:    81.225363    14.966947     6.565430     2.426773    -0.787931    -4.276089    -9.748028   -26.372466

dft4n::fouri4ninv() result:
Re:     0.000000     1.000000     2.000000     3.000000     4.000000     5.000000     6.000000     7.000000
Im:     8.000000     9.000000    10.000000    11.000000    12.000000    13.000000    14.000000    15.000000