@@ -311,68 +311,67 @@ squeeze :
311
311
Tensor to dtype
312
312
squeeze $ MkTensor {shape} x = MkTensor $ Reshape shape to x
313
313
314
- ||| A `SliceOrIndex d` is a valid slice or index into a dimension of size `d`. See `slice` for
314
+ ||| A `Subset d` is a valid slice or index into a dimension of size `d`. See `slice` for
315
315
||| details.
316
316
export
317
- data SliceOrIndex : Nat -> Type where
318
- Slice :
319
- (from, to : Nat ) ->
320
- {size : _} ->
321
- {auto 0 fromTo : from + size = to} ->
322
- {auto 0 inDim : LTE to d} ->
323
- SliceOrIndex d
324
- Index : (idx : Nat ) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
325
- DynamicSlice : Tensor [] U64 -> (size : Nat ) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
326
- DynamicIndex : Tensor [] U64 -> SliceOrIndex d
317
+ data Subset : Type where
318
+ All : Subset
319
+ Slice : (from, to : Nat ) -> Subset
320
+ Index : (idx : Nat ) -> Subset
321
+ DynamicSlice : Tensor [] U64 -> (size : Nat ) -> Subset
322
+ DynamicIndex : Tensor [] U64 -> Subset
327
323
328
324
||| Index at `idx`. See `slice` for details.
329
325
public export
330
- at : (idx : Nat ) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
326
+ at : (idx : Nat ) -> Subset
331
327
at = Index
332
328
333
329
namespace Dynamic
334
330
||| Index at the specified index. See `slice` for details.
335
331
public export
336
- at : Tensor [] U64 -> SliceOrIndex d
332
+ at : Tensor [] U64 -> Subset
337
333
at = DynamicIndex
338
334
339
335
||| Slice from `from` (inclusive) to `to` (exclusive). See `slice` for details.
340
336
public export
341
- (. to) :
342
- (from, to : Nat ) ->
343
- {size : _} ->
344
- {auto 0 fromTo : from + size = to} ->
345
- {auto 0 inDim : LTE to d} ->
346
- SliceOrIndex d
337
+ (. to) : (from, to : Nat ) -> {auto 0 ordered : from `LT ` to} -> Subset
347
338
(. to) = Slice
348
339
349
340
||| Slice `size` elements starting at the specified scalar `U64` index. See `slice` for details.
350
341
public export
351
- (. size) : Tensor [] U64 -> (size : Nat ) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
342
+ (. size) : Tensor [] U64 -> (size : Nat ) -> Subset
352
343
(. size) = DynamicSlice
353
344
354
345
||| Slice across all indices along an axis. See `slice` for details.
355
346
public export
356
- all : { d : _} -> SliceOrIndex d
357
- all = Slice 0 @{ % search} @{reflexive {ty = Nat }} d
347
+ all : Subset
348
+ all = All
358
349
359
- ||| A `MultiSlice shape` is a valid multi-dimensional slice into a tensor with shape `shape`.
360
- ||| See `slice` for details.
361
- public export
362
- data MultiSlice : Shape -> Type where
363
- Nil : MultiSlice ds
364
- (:: ) : SliceOrIndex d -> MultiSlice ds -> MultiSlice (d :: ds)
350
+ public export -- is there a stdlib version of this?
351
+ assert : Bool -> e -> Either e a -> Either e a
352
+ assert True _ either = either
353
+ assert False e _ = Left e
354
+
355
+ namespace Subset
356
+ public export
357
+ data InvalidSubsetError =
358
+ ||| The number of dimensions requested and found
359
+ OutOfBounds Nat Nat
360
+
361
+ ||| The number of unaccounted-for axes
362
+ TooManyAxes Nat
365
363
366
- namespace MultiSlice
367
364
||| The shape of a tensor produced by slicing with the specified multi-dimensional slice. See
368
365
||| `Tensor.slice` for details.
369
366
public export
370
- slice : {shape : _} -> MultiSlice shape -> Shape
371
- slice {shape} [] = shape
372
- slice {shape = (_ :: _ )} (Slice {size} _ _ :: xs) = size :: slice xs
373
- slice {shape = (_ :: _ )} (Index _ :: xs) = slice xs
374
- slice {shape = (_ :: _ )} (DynamicSlice _ size :: xs) = size :: slice xs
375
- slice {shape = (_ :: _ )} (DynamicIndex _ :: xs) = slice xs
367
+ slice : Shape -> List Subset -> Either InvalidSubsetError Shape
368
+ slice [] at@(_ :: _ ) = Left TooManyAxes (length at)
369
+ slice ds [] = ds
370
+ slice (d :: ds) (Slice from to :: xs) = assert (to > d) (OutOfBounds to d) $ map (size :: ) (slice ds xs)
371
+ slice (d :: ds) (Index idx :: xs) = assert (idx >= d) (OutOfBounds idx d) $ slice ds xs
372
+ slice (d :: ds) (DynamicSlice _ size :: xs) =
373
+ assert (size > d) (OutOfBounds size d) $ map (size :: ) (slice ds xs)
374
+ slice (_ :: ds) (DynamicIndex _ :: xs) = slice ds xs
376
375
377
376
||| Slice or index `Tensor` axes. Each axis can be sliced or indexed, and this can be done with
378
377
||| either static (`Nat`) or dynamic (scalar `U64`) indices.
@@ -462,47 +461,40 @@ namespace MultiSlice
462
461
export
463
462
slice :
464
463
Primitive dtype =>
465
- (at : MultiSlice shape ) ->
464
+ (at : List Subset ) ->
466
465
Tensor shape dtype ->
467
- Tensor (slice at) dtype
466
+ {auto 0 shape' : IsRight (slice at shape)} ->
467
+ case shape' of Right shape' => Tensor shape' dtype
468
468
slice at $ MkTensor x = MkTensor
469
- $ Reshape (mapd size id at) (MultiSlice . slice at)
470
- $ DynamicSlice (dynStarts [] at) (mapd size id at)
471
- $ Slice (mapd start (const 0 ) at) (mapd stop id at) (replicate (length shape) 1 ) x
472
-
473
- where
474
- mapd : ((Nat -> a) -> {d : Nat } -> SliceOrIndex d -> a) ->
475
- (Nat -> a) ->
476
- {shape : Shape} ->
477
- MultiSlice shape ->
478
- List a
479
- mapd _ dflt {shape} [] = Prelude . map dflt shape
480
- mapd f dflt (x :: xs) = f dflt x :: mapd f dflt xs
481
-
482
- start : (Nat -> Nat ) -> {d : Nat } -> SliceOrIndex d -> Nat
483
- start _ (Slice from _ ) = from
484
- start _ (Index idx) = idx
485
- start f {d} _ = f d
486
-
487
- stop : (Nat -> Nat ) -> {d : Nat } -> SliceOrIndex d -> Nat
488
- stop _ (Slice _ to) = to
489
- stop _ (Index idx) = S idx
490
- stop f {d} _ = f d
491
-
492
- size : (Nat -> Nat ) -> {d : Nat } -> SliceOrIndex d -> Nat
493
- size _ (Slice {size = size'} _ _ ) = size'
494
- size _ (Index _ ) = 1
495
- size _ (DynamicSlice _ size') = size'
496
- size _ (DynamicIndex _ ) = 1
497
-
498
- zero : Expr
499
- zero = FromLiteral {shape = []} {dtype = U64 } 0
500
-
501
- dynStarts : List Expr -> {shape : _} -> MultiSlice shape -> List Expr
502
- dynStarts idxs {shape} [] = replicate (length shape) zero ++ idxs
503
- dynStarts idxs (DynamicSlice (MkTensor i) _ :: ds) = i :: dynStarts idxs ds
504
- dynStarts idxs (DynamicIndex (MkTensor i) :: ds) = i :: dynStarts idxs ds
505
- dynStarts idxs (_ :: ds) = zero :: dynStarts idxs ds
469
+ $ Reshape (map size at) (MultiSlice . slice at)
470
+ $ DynamicSlice (map dynStart at ++ replicate (length shape `minus` length at) zero) (map size at)
471
+ $ Slice (map start at) (map (uncurry stop) (zip shape at)) (replicate (length shape) 1 ) x -- zip doesn't account for length difference
472
+
473
+ where
474
+
475
+ start : Subset -> Nat
476
+ start (Slice from _ ) = from
477
+ start (Index idx) = idx
478
+ start _ = 0
479
+
480
+ stop : Nat -> Subset -> Nat
481
+ stop _ (Slice _ to) = to
482
+ stop _ (Index idx) = S idx
483
+ stop d _ = d
484
+
485
+ size : Subset -> Nat
486
+ size (Slice {size = size'} _ _ ) = size'
487
+ size (Index _ ) = 1
488
+ size (DynamicSlice _ size') = size'
489
+ size (DynamicIndex _ ) = 1
490
+
491
+ zero : Expr
492
+ zero = FromLiteral {shape = []} {dtype = U64 } 0
493
+
494
+ dynStart : Subset -> Expr
495
+ dynStart (DynamicSlice (MkTensor i) _ ) = i
496
+ dynStart (DynamicIndex (MkTensor i)) = i
497
+ dynStart _ = zero
506
498
507
499
||| Concatenate two `Tensor`s along the specfied `axis`. For example,
508
500
||| `concat 0 (tensor [[1, 2], [3, 4]]) (tensor [[5, 6]])` and
0 commit comments