From c5ed7c44c92035086cc0733745414b63746f3618 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20Gillet?= Date: Sat, 29 Sep 2018 12:57:29 +0900 Subject: [PATCH] Added FFT based convolution in Haskell --- .../convolutions/code/haskell/convolution.hs | 18 +++++++++++++++++- contents/convolutions/convolutions.md | 5 ++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/contents/convolutions/code/haskell/convolution.hs b/contents/convolutions/code/haskell/convolution.hs index 59bc1a135..109035871 100644 --- a/contents/convolutions/code/haskell/convolution.hs +++ b/contents/convolutions/code/haskell/convolution.hs @@ -1,5 +1,21 @@ +import Data.Array.CArray +import Data.Complex import Data.List (tails) +import Math.FFT (dft, idft) convolution :: (Num a) => [a] -> [a] -> [a] convolution x = map (sum . zipWith (*) (reverse x)) . spread - where spread = init . tails . (replicate (length x - 1) 0 ++) + where + spread = init . tails . (replicate (length x - 1) 0 ++) + +convolutionFFT :: [Complex Double] -> [Complex Double] -> [Complex Double] +convolutionFFT x y = elems $ idft $ liftArray2 (*) (fft x) (fft y) + where + fft a = dft $ listArray (1, length a) a + +main :: IO () +main = do + let x = [1, 2, 1, 2, 1] + y = [2, 1, 2, 1, 2] + print $ convolution x y + print $ convolutionFFT x y diff --git a/contents/convolutions/convolutions.md b/contents/convolutions/convolutions.md index f74f7f458..4fed7565d 100644 --- a/contents/convolutions/convolutions.md +++ b/contents/convolutions/convolutions.md @@ -39,7 +39,7 @@ In code, this looks something like: {% sample lang="jl" %} [import:1-17, lang:"julia"](code/julia/conv.jl) {% sample lang="hs" %} -[import:1-5, lang:"haskell"](code/haskell/convolution.hs) +[import:6-9, lang:"haskell"](code/haskell/convolution.hs) {% sample lang="c"%} [import:5-18, lang:"c_cpp"](code/c/convolutions.c) {% sample lang="cpp"%} @@ -86,8 +86,7 @@ That said, Julia has an in-built fft routine, so the code for this method could [import:19-22, lang:"julia"](code/julia/conv.jl) Where the `.*` operator is an element-wise multiplication. {% sample lang="hs" %} -The FFT-based convolution in Haskell is complicated, so here is some simple julia code: -[import:19-22, lang:"julia"](code/julia/conv.jl) +[import:11-14, lang:"haskell"](code/haskell/convolution.hs) Where the `.*` operator is an element-wise multiplication. {% sample lang="c"%} [import:20-30, lang:"c_cpp"](code/c/convolutions.c)