From 0874b02816f27ad1203053deeafd6dc41e294033 Mon Sep 17 00:00:00 2001 From: Deven Lahoti Date: Sat, 10 Sep 2016 05:05:09 -0400 Subject: [PATCH 1/2] parameterize `Source` and `Target` on the shape of the array --- repa/Data/Array/Repa/Arbitrary.hs | 2 +- repa/Data/Array/Repa/Base.hs | 21 +++++++------------ repa/Data/Array/Repa/Eval.hs | 20 +++++++++--------- repa/Data/Array/Repa/Eval/Load.hs | 12 +++++------ repa/Data/Array/Repa/Eval/Target.hs | 17 +++++++-------- repa/Data/Array/Repa/Operators/IndexSpace.hs | 22 +++++++++----------- repa/Data/Array/Repa/Operators/Interleave.hs | 13 ++++++------ repa/Data/Array/Repa/Operators/Mapping.hs | 6 +++--- repa/Data/Array/Repa/Operators/Reduction.hs | 22 ++++++++++---------- repa/Data/Array/Repa/Operators/Traversal.hs | 14 ++++++------- repa/Data/Array/Repa/Repr/ByteString.hs | 2 +- repa/Data/Array/Repa/Repr/Cursored.hs | 2 +- repa/Data/Array/Repa/Repr/Delayed.hs | 6 +++--- repa/Data/Array/Repa/Repr/ForeignPtr.hs | 6 +++--- repa/Data/Array/Repa/Repr/HintInterleave.hs | 2 +- repa/Data/Array/Repa/Repr/HintSmall.hs | 2 +- repa/Data/Array/Repa/Repr/Partitioned.hs | 2 +- repa/Data/Array/Repa/Repr/Unboxed.hs | 6 +++--- repa/Data/Array/Repa/Repr/Undefined.hs | 2 +- repa/Data/Array/Repa/Repr/Vector.hs | 6 +++--- repa/Data/Array/Repa/Specialised/Dim2.hs | 2 +- repa/Data/Array/Repa/Stencil/Dim2.hs | 10 ++++----- 22 files changed, 95 insertions(+), 102 deletions(-) diff --git a/repa/Data/Array/Repa/Arbitrary.hs b/repa/Data/Array/Repa/Arbitrary.hs index 29e3f3f7..4e5f705b 100644 --- a/repa/Data/Array/Repa/Arbitrary.hs +++ b/repa/Data/Array/Repa/Arbitrary.hs @@ -74,7 +74,7 @@ instance (Shape a, CoArbitrary a) => CoArbitrary (a :. Int) where coarbitrary (a :. b) = coarbitrary a . coarbitrary b -instance (CoArbitrary sh, CoArbitrary a, Source r a, Shape sh) +instance (CoArbitrary sh, CoArbitrary a, Source r sh a) => CoArbitrary (Array r sh a) where coarbitrary arr = (coarbitrary . extent $ arr) . (coarbitrary . toList $ arr) diff --git a/repa/Data/Array/Repa/Base.hs b/repa/Data/Array/Repa/Base.hs index e5e7b1c6..a91e3e5d 100644 --- a/repa/Data/Array/Repa/Base.hs +++ b/repa/Data/Array/Repa/Base.hs @@ -8,18 +8,17 @@ import Data.Array.Repa.Shape -- Source ----------------------------------------------------------------------- -- | Class of array representations that we can read elements from. -class Source r e where +class Shape sh => Source r sh e where -- Arrays with a representation tag, shape, and element type. -- Use one of the type tags like `D`, `U` and so on for @r@, -- one of `DIM1`, `DIM2` ... for @sh@. data Array r sh e -- | O(1). Take the extent (size) of an array. - extent :: Shape sh => Array r sh e -> sh + extent :: Array r sh e -> sh -- | O(1). Shape polymorphic indexing. - index, unsafeIndex - :: Shape sh => Array r sh e -> sh -> e + index, unsafeIndex :: Array r sh e -> sh -> e {-# INLINE index #-} index arr ix = arr `linearIndex` toIndex (extent arr) ix @@ -28,25 +27,22 @@ class Source r e where unsafeIndex arr ix = arr `unsafeLinearIndex` toIndex (extent arr) ix -- | O(1). Linear indexing into underlying, row-major, array representation. - linearIndex, unsafeLinearIndex - :: Shape sh => Array r sh e -> Int -> e + linearIndex, unsafeLinearIndex :: Array r sh e -> Int -> e {-# INLINE unsafeLinearIndex #-} unsafeLinearIndex = linearIndex -- | Ensure an array's data structure is fully evaluated. - deepSeqArray - :: Shape sh =>Array r sh e -> b -> b + deepSeqArray :: Array r sh e -> b -> b -- | O(1). Alias for `index` -(!) :: Shape sh => Source r e => Array r sh e -> sh -> e +(!) :: Source r sh e => Array r sh e -> sh -> e (!) = index -- | O(n). Convert an array to a list. -toList :: Shape sh => Source r e - => Array r sh e -> [e] +toList :: Source r sh e => Array r sh e -> [e] {-# INLINE toList #-} toList arr = go 0 @@ -90,8 +86,7 @@ toList arr -- If you're not sure, then just follow the example code above. -- deepSeqArrays - :: Shape sh => Source r e - => [Array r sh e] -> b -> b + :: Source r sh e => [Array r sh e] -> b -> b {-# INLINE deepSeqArrays #-} deepSeqArrays arrs x = case arrs of diff --git a/repa/Data/Array/Repa/Eval.hs b/repa/Data/Array/Repa/Eval.hs index 06bf8bdd..ae994ee8 100644 --- a/repa/Data/Array/Repa/Eval.hs +++ b/repa/Data/Array/Repa/Eval.hs @@ -62,7 +62,7 @@ import System.IO.Unsafe -- computeP :: ( Load r1 sh e - , Target r2 e, Source r2 e, Monad m) + , Target r2 sh e, Source r2 sh e, Monad m) => Array r1 sh e -> m (Array r2 sh e) computeP arr = now $ suspendedComputeP arr {-# INLINE [4] computeP #-} @@ -70,7 +70,7 @@ computeP arr = now $ suspendedComputeP arr -- | Sequential computation of array elements. computeS - :: (Load r1 sh e, Target r2 e) + :: (Load r1 sh e, Target r2 sh e) => Array r1 sh e -> Array r2 sh e computeS arr1 = arr1 `deepSeqArray` @@ -93,7 +93,7 @@ computeS arr1 -- that each array is fully evaluated before continuing. -- suspendedComputeP - :: (Load r1 sh e, Target r2 e) + :: (Load r1 sh e, Target r2 sh e) => Array r1 sh e -> Array r2 sh e suspendedComputeP arr1 = arr1 `deepSeqArray` @@ -110,8 +110,8 @@ suspendedComputeP arr1 -- -- * You can use it to copy manifest arrays between representations. -- -copyP :: ( Source r1 e, Source r2 e - , Load D sh e, Target r2 e +copyP :: ( Source r1 sh e, Source r2 sh e + , Load D sh e, Target r2 sh e , Monad m) => Array r1 sh e -> m (Array r2 sh e) copyP arr = now $ suspendedCopyP arr @@ -119,8 +119,8 @@ copyP arr = now $ suspendedCopyP arr -- | Sequential copying of arrays. -copyS :: ( Source r1 e - , Load D sh e, Target r2 e) +copyS :: ( Source r1 sh e + , Load D sh e, Target r2 sh e) => Array r1 sh e -> Array r2 sh e copyS arr1 = computeS $ delay arr1 {-# INLINE [4] copyS #-} @@ -128,8 +128,8 @@ copyS arr1 = computeS $ delay arr1 -- | Suspended parallel copy of array elements. suspendedCopyP - :: ( Source r1 e - , Load D sh e, Target r2 e) + :: ( Source r1 sh e + , Load D sh e, Target r2 sh e) => Array r1 sh e -> Array r2 sh e suspendedCopyP arr1 = suspendedComputeP $ delay arr1 {-# INLINE [4] suspendedCopyP #-} @@ -146,7 +146,7 @@ suspendedCopyP arr1 = suspendedComputeP $ delay arr1 -- ... -- @ -- -now :: (Shape sh, Source r e, Monad m) +now :: (Shape sh, Source r sh e, Monad m) => Array r sh e -> m (Array r sh e) now arr = do arr `deepSeqArray` return () diff --git a/repa/Data/Array/Repa/Eval/Load.hs b/repa/Data/Array/Repa/Eval/Load.hs index f60eac18..f5610f43 100644 --- a/repa/Data/Array/Repa/Eval/Load.hs +++ b/repa/Data/Array/Repa/Eval/Load.hs @@ -14,23 +14,23 @@ import Data.Array.Repa.Base -- Note that instances require that the source array to have a delayed -- representation such as `D` or `C`. If you want to use a pre-existing -- manifest array as the source then `delay` it first. -class (Source r1 e, Shape sh) => Load r1 sh e where +class Source r1 sh e => Load r1 sh e where -- | Fill an entire array sequentially. - loadS :: Target r2 e => Array r1 sh e -> MVec r2 e -> IO () + loadS :: Target r2 sh e => Array r1 sh e -> MVec r2 sh e -> IO () -- | Fill an entire array in parallel. - loadP :: Target r2 e => Array r1 sh e -> MVec r2 e -> IO () + loadP :: Target r2 sh e => Array r1 sh e -> MVec r2 sh e -> IO () -- FillRange ------------------------------------------------------------------ -- | Compute a range of elements defined by an array and write them to a fillable -- representation. -class (Source r1 e, Shape sh) => LoadRange r1 sh e where +class Source r1 sh e => LoadRange r1 sh e where -- | Fill a range of an array sequentially. - loadRangeS :: Target r2 e => Array r1 sh e -> MVec r2 e -> sh -> sh -> IO () + loadRangeS :: Target r2 sh e => Array r1 sh e -> MVec r2 sh e -> sh -> sh -> IO () -- | Fill a range of an array in parallel. - loadRangeP :: Target r2 e => Array r1 sh e -> MVec r2 e -> sh -> sh -> IO () + loadRangeP :: Target r2 sh e => Array r1 sh e -> MVec r2 sh e -> sh -> sh -> IO () diff --git a/repa/Data/Array/Repa/Eval/Target.hs b/repa/Data/Array/Repa/Eval/Target.hs index c50876b5..5aff8068 100644 --- a/repa/Data/Array/Repa/Eval/Target.hs +++ b/repa/Data/Array/Repa/Eval/Target.hs @@ -11,31 +11,30 @@ import System.IO.Unsafe -- Target --------------------------------------------------------------------- -- | Class of manifest array representations that can be constructed in parallel. -class Target r e where +class Shape sh => Target r sh e where -- | Mutable version of the representation. - data MVec r e + data MVec r sh e -- | Allocate a new mutable array of the given size. - newMVec :: Int -> IO (MVec r e) + newMVec :: Int -> IO (MVec r sh e) -- | Write an element into the mutable array. - unsafeWriteMVec :: MVec r e -> Int -> e -> IO () + unsafeWriteMVec :: MVec r sh e -> Int -> e -> IO () -- | Freeze the mutable array into an immutable Repa array. - unsafeFreezeMVec :: sh -> MVec r e -> IO (Array r sh e) + unsafeFreezeMVec :: sh -> MVec r sh e -> IO (Array r sh e) -- | Ensure the strucure of a mutable array is fully evaluated. - deepSeqMVec :: MVec r e -> a -> a + deepSeqMVec :: MVec r sh e -> a -> a -- | Ensure the array is still live at this point. -- Needed when the mutable array is a ForeignPtr with a finalizer. - touchMVec :: MVec r e -> IO () + touchMVec :: MVec r sh e -> IO () -- | O(n). Construct a manifest array from a list. -fromList :: (Shape sh, Target r e) - => sh -> [e] -> Array r sh e +fromList :: Target r sh e => sh -> [e] -> Array r sh e fromList sh xx = unsafePerformIO $ do let len = length xx diff --git a/repa/Data/Array/Repa/Operators/IndexSpace.hs b/repa/Data/Array/Repa/Operators/IndexSpace.hs index 0dde3693..72d00e4e 100644 --- a/repa/Data/Array/Repa/Operators/IndexSpace.hs +++ b/repa/Data/Array/Repa/Operators/IndexSpace.hs @@ -24,8 +24,8 @@ stage = "Data.Array.Repa.Operators.IndexSpace" -- Index space transformations ------------------------------------------------ -- | Impose a new shape on the elements of an array. -- The new extent must be the same size as the original, else `error`. -reshape :: ( Shape sh1, Shape sh2 - , Source r1 e) +reshape :: ( Shape sh2 + , Source r1 sh1 e) => sh2 -> Array r1 sh1 e -> Array D sh2 e @@ -43,8 +43,7 @@ reshape sh2 arr -- | Append two arrays. append, (++) - :: ( Shape sh - , Source r1 e, Source r2 e) + :: (Source r1 (sh :. Int) e, Source r2 (sh :. Int) e) => Array r1 (sh :. Int) e -> Array r2 (sh :. Int) e -> Array D (sh :. Int) e @@ -70,7 +69,7 @@ append arr1 arr2 -- | Transpose the lowest two dimensions of an array. -- Transposing an array twice yields the original. transpose - :: (Shape sh, Source r e) + :: (Source r (sh :. Int :. Int) e) => Array r (sh :. Int :. Int) e -> Array D (sh :. Int :. Int) e @@ -82,7 +81,7 @@ transpose arr -- | Extract a sub-range of elements from an array. -extract :: (Shape sh, Source r e) +extract :: (Source r sh e) => sh -- ^ Starting index. -> sh -- ^ Size of result. -> Array r sh e @@ -95,8 +94,8 @@ extract start sz arr -- | Backwards permutation of an array's elements. backpermute, unsafeBackpermute :: forall r sh1 sh2 e - . ( Shape sh1, Shape sh2 - , Source r e) + . ( Shape sh2 + , Source r sh1 e) => sh2 -- ^ Extent of result array. -> (sh2 -> sh1) -- ^ Function mapping each index in the result array -- to an index of the source array. @@ -117,8 +116,7 @@ unsafeBackpermute newExtent perm arr -- from the default array (@arrDft@) backpermuteDft, unsafeBackpermuteDft :: forall r1 r2 sh1 sh2 e - . ( Shape sh1, Shape sh2 - , Source r1 e, Source r2 e) + . (Source r1 sh1 e, Source r2 sh2 e) => Array r2 sh2 e -- ^ Default values (@arrDft@) -> (sh2 -> Maybe sh1) -- ^ Function mapping each index in the result array -- to an index in the source array. @@ -153,7 +151,7 @@ extend, unsafeExtend :: ( Slice sl , Shape (SliceShape sl) , Shape (FullShape sl) - , Source r e) + , Source r (SliceShape sl) e) => sl -> Array r (SliceShape sl) e -> Array D (FullShape sl) e @@ -188,7 +186,7 @@ slice, unsafeSlice :: ( Slice sl , Shape (FullShape sl) , Shape (SliceShape sl) - , Source r e) + , Source r (FullShape sl) e) => Array r (FullShape sl) e -> sl -> Array D (SliceShape sl) e diff --git a/repa/Data/Array/Repa/Operators/Interleave.hs b/repa/Data/Array/Repa/Operators/Interleave.hs index 9d962444..78293073 100644 --- a/repa/Data/Array/Repa/Operators/Interleave.hs +++ b/repa/Data/Array/Repa/Operators/Interleave.hs @@ -24,8 +24,8 @@ import Prelude hiding ((++)) -- @ -- interleave2 - :: ( Shape sh - , Source r1 a, Source r2 a) + :: ( Eq sh + , Source r1 (sh :. Int) a, Source r2 (sh :. Int) a) => Array r1 (sh :. Int) a -> Array r2 (sh :. Int) a -> Array D (sh :. Int) a @@ -51,8 +51,8 @@ interleave2 arr1 arr2 -- | Interleave the elements of three arrays. interleave3 - :: ( Shape sh - , Source r1 a, Source r2 a, Source r3 a) + :: ( Eq sh + , Source r1 (sh :. Int) a, Source r2 (sh :. Int) a, Source r3 (sh :. Int) a) => Array r1 (sh :. Int) a -> Array r2 (sh :. Int) a -> Array r3 (sh :. Int) a @@ -81,8 +81,9 @@ interleave3 arr1 arr2 arr3 -- | Interleave the elements of four arrays. interleave4 - :: ( Shape sh - , Source r1 a, Source r2 a, Source r3 a, Source r4 a) + :: ( Eq sh + , Source r1 (sh :. Int) a, Source r2 (sh :. Int) a + , Source r3 (sh :. Int) a, Source r4 (sh :. Int) a) => Array r1 (sh :. Int) a -> Array r2 (sh :. Int) a -> Array r3 (sh :. Int) a diff --git a/repa/Data/Array/Repa/Operators/Mapping.hs b/repa/Data/Array/Repa/Operators/Mapping.hs index 5006cdf4..5470d2da 100644 --- a/repa/Data/Array/Repa/Operators/Mapping.hs +++ b/repa/Data/Array/Repa/Operators/Mapping.hs @@ -27,7 +27,7 @@ import Data.Word -- | Apply a worker function to each element of an array, -- yielding a new array with the same extent. -- -map :: (Shape sh, Source r a) +map :: Source r sh a => (a -> b) -> Array r sh a -> Array D sh b map f arr = case delay arr of @@ -40,7 +40,7 @@ map f arr -- If the extent of the two array arguments differ, -- then the resulting array's extent is their intersection. -- -zipWith :: (Shape sh, Source r1 a, Source r2 b) +zipWith :: (Source r1 sh a, Source r2 sh b) => (a -> b -> c) -> Array r1 sh a -> Array r2 sh b -> Array D sh c @@ -100,7 +100,7 @@ class Structured r1 a b where -- If you have a cursored or partitioned source array then use that as -- the third argument (corresponding to @r1@ here) szipWith - :: (Shape sh, Source r c) + :: Source r sh c => (c -> a -> b) -> Array r sh c -> Array r1 sh a diff --git a/repa/Data/Array/Repa/Operators/Reduction.hs b/repa/Data/Array/Repa/Operators/Reduction.hs index 4dd4b4b2..b971a8eb 100644 --- a/repa/Data/Array/Repa/Operators/Reduction.hs +++ b/repa/Data/Array/Repa/Operators/Reduction.hs @@ -33,7 +33,7 @@ import GHC.Exts -- >>> foldS c 0 a -- AUnboxed (Z :. 2) (fromList [2,4]) -- -foldS :: (Shape sh, Source r a, Elt a, Unbox a) +foldS :: (Source r sh a, Source r (sh :. Int) a, Elt a, Unbox a) => (a -> a -> a) -> a -> Array r (sh :. Int) a @@ -67,7 +67,7 @@ foldS f z arr -- >>> foldP c 0 a -- AUnboxed (Z :. 2) (fromList [2,4]) -- -foldP :: (Shape sh, Source r a, Elt a, Unbox a, Monad m) +foldP :: (Source r sh a, Source r (sh :. Int) a, Elt a, Unbox a, Monad m) => (a -> a -> a) -> a -> Array r (sh :. Int) a @@ -100,7 +100,7 @@ foldP f z arr -- Elements are reduced in row-major order. Applications of the operator are -- associated arbitrarily. -- -foldAllS :: (Shape sh, Source r a, Elt a, Unbox a) +foldAllS :: (Source r sh a, Elt a, Unbox a) => (a -> a -> a) -> a -> Array r sh a @@ -128,7 +128,7 @@ foldAllS f z arr -- associated arbitrarily. -- foldAllP - :: (Shape sh, Source r a, Elt a, Unbox a, Monad m) + :: (Source r sh a, Elt a, Unbox a, Monad m) => (a -> a -> a) -> a -> Array r sh a @@ -146,7 +146,7 @@ foldAllP f z arr -- sum ------------------------------------------------------------------------ -- | Sequential sum the innermost dimension of an array. -sumS :: (Shape sh, Source r a, Num a, Elt a, Unbox a) +sumS :: (Source r sh a, Source r (sh :. Int) a, Num a, Elt a, Unbox a) => Array r (sh :. Int) a -> Array U sh a sumS = foldS (+) 0 @@ -154,7 +154,7 @@ sumS = foldS (+) 0 -- | Parallel sum the innermost dimension of an array. -sumP :: (Shape sh, Source r a, Num a, Elt a, Unbox a, Monad m) +sumP :: (Source r sh a, Source r (sh :. Int) a, Num a, Elt a, Unbox a, Monad m) => Array r (sh :. Int) a -> m (Array U sh a) sumP = foldP (+) 0 @@ -163,7 +163,7 @@ sumP = foldP (+) 0 -- sumAll --------------------------------------------------------------------- -- | Sequential sum of all the elements of an array. -sumAllS :: (Shape sh, Source r a, Elt a, Unbox a, Num a) +sumAllS :: (Source r sh a, Elt a, Unbox a, Num a) => Array r sh a -> a sumAllS = foldAllS (+) 0 @@ -171,7 +171,7 @@ sumAllS = foldAllS (+) 0 -- | Parallel sum all the elements of an array. -sumAllP :: (Shape sh, Source r a, Elt a, Unbox a, Num a, Monad m) +sumAllP :: (Source r sh a, Elt a, Unbox a, Num a, Monad m) => Array r sh a -> m a sumAllP = foldAllP (+) 0 @@ -179,7 +179,7 @@ sumAllP = foldAllP (+) 0 -- Equality ------------------------------------------------------------------ -instance (Shape sh, Eq sh, Source r a, Eq a) => Eq (Array r sh a) where +instance (Eq sh, Source r sh a, Eq a) => Eq (Array r sh a) where (==) arr1 arr2 = extent arr1 == extent arr2 && (foldAllS (&&) True (R.zipWith (==) arr1 arr2)) @@ -187,7 +187,7 @@ instance (Shape sh, Eq sh, Source r a, Eq a) => Eq (Array r sh a) where -- | Check whether two arrays have the same shape and contain equal elements, -- in parallel. -equalsP :: (Shape sh, Eq sh, Source r1 a, Source r2 a, Eq a, Monad m) +equalsP :: (Eq sh, Source r1 sh a, Source r2 sh a, Eq a, Monad m) => Array r1 sh a -> Array r2 sh a -> m Bool @@ -198,7 +198,7 @@ equalsP arr1 arr2 -- | Check whether two arrays have the same shape and contain equal elements, -- sequentially. -equalsS :: (Shape sh, Eq sh, Source r1 a, Source r2 a, Eq a) +equalsS :: (Eq sh, Source r1 sh a, Source r2 sh a, Eq a) => Array r1 sh a -> Array r2 sh a -> Bool diff --git a/repa/Data/Array/Repa/Operators/Traversal.hs b/repa/Data/Array/Repa/Operators/Traversal.hs index 513c7009..f0c1ba98 100644 --- a/repa/Data/Array/Repa/Operators/Traversal.hs +++ b/repa/Data/Array/Repa/Operators/Traversal.hs @@ -14,8 +14,8 @@ import Prelude hiding (traverse) -- | Unstructured traversal. traverse, unsafeTraverse :: forall r sh sh' a b - . ( Source r a - , Shape sh, Shape sh') + . ( Source r sh a + , Shape sh') => Array r sh a -- ^ Source array. -> (sh -> sh') -- ^ Function to produce the extent of the result. -> ((sh -> a) -> sh' -> b) -- ^ Function to produce elements of the result. @@ -34,7 +34,7 @@ unsafeTraverse arr transExtent newElem -- | Unstructured traversal over two arrays at once. traverse2, unsafeTraverse2 :: forall r1 r2 sh sh' sh'' a b c - . ( Source r1 a, Source r2 b + . ( Source r1 sh a, Source r2 sh' b , Shape sh, Shape sh', Shape sh'') => Array r1 sh a -- ^ First source array. -> Array r2 sh' b -- ^ Second source array. @@ -61,8 +61,8 @@ traverse3, unsafeTraverse3 :: forall r1 r2 r3 sh1 sh2 sh3 sh4 a b c d - . ( Source r1 a, Source r2 b, Source r3 c - , Shape sh1, Shape sh2, Shape sh3, Shape sh4) + . ( Source r1 sh1 a, Source r2 sh2 b, Source r3 sh3 c + , Shape sh4) => Array r1 sh1 a -> Array r2 sh2 b -> Array r3 sh3 c @@ -88,8 +88,8 @@ traverse4, unsafeTraverse4 :: forall r1 r2 r3 r4 sh1 sh2 sh3 sh4 sh5 a b c d e - . ( Source r1 a, Source r2 b, Source r3 c, Source r4 d - , Shape sh1, Shape sh2, Shape sh3, Shape sh4, Shape sh5) + . ( Source r1 sh1 a, Source r2 sh2 b, Source r3 sh3 c, Source r4 sh4 d + , Shape sh5) => Array r1 sh1 a -> Array r2 sh2 b -> Array r3 sh3 c diff --git a/repa/Data/Array/Repa/Repr/ByteString.hs b/repa/Data/Array/Repa/Repr/ByteString.hs index 6fa9fbd2..d0e907f8 100644 --- a/repa/Data/Array/Repa/Repr/ByteString.hs +++ b/repa/Data/Array/Repa/Repr/ByteString.hs @@ -16,7 +16,7 @@ import Data.ByteString (ByteString) data B -- | Read elements from a `ByteString`. -instance Source B Word8 where +instance Shape sh => Source B sh Word8 where data Array B sh Word8 = AByteString !sh !ByteString diff --git a/repa/Data/Array/Repa/Repr/Cursored.hs b/repa/Data/Array/Repa/Repr/Cursored.hs index 3504f35c..78926aa9 100644 --- a/repa/Data/Array/Repa/Repr/Cursored.hs +++ b/repa/Data/Array/Repa/Repr/Cursored.hs @@ -26,7 +26,7 @@ data C -- | Compute elements of a cursored array. -instance Source C a where +instance Shape sh => Source C sh a where data Array C sh a = forall cursor. ACursored diff --git a/repa/Data/Array/Repa/Repr/Delayed.hs b/repa/Data/Array/Repa/Repr/Delayed.hs index fe1a182c..d7fcccd8 100644 --- a/repa/Data/Array/Repa/Repr/Delayed.hs +++ b/repa/Data/Array/Repa/Repr/Delayed.hs @@ -22,7 +22,7 @@ import GHC.Exts data D -- | Compute elements of a delayed array. -instance Source D a where +instance Shape sh => Source D sh a where data Array D sh a = ADelayed !sh @@ -99,7 +99,7 @@ fromFunction sh f -- | O(1). Produce the extent of an array, and a function to retrieve an -- arbitrary element. toFunction - :: (Shape sh, Source r1 a) + :: Source r1 sh a => Array r1 sh a -> (sh, sh -> a) toFunction arr = case delay arr of @@ -112,7 +112,7 @@ toFunction arr -- indices to elements, so consumers don't need to worry about -- what the previous representation was. -- -delay :: Shape sh => Source r e +delay :: Source r sh e => Array r sh e -> Array D sh e delay arr = ADelayed (extent arr) (unsafeIndex arr) {-# INLINE delay #-} diff --git a/repa/Data/Array/Repa/Repr/ForeignPtr.hs b/repa/Data/Array/Repa/Repr/ForeignPtr.hs index 8aacae62..1856cec6 100644 --- a/repa/Data/Array/Repa/Repr/ForeignPtr.hs +++ b/repa/Data/Array/Repa/Repr/ForeignPtr.hs @@ -19,7 +19,7 @@ import qualified Foreign.ForeignPtr.Unsafe as Unsafe data F -- | Read elements from a foreign buffer. -instance Storable a => Source F a where +instance (Storable a, Shape sh) => Source F sh a where data Array F sh a = AForeignPtr !sh !Int !(ForeignPtr a) @@ -50,8 +50,8 @@ instance Storable a => Source F a where -- Load ----------------------------------------------------------------------- -- | Filling foreign buffers. -instance Storable e => Target F e where - data MVec F e +instance (Storable e, Shape sh) => Target F sh e where + data MVec F sh e = FPVec !Int !(ForeignPtr e) newMVec n diff --git a/repa/Data/Array/Repa/Repr/HintInterleave.hs b/repa/Data/Array/Repa/Repr/HintInterleave.hs index 23386af5..971dc93d 100644 --- a/repa/Data/Array/Repa/Repr/HintInterleave.hs +++ b/repa/Data/Array/Repa/Repr/HintInterleave.hs @@ -15,7 +15,7 @@ import Debug.Trace -- and evaluation should be interleaved between the processors. data I r1 -instance Source r1 a => Source (I r1) a where +instance Source r1 sh a => Source (I r1) sh a where data Array (I r1) sh a = AInterleave !(Array r1 sh a) diff --git a/repa/Data/Array/Repa/Repr/HintSmall.hs b/repa/Data/Array/Repa/Repr/HintSmall.hs index 760abe78..82556dbe 100644 --- a/repa/Data/Array/Repa/Repr/HintSmall.hs +++ b/repa/Data/Array/Repa/Repr/HintSmall.hs @@ -12,7 +12,7 @@ import Data.Array.Repa.Shape -- in parallel on the gang. This avoids the associated scheduling overhead. data S r1 -instance Source r1 a => Source (S r1) a where +instance Source r1 sh a => Source (S r1) sh a where data Array (S r1) sh a = ASmall !(Array r1 sh a) diff --git a/repa/Data/Array/Repa/Repr/Partitioned.hs b/repa/Data/Array/Repa/Repr/Partitioned.hs index 97ddfb79..412191fc 100644 --- a/repa/Data/Array/Repa/Repr/Partitioned.hs +++ b/repa/Data/Array/Repa/Repr/Partitioned.hs @@ -36,7 +36,7 @@ inRange (Range _ _ p) ix -- Repr ----------------------------------------------------------------------- -- | Read elements from a partitioned array. -instance (Source r1 e, Source r2 e) => Source (P r1 r2) e where +instance (Source r1 sh e, Source r2 sh e) => Source (P r1 r2) sh e where data Array (P r1 r2) sh e = APart !sh -- size of the whole array !(Range sh) !(Array r1 sh e) -- if in range use this array diff --git a/repa/Data/Array/Repa/Repr/Unboxed.hs b/repa/Data/Array/Repa/Repr/Unboxed.hs index 470db164..d3609a99 100644 --- a/repa/Data/Array/Repa/Repr/Unboxed.hs +++ b/repa/Data/Array/Repa/Repr/Unboxed.hs @@ -28,7 +28,7 @@ import Prelude hiding (zip, zip3, unzip, unzip3) data U -- | Read elements from an unboxed vector array. -instance U.Unbox a => Source U a where +instance (U.Unbox a, Shape sh) => Source U sh a where data Array U sh a = AUnboxed !sh !(U.Vector a) @@ -58,8 +58,8 @@ deriving instance (Read sh, Read e, U.Unbox e) -- Fill ----------------------------------------------------------------------- -- | Filling of unboxed vector arrays. -instance U.Unbox e => Target U e where - data MVec U e +instance (U.Unbox e, Shape sh) => Target U sh e where + data MVec U sh e = UMVec (UM.IOVector e) newMVec n diff --git a/repa/Data/Array/Repa/Repr/Undefined.hs b/repa/Data/Array/Repa/Repr/Undefined.hs index 425062de..6e36c385 100644 --- a/repa/Data/Array/Repa/Repr/Undefined.hs +++ b/repa/Data/Array/Repa/Repr/Undefined.hs @@ -16,7 +16,7 @@ data X -- | Undefined array elements. Inspecting them yields `error`. -- -instance Source X e where +instance Shape sh => Source X sh e where data Array X sh e = AUndefined !sh diff --git a/repa/Data/Array/Repa/Repr/Vector.hs b/repa/Data/Array/Repa/Repr/Vector.hs index 72d2cdd6..76599691 100644 --- a/repa/Data/Array/Repa/Repr/Vector.hs +++ b/repa/Data/Array/Repa/Repr/Vector.hs @@ -21,7 +21,7 @@ import Control.Monad data V -- | Read elements from a boxed vector array. -instance Source V a where +instance Shape sh => Source V sh a where data Array V sh a = AVector !sh !(V.Vector a) @@ -51,8 +51,8 @@ deriving instance (Read sh, Read e) -- Fill ----------------------------------------------------------------------- -- | Filling of boxed vector arrays. -instance Target V e where - data MVec V e +instance Shape sh => Target V sh e where + data MVec V sh e = MVector (VM.IOVector e) newMVec n diff --git a/repa/Data/Array/Repa/Specialised/Dim2.hs b/repa/Data/Array/Repa/Specialised/Dim2.hs index 55a2f79f..b1bd499d 100644 --- a/repa/Data/Array/Repa/Specialised/Dim2.hs +++ b/repa/Data/Array/Repa/Specialised/Dim2.hs @@ -73,7 +73,7 @@ clampToBorder2 (_ :. yLen :. xLen) (sh :. j :. i) -- The border must be the same width on all sides. -- makeBordered2 - :: (Source r1 a, Source r2 a) + :: (Source r1 DIM2 a, Source r2 DIM2 a) => DIM2 -- ^ Extent of array. -> Int -- ^ Width of border. -> Array r1 DIM2 a -- ^ Array for internal elements. diff --git a/repa/Data/Array/Repa/Stencil/Dim2.hs b/repa/Data/Array/Repa/Stencil/Dim2.hs index ff1b36b7..8837ed60 100644 --- a/repa/Data/Array/Repa/Stencil/Dim2.hs +++ b/repa/Data/Array/Repa/Stencil/Dim2.hs @@ -41,7 +41,7 @@ type PC5 = P C (P (S D) (P (S D) (P (S D) (P (S D) X)))) -- Wrappers ------------------------------------------------------------------- -- | Like `mapStencil2` but with the parameters flipped. forStencil2 - :: Source r a + :: Source r DIM2 a => Boundary a -> Array r DIM2 a -> Stencil DIM2 a @@ -55,7 +55,7 @@ forStencil2 boundary arr stencil ------------------------------------------------------------------------------- -- | Apply a stencil to every element of a 2D array. mapStencil2 - :: Source r a + :: Source r DIM2 a => Boundary a -- ^ How to handle the boundary of the array. -> Stencil DIM2 a -- ^ Stencil to apply. -> Array r DIM2 a -- ^ Array to apply stencil to. @@ -128,7 +128,7 @@ mapStencil2 boundary stencil@(StencilStatic sExtent _zero _load) arr unsafeAppStencilCursor2 - :: Source r a + :: Source r DIM2 a => (DIM2 -> Cursor -> Cursor) -> Stencil DIM2 a -> Array r DIM2 a @@ -164,7 +164,7 @@ unsafeAppStencilCursor2 shift -- | Like above, but treat elements outside the array has having a constant value. unsafeAppStencilCursor2_const :: forall r a - . Source r a + . Source r DIM2 a => (DIM2 -> DIM2 -> DIM2) -> Stencil DIM2 a -> a @@ -214,7 +214,7 @@ unsafeAppStencilCursor2_const shift -- | Like above, but clamp out of bounds array values to the closest real value. unsafeAppStencilCursor2_clamp :: forall r a - . Source r a + . Source r DIM2 a => (DIM2 -> DIM2 -> DIM2) -> Stencil DIM2 a -> Array r DIM2 a From 008cd3aca2ba388ad93fdd1aafe146b216ab9615 Mon Sep 17 00:00:00 2001 From: Deven Lahoti Date: Sat, 10 Sep 2016 05:22:52 -0400 Subject: [PATCH 2/2] update repa-algorithms for new `Source` parameter --- .../Data/Array/Repa/Algorithms/DFT/Center.hs | 6 +++--- .../Data/Array/Repa/Algorithms/FFT.hs | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/repa-algorithms/Data/Array/Repa/Algorithms/DFT/Center.hs b/repa-algorithms/Data/Array/Repa/Algorithms/DFT/Center.hs index e5452918..86181fd5 100644 --- a/repa-algorithms/Data/Array/Repa/Algorithms/DFT/Center.hs +++ b/repa-algorithms/Data/Array/Repa/Algorithms/DFT/Center.hs @@ -11,7 +11,7 @@ import Data.Array.Repa.Algorithms.Complex as R -- | Apply the centering transform to a vector. center1d - :: Source r Complex + :: Source r DIM1 Complex => Array r DIM1 Complex -> Array D DIM1 Complex {-# INLINE center1d #-} center1d arr @@ -21,7 +21,7 @@ center1d arr -- | Apply the centering transform to a matrix. center2d - :: Source r Complex + :: Source r DIM2 Complex => Array r DIM2 Complex -> Array D DIM2 Complex {-# INLINE center2d #-} center2d arr @@ -31,7 +31,7 @@ center2d arr -- | Apply the centering transform to a 3d array. center3d - :: Source r Complex + :: Source r DIM3 Complex => Array r DIM3 Complex -> Array D DIM3 Complex {-# INLINE center3d #-} center3d arr diff --git a/repa-algorithms/Data/Array/Repa/Algorithms/FFT.hs b/repa-algorithms/Data/Array/Repa/Algorithms/FFT.hs index cd8c15b8..a568bfbc 100644 --- a/repa-algorithms/Data/Array/Repa/Algorithms/FFT.hs +++ b/repa-algorithms/Data/Array/Repa/Algorithms/FFT.hs @@ -49,7 +49,7 @@ isPowerOfTwo n -- 3D Transform ----------------------------------------------------------------------------------- -- | Compute the DFT of a 3d array. Array dimensions must be powers of two else `error`. -fft3dP :: (Source r Complex, Monad m) +fft3dP :: (Source r DIM3 Complex, Monad m) => Mode -> Array r DIM3 Complex -> m (Array U DIM3 Complex) @@ -76,7 +76,7 @@ fft3dP mode arr fftTrans3d - :: Source r Complex + :: Source r DIM3 Complex => Double -> Array r DIM3 Complex -> Array U DIM3 Complex @@ -88,7 +88,7 @@ fftTrans3d sign arr rotate3d - :: Source r Complex + :: Source r DIM3 Complex => Array r DIM3 Complex -> Array D DIM3 Complex rotate3d arr = backpermute (sh :. m :. k :. l) f arr @@ -100,7 +100,7 @@ rotate3d arr -- Matrix Transform ------------------------------------------------------------------------------- -- | Compute the DFT of a matrix. Array dimensions must be powers of two else `error`. -fft2dP :: (Source r Complex, Monad m) +fft2dP :: (Source r DIM2 Complex, Monad m) => Mode -> Array r DIM2 Complex -> m (Array U DIM2 Complex) @@ -124,7 +124,7 @@ fft2dP mode arr fftTrans2d - :: Source r Complex + :: Source r DIM2 Complex => Double -> Array r DIM2 Complex -> Array U DIM2 Complex @@ -137,7 +137,7 @@ fftTrans2d sign arr -- Vector Transform ------------------------------------------------------------------------------- -- | Compute the DFT of a vector. Array dimensions must be powers of two else `error`. -fft1dP :: (Source r Complex, Monad m) +fft1dP :: (Source r DIM1 Complex, Monad m) => Mode -> Array r DIM1 Complex -> m (Array U DIM1 Complex) @@ -161,7 +161,7 @@ fft1dP mode arr fftTrans1d - :: Source r Complex + :: Source r DIM1 Complex => Double -> Array r DIM1 Complex -> Array U DIM1 Complex @@ -173,7 +173,7 @@ fftTrans1d sign arr -- Rank Generalised Worker ------------------------------------------------------------------------ -fft :: (Shape sh, Source r Complex) +fft :: (Source r (sh :. Int) Complex) => Double -> sh -> Int -> Array r (sh :. Int) Complex -> Array U (sh :. Int) Complex