Skip to content

Commit 41bb45a

Browse files
authored
Merge pull request #276 from antonte/cooley_tukey_cpp
Minor clean up for cooley tukey in C++
2 parents 6fa1761 + fe358f8 commit 41bb45a

File tree

2 files changed

+26
-17
lines changed

2 files changed

+26
-17
lines changed

contents/cooley_tukey/code/c++/fft.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,19 @@ using std::swap;
1717

1818
using std::size_t;
1919

20-
using c64 = std::complex<double>;
21-
template <typename T>
22-
constexpr T pi() {
23-
return 3.14159265358979323846264338327950288419716;
20+
using complex = std::complex<double>;
21+
static const double pi = 3.14159265358979323846264338327950288419716;
22+
23+
template <typename Iter>
24+
void dft(Iter X, Iter last) {
25+
const auto N = last - X;
26+
std::vector<complex> tmp(N);
27+
for (auto i = 0; i < N; ++i) {
28+
for (auto j = 0; j < N; ++j) {
29+
tmp[i] += X[j] * exp(complex(0, -2.0 * M_PI * i * j / N));
30+
}
31+
}
32+
std::copy(std::begin(tmp), std::end(tmp), X);
2433
}
2534

2635
// `cooley_tukey` does the cooley-tukey algorithm, recursively
@@ -30,12 +39,12 @@ void cooley_tukey(Iter first, Iter last) {
3039
if (size >= 2) {
3140
// split the range, with even indices going in the first half,
3241
// and odd indices going in the last half.
33-
auto temp = std::vector<c64>(size / 2);
34-
for (size_t i = 0; i < size / 2; ++i) {
42+
auto temp = std::vector<complex>(size / 2);
43+
for (int i = 0; i < size / 2; ++i) {
3544
temp[i] = first[i * 2 + 1];
3645
first[i] = first[i * 2];
3746
}
38-
for (size_t i = 0; i < size / 2; ++i) {
47+
for (int i = 0; i < size / 2; ++i) {
3948
first[i + size / 2] = temp[i];
4049
}
4150

@@ -45,8 +54,8 @@ void cooley_tukey(Iter first, Iter last) {
4554
cooley_tukey(split, last);
4655

4756
// now combine each of those halves with the butterflies
48-
for (size_t k = 0; k < size / 2; ++k) {
49-
auto w = std::exp(c64(0, -2.0 * pi<double>() * k / size));
57+
for (int k = 0; k < size / 2; ++k) {
58+
auto w = std::exp(complex(0, -2.0 * pi * k / size));
5059

5160
auto& bottom = first[k];
5261
auto& top = first[k + size / 2];
@@ -83,11 +92,11 @@ void iterative_cooley_tukey(Iter first, Iter last) {
8392

8493
// perform the butterfly on the range
8594
auto size = last - first;
86-
for (size_t stride = 2; stride <= size; stride *= 2) {
87-
auto w = exp(c64(0, -2.0 * pi<double>() / stride));
88-
for (size_t j = 0; j < size; j += stride) {
89-
auto v = c64(1.0);
90-
for (size_t k = 0; k < stride / 2; k++) {
95+
for (int stride = 2; stride <= size; stride *= 2) {
96+
auto w = exp(complex(0, -2.0 * pi / stride));
97+
for (int j = 0; j < size; j += stride) {
98+
auto v = complex(1.0);
99+
for (int k = 0; k < stride / 2; k++) {
91100
first[k + j + stride / 2] =
92101
first[k + j] - v * first[k + j + stride / 2];
93102
first[k + j] -= (first[k + j + stride / 2] - first[k + j]);
@@ -103,7 +112,7 @@ int main() {
103112
std::mt19937 rng(random_device());
104113
std::uniform_real_distribution<double> distribution(0.0, 1.0);
105114

106-
std::array<c64, 64> initial;
115+
std::array<complex, 64> initial;
107116
std::generate(
108117
begin(initial), end(initial), [&] { return distribution(rng); });
109118

@@ -117,7 +126,7 @@ int main() {
117126
// Check if the arrays are approximately equivalent
118127
std::cout << std::right << std::setw(16) << "idx" << std::setw(16) << "rec"
119128
<< std::setw(16) << "it" << std::setw(16) << "subtracted" << '\n';
120-
for (int i = 0; i < initial.size(); ++i) {
129+
for (size_t i = 0; i < initial.size(); ++i) {
121130
auto rec = recursive[i];
122131
auto it = iterative[i];
123132
std::cout << std::setw(16) << i << std::setw(16) << std::abs(rec)

contents/cooley_tukey/cooley_tukey.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ In the end, the code looks like:
119119
{% sample lang="c" %}
120120
[import:20-39, lang:"c_cpp"](code/c/fft.c)
121121
{% sample lang="cpp" %}
122-
[import:27-57, lang:"c_cpp"](code/c++/fft.cpp)
122+
[import:35-66, lang:"c_cpp"](code/c++/fft.cpp)
123123
{% sample lang="hs" %}
124124
[import:6-19, lang:"haskell"](code/haskell/fft.hs)
125125
{% sample lang="py" %}

0 commit comments

Comments
 (0)