Skip to content

Commit 9711062

Browse files
committed
improve slice type inference with decidability
1 parent 5c2f9b1 commit 9711062

File tree

1 file changed

+66
-74
lines changed

1 file changed

+66
-74
lines changed

spidr/src/Tensor.idr

Lines changed: 66 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -312,68 +312,67 @@ squeeze :
312312
Tensor to dtype
313313
squeeze $ MkTensor {shape} x = MkTensor $ Reshape shape to x
314314

315-
||| A `SliceOrIndex d` is a valid slice or index into a dimension of size `d`. See `slice` for
315+
||| A `Subset d` is a valid slice or index into a dimension of size `d`. See `slice` for
316316
||| details.
317317
export
318-
data SliceOrIndex : Nat -> Type where
319-
Slice :
320-
(from, to : Nat) ->
321-
{size : _} ->
322-
{auto 0 fromTo : from + size = to} ->
323-
{auto 0 inDim : LTE to d} ->
324-
SliceOrIndex d
325-
Index : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
326-
DynamicSlice : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
327-
DynamicIndex : Tensor [] U64 -> SliceOrIndex d
318+
data Subset : Type where
319+
All : Subset
320+
Slice : (from, to : Nat) -> Subset
321+
Index : (idx : Nat) -> Subset
322+
DynamicSlice : Tensor [] U64 -> (size : Nat) -> Subset
323+
DynamicIndex : Tensor [] U64 -> Subset
328324

329325
||| Index at `idx`. See `slice` for details.
330326
public export
331-
at : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
327+
at : (idx : Nat) -> Subset
332328
at = Index
333329

334330
namespace Dynamic
335331
||| Index at the specified index. See `slice` for details.
336332
public export
337-
at : Tensor [] U64 -> SliceOrIndex d
333+
at : Tensor [] U64 -> Subset
338334
at = DynamicIndex
339335

340336
||| Slice from `from` (inclusive) to `to` (exclusive). See `slice` for details.
341337
public export
342-
(.to) :
343-
(from, to : Nat) ->
344-
{size : _} ->
345-
{auto 0 fromTo : from + size = to} ->
346-
{auto 0 inDim : LTE to d} ->
347-
SliceOrIndex d
338+
(.to) : (from, to : Nat) -> {auto 0 ordered : from `LT` to} -> Subset
348339
(.to) = Slice
349340

350341
||| Slice `size` elements starting at the specified scalar `U64` index. See `slice` for details.
351342
public export
352-
(.size) : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
343+
(.size) : Tensor [] U64 -> (size : Nat) -> Subset
353344
(.size) = DynamicSlice
354345

355346
||| Slice across all indices along an axis. See `slice` for details.
356347
public export
357-
all : {d : _} -> SliceOrIndex d
358-
all = Slice 0 @{%search} @{reflexive {ty = Nat}} d
348+
all : Subset
349+
all = All
359350

360-
||| A `MultiSlice shape` is a valid multi-dimensional slice into a tensor with shape `shape`.
361-
||| See `slice` for details.
362-
public export
363-
data MultiSlice : Shape -> Type where
364-
Nil : MultiSlice ds
365-
(::) : SliceOrIndex d -> MultiSlice ds -> MultiSlice (d :: ds)
351+
public export -- is there a stdlib version of this?
352+
assert : Bool -> e -> Either e a -> Either e a
353+
assert True _ either = either
354+
assert False e _ = Left e
355+
356+
namespace Subset
357+
public export
358+
data InvalidSubsetError =
359+
||| The number of dimensions requested and found
360+
OutOfBounds Nat Nat
361+
362+
||| The number of unaccounted-for axes
363+
TooManyAxes Nat
366364

367-
namespace MultiSlice
368365
||| The shape of a tensor produced by slicing with the specified multi-dimensional slice. See
369366
||| `Tensor.slice` for details.
370367
public export
371-
slice : {shape : _} -> MultiSlice shape -> Shape
372-
slice {shape} [] = shape
373-
slice {shape = (_ :: _)} (Slice {size} _ _ :: xs) = size :: slice xs
374-
slice {shape = (_ :: _)} (Index _ :: xs) = slice xs
375-
slice {shape = (_ :: _)} (DynamicSlice _ size :: xs) = size :: slice xs
376-
slice {shape = (_ :: _)} (DynamicIndex _ :: xs) = slice xs
368+
slice : Shape -> List Subset -> Either InvalidSubsetError Shape
369+
slice [] at@(_ :: _) = Left TooManyAxes (length at)
370+
slice ds [] = ds
371+
slice (d :: ds) (Slice from to :: xs) = assert (to > d) (OutOfBounds to d) $ map (size ::) (slice ds xs)
372+
slice (d :: ds) (Index idx :: xs) = assert (idx >= d) (OutOfBounds idx d) $ slice ds xs
373+
slice (d :: ds) (DynamicSlice _ size :: xs) =
374+
assert (size > d) (OutOfBounds size d) $ map (size ::) (slice ds xs)
375+
slice (_ :: ds) (DynamicIndex _ :: xs) = slice ds xs
377376

