Skip to content

Commit 2be192c

Browse files
authored
Return CBool when using comparison functions. (#24)
* Return CBool when using comparison functions. * Return CBool, remove getScalar. * Re-add tests.
1 parent b06a1f8 commit 2be192c

File tree

6 files changed

+123
-106
lines changed

6 files changed

+123
-106
lines changed

src/ArrayFire/Algorithm.hs

+20-24
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,26 @@
2424
--------------------------------------------------------------------------------
2525
module ArrayFire.Algorithm where
2626

27-
import ArrayFire.Array
2827
import ArrayFire.FFI
2928
import ArrayFire.Internal.Algorithm
3029
import ArrayFire.Internal.Types
3130

32-
import Foreign.C.Types
33-
import Data.Word
34-
3531
-- | Sum all of the elements in 'Array' along the specified dimension
3632
--
3733
-- >>> A.sum (A.vector @Double 10 [1..]) 0
3834
-- 55.0
3935
--
40-
-- >>> A.sum (A.matrix @Double (10,10) [[2..],[2..]]) 0
36+
-- >>> A.matrix @Double (10,10) $ replicate 10 [1..]
4137
-- 65.0
4238
sum
4339
:: AFType a
4440
=> Array a
4541
-- ^ Array to sum
4642
-> Int
47-
-- ^ Dimension along which to perform sum
48-
-> a
43+
-- ^ 0-based Dimension along which to perform sum
44+
-> Array a
4945
-- ^ Will return the sum of all values in the input array along the specified dimension
50-
sum x (fromIntegral -> n) = getScalar (x `op1` (\p a -> af_sum p a n))
46+
sum x (fromIntegral -> n) = (x `op1` (\p a -> af_sum p a n))
5147

5248
-- | Sum all of the elements in 'Array' along the specified dimension, using a default value for NaN
5349
--
@@ -61,9 +57,9 @@ sumNaN
6157
-- ^ Dimension along which to perform sum
6258
-> Double
6359
-- ^ Default value to use in the case of NaN
64-
-> a
60+
-> Array a
6561
-- ^ Will return the sum of all values in the input array along the specified dimension, substituted with the default value
66-
sumNaN n (fromIntegral -> i) d = getScalar (n `op1` (\p a -> af_sum_nan p a i d))
62+
sumNaN n (fromIntegral -> i) d = (n `op1` (\p a -> af_sum_nan p a i d))
6763

6864
-- | Product all of the elements in 'Array' along the specified dimension
6965
--
@@ -75,9 +71,9 @@ product
7571
-- ^ Array to product
7672
-> Int
7773
-- ^ Dimension along which to perform product
78-
-> a
74+
-> Array a
7975
-- ^ Will return the product of all values in the input array along the specified dimension
80-
product x (fromIntegral -> n) = getScalar (x `op1` (\p a -> af_product p a n))
76+
product x (fromIntegral -> n) = (x `op1` (\p a -> af_product p a n))
8177

8278
-- | Product all of the elements in 'Array' along the specified dimension, using a default value for NaN
8379
--
@@ -91,9 +87,9 @@ productNaN
9187
-- ^ Dimension along which to perform product
9288
-> Double
9389
-- ^ Default value to use in the case of NaN
94-
-> a
90+
-> Array a
9591
-- ^ Will return the product of all values in the input array along the specified dimension, substituted with the default value
96-
productNaN n (fromIntegral -> i) d = getScalar (n `op1` (\p a -> af_product_nan p a i d))
92+
productNaN n (fromIntegral -> i) d = n `op1` (\p a -> af_product_nan p a i d)
9793

9894
-- | Take the minimum of an 'Array' along a specific dimension
9995
--
@@ -105,9 +101,9 @@ min
105101
-- ^ Array input
106102
-> Int
107103
-- ^ Dimension along which to retrieve the min element
108-
-> a
104+
-> Array a
109105
-- ^ Will contain the minimum of all values in the input array along dim
110-
min x (fromIntegral -> n) = getScalar (x `op1` (\p a -> af_min p a n))
106+
min x (fromIntegral -> n) = x `op1` (\p a -> af_min p a n)
111107

112108
-- | Take the maximum of an 'Array' along a specific dimension
113109
--
@@ -119,9 +115,9 @@ max
119115
-- ^ Array input
120116
-> Int
121117
-- ^ Dimension along which to retrieve the max element
122-
-> a
118+
-> Array a
123119
-- ^ Will contain the maximum of all values in the input array along dim
124-
max x (fromIntegral -> n) = getScalar (x `op1` (\p a -> af_max p a n))
120+
max x (fromIntegral -> n) = x `op1` (\p a -> af_max p a n)
125121

126122
-- | Find if all elements in an 'Array' are 'True' along a dimension
127123
--
@@ -133,10 +129,10 @@ allTrue
133129
-- ^ Array input
134130
-> Int
135131
-- ^ Dimension along which to see if all elements are True
136-
-> Bool
132+
-> Array a
137133
-- ^ Will contain the maximum of all values in the input array along dim
138134
allTrue x (fromIntegral -> n) =
139-
toEnum . fromIntegral $ getScalar @CBool @a (x `op1` (\p a -> af_all_true p a n))
135+
x `op1` (\p a -> af_all_true p a n)
140136

141137
-- | Find if any elements in an 'Array' are 'True' along a dimension
142138
--
@@ -148,10 +144,10 @@ anyTrue
148144
-- ^ Array input
149145
-> Int
150146
-- ^ Dimension along which to see if all elements are True
151-
-> Bool
147+
-> Array a
152148
-- ^ Returns if all elements are true
153149
anyTrue x (fromIntegral -> n) =
154-
toEnum . fromIntegral $ getScalar @CBool @a (x `op1` (\p a -> af_any_true p a n))
150+
(x `op1` (\p a -> af_any_true p a n))
155151

156152
-- | Count elements in an 'Array' along a dimension
157153
--
@@ -163,9 +159,9 @@ count
163159
-- ^ Array input
164160
-> Int
165161
-- ^ Dimension along which to count
166-
-> Int
162+
-> Array Int
167163
-- ^ Count of all elements along dimension
168-
count x (fromIntegral -> n) = fromIntegral $ getScalar @Word32 @a (x `op1` (\p a -> af_count p a n))
164+
count x (fromIntegral -> n) = x `op1d` (\p a -> af_count p a n)
169165

170166
-- | Sum all elements in an 'Array' along all dimensions
171167
--

0 commit comments

Comments
 (0)