fork(2) download
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. """Polynomial multiplication using a fast Fourier transform (FFT) algorithm.
  5.  
  6. Author: Fabian Reiter
  7. Tested on Ubuntu 12.04 with Python 2.7.3 and SymPy 0.7.1.rc1
  8. """
  9.  
  10. from __future__ import print_function
  11. from sympy import exp, I, pi, sympify
  12. from math import log, ceil
  13.  
  14. # If verbose = True, vprint is the print function, otherwise it does nothing.
  15. verbose = True
  16. vprint = (print) if verbose else (lambda *a, **k: None)
  17.  
  18. # -----------------------------------------------------------------------------
  19. def normal_form(d):
  20. """Return an array with the same elements as array d (i.e. complex numbers)
  21. but make sure that their representation is in the form a + i*b (or close).
  22. """
  23. return [sympify(d_i).expand(complex=True) for d_i in d]
  24.  
  25. # -----------------------------------------------------------------------------
  26. def fft(n, a, depth=0):
  27. """Compute DFT_n(a) using FFT and print the tree of recursive calls.
  28. Arguments:
  29. n -- index of the DFT (at least length of a)
  30. a -- array with the coefficients of the polynomial
  31. depth -- recursion depth (needed for pretty printing, default: 0)
  32. """
  33. # Check input.
  34. if n < len(a):
  35. raise ValueError("DFT index cannot be smaller than array length.")
  36. if log(n,2) != ceil(log(n,2)):
  37. raise ValueError("DFT index must be a power of 2.")
  38. # Print function call.
  39. vprint(" ", "│ " * depth, "┌─FFT(", n, ", ", a, ")", sep='')
  40. # Base case (constant polynomial):
  41. if len(a) == 1:
  42. d = [a[0] for i in range(n)] # d: array of length n filled with a[0]
  43. # Nontrivial case:
  44. else:
  45. a0 = [a[i] for i in range(0, len(a), 2)] # a0 = [a[0], a[2], ...]
  46. a1 = [a[i] for i in range(1, len(a), 2)] # a1 = [a[1], a[3], ...]
  47. d0 = fft(n/2, a0, depth+1) # d0 = DFT_{n/2}(a0)
  48. d1 = fft(n/2, a1, depth+1) # d1 = DFT_{n/2}(a1)
  49. w_n = exp(I*2*pi/n) # represented symbolically using SymPy
  50. w = 1
  51. d = [0 for i in range(n)] # d: array of length n (initialized with 0)
  52. for k in range(n/2):
  53. x = w * d1[k]
  54. d[k] = d0[k] + x # p(ωₙᵏ) = p₀(ω½ₙᵏ) + ωₙᵏ·p₁(ω½ₙᵏ)
  55. d[k+n/2] = d0[k] - x # p(ωₙᵏ⁺½ⁿ) = p₀(ω½ₙᵏ) − ωₙᵏ·p₁(ω½ₙᵏ)
  56. w *= w_n
  57. d = normal_form(d)
  58. # Print function return.
  59. vprint(" ", "│ " * depth, "└‣", d, sep='')
  60. return d
  61.  
  62. # -----------------------------------------------------------------------------
  63. def polymult(a, b=None):
  64. """Multiply two polynomials using FFT and print intermediate steps.
  65. Arguments:
  66. a -- array with the coefficients of the first polynomial
  67. b -- array with the coefficients of the second polynomial (default: None)
  68. If only one polynomial is given, it is multiplied with itself.
  69. """
  70. vprint("\n", "─" * 35, "\n", " POLYNOMIAL MULTIPLICATION VIA FFT", sep='')
  71. vprint("─" * 35, "\n", " a = ", a, "\n", " b = ", b if b else "a", sep='')
  72. # Determine n, the minimum index for the DFT.
  73. n = (len(a) + len(b) - 1) if b else (2 * len(a) - 1)
  74. n = 2**int(ceil(log(n,2))) # Round up to next power of 2.
  75. # 1. Evaluation:
  76. vprint("\n", "1. Evaluation:", "\n ", "─" * 13, sep='')
  77. d_a = fft(n, a) # d_a = DFT_n(a)
  78. vprint("\n" if b else "", end='')
  79. d_b = fft(n, b) if b else d_a # d_b = DFT_n(b)
  80. # 2. Point-wise multiplication:
  81. vprint("\n", "2. Point-wise multiplication:", "\n ", "─" * 28, sep='')
  82. d = [d_a[i] * d_b[i] for i in range(n)] # d = d_a * d_b
  83. d = normal_form(d)
  84. vprint(d)
  85. # 3. Interpolation:
  86. vprint("\n", "3. Interpolation:", "\n ", "─" * 16, sep='')
  87. f = fft(n, d) # f = DFT_n(d)
  88. # Result: Reorder and divide by n.
  89. r = [f[i]/n for i in range(1) + range(n-1, 0, -1)] # r = f[0,n-1,..,1] / n
  90. vprint("\n", "Result:", "\n", " ", r, "\n", sep='')
  91. return r
  92.  
  93. # -----------------------------------------------------------------------------
  94. # If this script is called directly, compute square of p(x) = 2x³ − x² + 4x + 1
  95. if __name__=="__main__":
  96. n=3
  97. a=[1,2,3]
  98. y=[0]
  99. for i in range(1,n+1):
  100. y[i]=y[i-1]+a[i-1]
  101. p1=[]
  102. p2=[]
  103. for i in range(1,y[n]+1):
  104. p1[i]=0
  105. for i in range(1,y[n]+y[n]+1):
  106. p2[i]=0
  107. for i in range(0,y[n]+1):
  108. p1[y[i]]=1
  109. p2[y[n]+y[i]]=1
  110. polymult(p1,p2)
  111.  
Runtime error #stdin #stdout #stderr 0.15s 10224KB
stdin
Standard input is empty
stdout
Standard output is empty
stderr
  File "prog.py", line 97
    a=[1,2,3]
            ^
TabError: inconsistent use of tabs and spaces in indentation