378377
||| Slice or index `Tensor` axes. Each axis can be sliced or indexed, and this can be done with
379378
||| either static (`Nat`) or dynamic (scalar `U64`) indices.
@@ -463,47 +462,40 @@ namespace MultiSlice
463462
export
464463
slice :
465464
Primitive dtype =>
466-
(at : MultiSlice shape) ->
465+
(at : List Subset) ->
467466
Tensor shape dtype ->
468-
Tensor (slice at) dtype
467+
{auto 0 shape' : IsRight (slice at shape)} ->
468+
case shape' of Right shape' => Tensor shape' dtype
469469
slice at $ MkTensor x = MkTensor
470-
$ Reshape (mapd size id at) (MultiSlice.slice at)
471-
$ DynamicSlice (dynStarts [] at) (mapd size id at)
472-
$ Slice (mapd start (const 0) at) (mapd stop id at) (replicate (length shape) 1) x
473-
474-
where
475-
mapd : ((Nat -> a) -> {d : Nat} -> SliceOrIndex d -> a) ->
476-
(Nat -> a) ->
477-
{shape : Shape} ->
478-
MultiSlice shape ->
479-
List a
480-
mapd _ dflt {shape} [] = Prelude.map dflt shape
481-
mapd f dflt (x :: xs) = f dflt x :: mapd f dflt xs
482-
483-
start : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
484-
start _ (Slice from _) = from
485-
start _ (Index idx) = idx
486-
start f {d} _ = f d
487-
488-
stop : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
489-
stop _ (Slice _ to) = to
490-
stop _ (Index idx) = S idx
491-
stop f {d} _ = f d
492-
493-
size : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
494-
size _ (Slice {size = size'} _ _) = size'
495-
size _ (Index _) = 1
496-
size _ (DynamicSlice _ size') = size'
497-
size _ (DynamicIndex _) = 1
498-
499-
zero : Expr
500-
zero = FromLiteral {shape = []} {dtype = U64} 0
501-
502-
dynStarts : List Expr -> {shape : _} -> MultiSlice shape -> List Expr
503-
dynStarts idxs {shape} [] = replicate (length shape) zero ++ idxs
504-
dynStarts idxs (DynamicSlice (MkTensor i) _ :: ds) = i :: dynStarts idxs ds
505-
dynStarts idxs (DynamicIndex (MkTensor i) :: ds) = i :: dynStarts idxs ds
506-
dynStarts idxs (_ :: ds) = zero :: dynStarts idxs ds
470+
$ Reshape (map size at) (MultiSlice.slice at)
471+
$ DynamicSlice (map dynStart at ++ replicate (length shape `minus` length at) zero) (map size at)
472+
$ Slice (map start at) (map (uncurry stop) (zip shape at)) (replicate (length shape) 1) x -- zip doesn't account for length difference
473+
474+
where
475+
476+
start : Subset -> Nat
477+
start (Slice from _) = from
478+
start (Index idx) = idx
479+
start _ = 0
480+
481+
stop : Nat -> Subset -> Nat
482+
stop _ (Slice _ to) = to
483+
stop _ (Index idx) = S idx
484+
stop d _ = d
485+
486+
size : Subset -> Nat
487+
size (Slice {size = size'} _ _) = size'
488+
size (Index _) = 1
489+
size (DynamicSlice _ size') = size'
490+
size (DynamicIndex _) = 1
491+
492+
zero : Expr
493+
zero = FromLiteral {shape = []} {dtype = U64} 0
494+
495+
dynStart : Subset -> Expr
496+
dynStart (DynamicSlice (MkTensor i) _) = i
497+
dynStart (DynamicIndex (MkTensor i)) = i
498+
dynStart _ = zero
507499

508500
||| Concatenate two `Tensor`s along the specfied `axis`. For example,
509501
||| `concat 0 (tensor [[1, 2], [3, 4]]) (tensor [[5, 6]])` and

0 commit comments

Comments
 (0)