Skip to content

Commit 2b11ded

Browse files
committed
improve slice type inference with decidability
1 parent d319652 commit 2b11ded

File tree

1 file changed

+66
-74
lines changed

1 file changed

+66
-74
lines changed

spidr/src/Tensor.idr

Lines changed: 66 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -311,68 +311,67 @@ squeeze :
311311
Tensor to dtype
312312
squeeze $ MkTensor {shape} x = MkTensor $ Reshape shape to x
313313

314-
||| A `SliceOrIndex d` is a valid slice or index into a dimension of size `d`. See `slice` for
314+
||| A `Subset d` is a valid slice or index into a dimension of size `d`. See `slice` for
315315
||| details.
316316
export
317-
data SliceOrIndex : Nat -> Type where
318-
Slice :
319-
(from, to : Nat) ->
320-
{size : _} ->
321-
{auto 0 fromTo : from + size = to} ->
322-
{auto 0 inDim : LTE to d} ->
323-
SliceOrIndex d
324-
Index : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
325-
DynamicSlice : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
326-
DynamicIndex : Tensor [] U64 -> SliceOrIndex d
317+
data Subset : Type where
318+
All : Subset
319+
Slice : (from, to : Nat) -> Subset
320+
Index : (idx : Nat) -> Subset
321+
DynamicSlice : Tensor [] U64 -> (size : Nat) -> Subset
322+
DynamicIndex : Tensor [] U64 -> Subset
327323

328324
||| Index at `idx`. See `slice` for details.
329325
public export
330-
at : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
326+
at : (idx : Nat) -> Subset
331327
at = Index
332328

333329
namespace Dynamic
334330
||| Index at the specified index. See `slice` for details.
335331
public export
336-
at : Tensor [] U64 -> SliceOrIndex d
332+
at : Tensor [] U64 -> Subset
337333
at = DynamicIndex
338334

339335
||| Slice from `from` (inclusive) to `to` (exclusive). See `slice` for details.
340336
public export
341-
(.to) :
342-
(from, to : Nat) ->
343-
{size : _} ->
344-
{auto 0 fromTo : from + size = to} ->
345-
{auto 0 inDim : LTE to d} ->
346-
SliceOrIndex d
337+
(.to) : (from, to : Nat) -> {auto 0 ordered : from `LT` to} -> Subset
347338
(.to) = Slice
348339

349340
||| Slice `size` elements starting at the specified scalar `U64` index. See `slice` for details.
350341
public export
351-
(.size) : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
342+
(.size) : Tensor [] U64 -> (size : Nat) -> Subset
352343
(.size) = DynamicSlice
353344

354345
||| Slice across all indices along an axis. See `slice` for details.
355346
public export
356-
all : {d : _} -> SliceOrIndex d
357-
all = Slice 0 @{%search} @{reflexive {ty = Nat}} d
347+
all : Subset
348+
all = All
358349

359-
||| A `MultiSlice shape` is a valid multi-dimensional slice into a tensor with shape `shape`.
360-
||| See `slice` for details.
361-
public export
362-
data MultiSlice : Shape -> Type where
363-
Nil : MultiSlice ds
364-
(::) : SliceOrIndex d -> MultiSlice ds -> MultiSlice (d :: ds)
350+
public export -- is there a stdlib version of this?
351+
assert : Bool -> e -> Either e a -> Either e a
352+
assert True _ either = either
353+
assert False e _ = Left e
354+
355+
namespace Subset
356+
public export
357+
data InvalidSubsetError =
358+
||| The number of dimensions requested and found
359+
OutOfBounds Nat Nat
360+
361+
||| The number of unaccounted-for axes
362+
TooManyAxes Nat
365363

366-
namespace MultiSlice
367364
||| The shape of a tensor produced by slicing with the specified multi-dimensional slice. See
368365
||| `Tensor.slice` for details.
369366
public export
370-
slice : {shape : _} -> MultiSlice shape -> Shape
371-
slice {shape} [] = shape
372-
slice {shape = (_ :: _)} (Slice {size} _ _ :: xs) = size :: slice xs
373-
slice {shape = (_ :: _)} (Index _ :: xs) = slice xs
374-
slice {shape = (_ :: _)} (DynamicSlice _ size :: xs) = size :: slice xs
375-
slice {shape = (_ :: _)} (DynamicIndex _ :: xs) = slice xs
367+
slice : Shape -> List Subset -> Either InvalidSubsetError Shape
368+
slice [] at@(_ :: _) = Left TooManyAxes (length at)
369+
slice ds [] = ds
370+
slice (d :: ds) (Slice from to :: xs) = assert (to > d) (OutOfBounds to d) $ map (size ::) (slice ds xs)
371+
slice (d :: ds) (Index idx :: xs) = assert (idx >= d) (OutOfBounds idx d) $ slice ds xs
372+
slice (d :: ds) (DynamicSlice _ size :: xs) =
373+
assert (size > d) (OutOfBounds size d) $ map (size ::) (slice ds xs)
374+
slice (_ :: ds) (DynamicIndex _ :: xs) = slice ds xs
376375

