@@ -312,68 +312,67 @@ squeeze :
312
312
Tensor to dtype
313
313
squeeze $ MkTensor {shape} x = MkTensor $ Reshape shape to x
314
314
315
- ||| A `SliceOrIndex d` is a valid slice or index into a dimension of size `d`. See `slice` for
315
+ ||| A `Subset d` is a valid slice or index into a dimension of size `d`. See `slice` for
316
316
||| details.
317
317
export
318
- data SliceOrIndex : Nat -> Type where
319
- Slice :
320
- (from, to : Nat ) ->
321
- {size : _} ->
322
- {auto 0 fromTo : from + size = to} ->
323
- {auto 0 inDim : LTE to d} ->
324
- SliceOrIndex d
325
- Index : (idx : Nat ) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
326
- DynamicSlice : Tensor [] U64 -> (size : Nat ) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
327
- DynamicIndex : Tensor [] U64 -> SliceOrIndex d
318
+ data Subset : Type where
319
+ All : Subset
320
+ Slice : (from, to : Nat ) -> Subset
321
+ Index : (idx : Nat ) -> Subset
322
+ DynamicSlice : Tensor [] U64 -> (size : Nat ) -> Subset
323
+ DynamicIndex : Tensor [] U64 -> Subset
328
324
329
325
||| Index at `idx`. See `slice` for details.
330
326
public export
331
- at : (idx : Nat ) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
327
+ at : (idx : Nat ) -> Subset
332
328
at = Index
333
329
334
330
namespace Dynamic
335
331
||| Index at the specified index. See `slice` for details.
336
332
public export
337
- at : Tensor [] U64 -> SliceOrIndex d
333
+ at : Tensor [] U64 -> Subset
338
334
at = DynamicIndex
339
335
340
336
||| Slice from `from` (inclusive) to `to` (exclusive). See `slice` for details.
341
337
public export
342
- (. to) :
343
- (from, to : Nat ) ->
344
- {size : _} ->
345
- {auto 0 fromTo : from + size = to} ->
346
- {auto 0 inDim : LTE to d} ->
347
- SliceOrIndex d
338
+ (. to) : (from, to : Nat ) -> {auto 0 ordered : from `LT ` to} -> Subset
348
339
(. to) = Slice
349
340
350
341
||| Slice `size` elements starting at the specified scalar `U64` index. See `slice` for details.
351
342
public export
352
- (. size) : Tensor [] U64 -> (size : Nat ) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
343
+ (. size) : Tensor [] U64 -> (size : Nat ) -> Subset
353
344
(. size) = DynamicSlice
354
345
355
346
||| Slice across all indices along an axis. See `slice` for details.
356
347
public export
357
- all : { d : _} -> SliceOrIndex d
358
- all = Slice 0 @{ % search} @{reflexive {ty = Nat }} d
348
+ all : Subset
349
+ all = All
359
350
360
- ||| A `MultiSlice shape` is a valid multi-dimensional slice into a tensor with shape `shape`.
361
- ||| See `slice` for details.
362
- public export
363
- data MultiSlice : Shape -> Type where
364
- Nil : MultiSlice ds
365
- (:: ) : SliceOrIndex d -> MultiSlice ds -> MultiSlice (d :: ds)
351
+ public export -- is there a stdlib version of this?
352
+ assert : Bool -> e -> Either e a -> Either e a
353
+ assert True _ either = either
354
+ assert False e _ = Left e
355
+
356
+ namespace Subset
357
+ public export
358
+ data InvalidSubsetError =
359
+ ||| The number of dimensions requested and found
360
+ OutOfBounds Nat Nat
361
+
362
+ ||| The number of unaccounted-for axes
363
+ TooManyAxes Nat
366
364
367
- namespace MultiSlice
368
365
||| The shape of a tensor produced by slicing with the specified multi-dimensional slice. See
369
366
||| `Tensor.slice` for details.
370
367
public export
371
- slice : {shape : _} -> MultiSlice shape -> Shape
372
- slice {shape} [] = shape
373
- slice {shape = (_ :: _ )} (Slice {size} _ _ :: xs) = size :: slice xs
374
- slice {shape = (_ :: _ )} (Index _ :: xs) = slice xs
375
- slice {shape = (_ :: _ )} (DynamicSlice _ size :: xs) = size :: slice xs
376
- slice {shape = (_ :: _ )} (DynamicIndex _ :: xs) = slice xs
368
+ slice : Shape -> List Subset -> Either InvalidSubsetError Shape
369
+ slice [] at@(_ :: _ ) = Left TooManyAxes (length at)
370
+ slice ds [] = ds
371
+ slice (d :: ds) (Slice from to :: xs) = assert (to > d) (OutOfBounds to d) $ map (size :: ) (slice ds xs)
372
+ slice (d :: ds) (Index idx :: xs) = assert (idx >= d) (OutOfBounds idx d) $ slice ds xs
373
+ slice (d :: ds) (DynamicSlice _ size :: xs) =
374
+ assert (size > d) (OutOfBounds size d) $ map (size :: ) (slice ds xs)
375
+ slice (_ :: ds) (DynamicIndex _ :: xs) = slice ds xs
377
376
378
377
||| Slice or index `Tensor` axes. Each axis can be sliced or indexed, and this can be done with
379
378
||| either static (`Nat`) or dynamic (scalar `U64`) indices.
@@ -463,47 +462,40 @@ namespace MultiSlice
463
462
export
464
463
slice :
465
464
Primitive dtype =>
466
- (at : MultiSlice shape ) ->
465
+ (at : List Subset ) ->
467
466
Tensor shape dtype ->
468
- Tensor (slice at) dtype
467
+ {auto 0 shape' : IsRight (slice at shape)} ->
468
+ case shape' of Right shape' => Tensor shape' dtype
469
469
slice at $ MkTensor x = MkTensor
470
- $ Reshape (mapd size id at) (MultiSlice . slice at)
471
- $ DynamicSlice (dynStarts [] at) (mapd size id at)
472
- $ Slice (mapd start (const 0 ) at) (mapd stop id at) (replicate (length shape) 1 ) x
473
-
474
- where
475
- mapd : ((Nat -> a) -> {d : Nat } -> SliceOrIndex d -> a) ->
476
- (Nat -> a) ->
477
- {shape : Shape} ->
478
- MultiSlice shape ->
479
- List a
480
- mapd _ dflt {shape} [] = Prelude . map dflt shape
481
- mapd f dflt (x :: xs) = f dflt x :: mapd f dflt xs
482
-
483
- start : (Nat -> Nat ) -> {d : Nat } -> SliceOrIndex d -> Nat
484
- start _ (Slice from _ ) = from
485
- start _ (Index idx) = idx
486
- start f {d} _ = f d
487
-
488
- stop : (Nat -> Nat ) -> {d : Nat } -> SliceOrIndex d -> Nat
489
- stop _ (Slice _ to) = to
490
- stop _ (Index idx) = S idx
491
- stop f {d} _ = f d
492
-
493
- size : (Nat -> Nat ) -> {d : Nat } -> SliceOrIndex d -> Nat
494
- size _ (Slice {size = size'} _ _ ) = size'
495
- size _ (Index _ ) = 1
496
- size _ (DynamicSlice _ size') = size'
497
- size _ (DynamicIndex _ ) = 1
498
-
499
- zero : Expr
500
- zero = FromLiteral {shape = []} {dtype = U64 } 0
501
-
502
- dynStarts : List Expr -> {shape : _} -> MultiSlice shape -> List Expr
503
- dynStarts idxs {shape} [] = replicate (length shape) zero ++ idxs
504
- dynStarts idxs (DynamicSlice (MkTensor i) _ :: ds) = i :: dynStarts idxs ds
505
- dynStarts idxs (DynamicIndex (MkTensor i) :: ds) = i :: dynStarts idxs ds
506
- dynStarts idxs (_ :: ds) = zero :: dynStarts idxs ds
470
+ $ Reshape (map size at) (MultiSlice . slice at)
471
+ $ DynamicSlice (map dynStart at ++ replicate (length shape `minus` length at) zero) (map size at)
472
+ $ Slice (map start at) (map (uncurry stop) (zip shape at)) (replicate (length shape) 1 ) x -- zip doesn't account for length difference
473
+
474
+ where
475
+
476
+ start : Subset -> Nat
477
+ start (Slice from _ ) = from
478
+ start (Index idx) = idx
479
+ start _ = 0
480
+
481
+ stop : Nat -> Subset -> Nat
482
+ stop _ (Slice _ to) = to
483
+ stop _ (Index idx) = S idx
484
+ stop d _ = d
485
+
486
+ size : Subset -> Nat
487
+ size (Slice {size = size'} _ _ ) = size'
488
+ size (Index _ ) = 1
489
+ size (DynamicSlice _ size') = size'
490
+ size (DynamicIndex _ ) = 1
491
+
492
+ zero : Expr
493
+ zero = FromLiteral {shape = []} {dtype = U64 } 0
494
+
495
+ dynStart : Subset -> Expr
496
+ dynStart (DynamicSlice (MkTensor i) _ ) = i
497
+ dynStart (DynamicIndex (MkTensor i)) = i
498
+ dynStart _ = zero
507
499
508
500
||| Concatenate two `Tensor`s along the specfied `axis`. For example,
509
501
||| `concat 0 (tensor [[1, 2], [3, 4]]) (tensor [[5, 6]])` and
0 commit comments