377376
||| Slice or index `Tensor` axes. Each axis can be sliced or indexed, and this can be done with
378377
||| either static (`Nat`) or dynamic (scalar `U64`) indices.
@@ -462,47 +461,40 @@ namespace MultiSlice
462461
export
463462
slice :
464463
Primitive dtype =>
465-
(at : MultiSlice shape) ->
464+
(at : List Subset) ->
466465
Tensor shape dtype ->
467-
Tensor (slice at) dtype
466+
{auto 0 shape' : IsRight (slice at shape)} ->
467+
case shape' of Right shape' => Tensor shape' dtype
468468
slice at $ MkTensor x = MkTensor
469-
$ Reshape (mapd size id at) (MultiSlice.slice at)
470-
$ DynamicSlice (dynStarts [] at) (mapd size id at)
471-
$ Slice (mapd start (const 0) at) (mapd stop id at) (replicate (length shape) 1) x
472-
473-
where
474-
mapd : ((Nat -> a) -> {d : Nat} -> SliceOrIndex d -> a) ->
475-
(Nat -> a) ->
476-
{shape : Shape} ->
477-
MultiSlice shape ->
478-
List a
479-
mapd _ dflt {shape} [] = Prelude.map dflt shape
480-
mapd f dflt (x :: xs) = f dflt x :: mapd f dflt xs
481-
482-
start : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
483-
start _ (Slice from _) = from
484-
start _ (Index idx) = idx
485-
start f {d} _ = f d
486-
487-
stop : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
488-
stop _ (Slice _ to) = to
489-
stop _ (Index idx) = S idx
490-
stop f {d} _ = f d
491-
492-
size : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
493-
size _ (Slice {size = size'} _ _) = size'
494-
size _ (Index _) = 1
495-
size _ (DynamicSlice _ size') = size'
496-
size _ (DynamicIndex _) = 1
497-
498-
zero : Expr
499-
zero = FromLiteral {shape = []} {dtype = U64} 0
500-
501-
dynStarts : List Expr -> {shape : _} -> MultiSlice shape -> List Expr
502-
dynStarts idxs {shape} [] = replicate (length shape) zero ++ idxs
503-
dynStarts idxs (DynamicSlice (MkTensor i) _ :: ds) = i :: dynStarts idxs ds
504-
dynStarts idxs (DynamicIndex (MkTensor i) :: ds) = i :: dynStarts idxs ds
505-
dynStarts idxs (_ :: ds) = zero :: dynStarts idxs ds
469+
$ Reshape (map size at) (MultiSlice.slice at)
470+
$ DynamicSlice (map dynStart at ++ replicate (length shape `minus` length at) zero) (map size at)
471+
$ Slice (map start at) (map (uncurry stop) (zip shape at)) (replicate (length shape) 1) x -- zip doesn't account for length difference
472+
473+
where
474+
475+
start : Subset -> Nat
476+
start (Slice from _) = from
477+
start (Index idx) = idx
478+
start _ = 0
479+
480+
stop : Nat -> Subset -> Nat
481+
stop _ (Slice _ to) = to
482+
stop _ (Index idx) = S idx
483+
stop d _ = d
484+
485+
size : Subset -> Nat
486+
size (Slice {size = size'} _ _) = size'
487+
size (Index _) = 1
488+
size (DynamicSlice _ size') = size'
489+
size (DynamicIndex _) = 1
490+
491+
zero : Expr
492+
zero = FromLiteral {shape = []} {dtype = U64} 0
493+
494+
dynStart : Subset -> Expr
495+
dynStart (DynamicSlice (MkTensor i) _) = i
496+
dynStart (DynamicIndex (MkTensor i)) = i
497+
dynStart _ = zero
506498

507499
||| Concatenate two `Tensor`s along the specfied `axis`. For example,
508500
||| `concat 0 (tensor [[1, 2], [3, 4]]) (tensor [[5, 6]])` and

0 commit comments

Comments
 (0)