Skip to content

Commit 636ca3c

Browse files
authored
Fix indexing for non-Real and offset ranges (#212)
* Fix indexing for non-Real and offset ranges * Extra test for diff with offset ranges
1 parent 8d81d29 commit 636ca3c

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

src/cumsum.jl

+4-7
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,13 @@ axes(c::RangeCumsum) = axes(c.range)
1313
==(a::RangeCumsum, b::RangeCumsum) = a.range == b.range
1414
BroadcastStyle(::Type{<:RangeCumsum{<:Any,RR}}) where RR = BroadcastStyle(RR)
1515

16+
_getindex(r::AbstractUnitRange{<:Integer}, k) = k * (2first(r) + k - 1) ÷ 2
17+
Base.@propagate_inbounds _getindex(r::AbstractRange, k) = sum(r[range(firstindex(r), length=k)])
1618

1719
Base.@propagate_inbounds function getindex(c::RangeCumsum{<:Any,<:AbstractRange}, k::Integer)
1820
@boundscheck checkbounds(c, k)
1921
r = c.range
20-
k * (first(r) + r[k]) ÷ 2
21-
end
22-
Base.@propagate_inbounds function getindex(c::RangeCumsum{<:Any,<:AbstractUnitRange}, k::Integer)
23-
@boundscheck checkbounds(c, k)
24-
r = c.range
25-
k * (2first(r) + k - 1) ÷ 2
22+
_getindex(r, k-firstindex(r)+1)
2623
end
2724

2825
Base.@propagate_inbounds getindex(c::RangeCumsum, kr::OneTo) = RangeCumsum(c.range[kr])
@@ -31,7 +28,7 @@ Base.@propagate_inbounds view(c::RangeCumsum, kr::OneTo) = c[kr]
3128

3229
first(r::RangeCumsum) = first(r.range)
3330
last(r::RangeCumsum) = sum(r.range)
34-
diff(r::RangeCumsum) = r.range[2:end]
31+
diff(r::RangeCumsum) = r.range[firstindex(r)+1:end]
3532
isempty(r::RangeCumsum) = isempty(r.range)
3633

3734
union(a::RangeCumsum{<:Any,<:OneTo}, b::RangeCumsum{<:Any,<:OneTo}) =

test/test_cumsum.jl

+14-9
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,21 @@ using ArrayLayouts, Test
55
include("infinitearrays.jl")
66

77
@testset "RangeCumsum" begin
8-
for r in (RangeCumsum(Base.OneTo(5)), RangeCumsum(2:5), RangeCumsum(2:2:6), RangeCumsum(6:-2:1))
9-
@test r == cumsum(r.range)
8+
@testset for p in (Base.OneTo(5), 2:5, 2:2:6, 6:-2:1, -1.0:3.0:5.0, (-1.0:3.0:5.0)*im,
9+
Base.IdentityUnitRange(4:6))
10+
r = RangeCumsum(p)
1011
@test r == r
11-
@test r .+ 1 == cumsum(r.range) .+ 1
12-
@test r[Base.OneTo(3)] == r[1:3]
13-
@test @view(r[Base.OneTo(3)]) === r[Base.OneTo(3)] == r[1:3]
14-
@test @view(r[Base.OneTo(3)]) isa RangeCumsum
15-
@test last(r) == r[end]
16-
@test diff(r) == diff(Vector(r))
17-
@test first(r) == r[1]
12+
if axes(r) isa Base.OneTo
13+
@test r == cumsum(p)
14+
@test r .+ 1 == cumsum(p) .+ 1
15+
@test r[Base.OneTo(3)] == r[1:3]
16+
@test @view(r[Base.OneTo(3)]) === r[Base.OneTo(3)] == r[1:3]
17+
@test @view(r[Base.OneTo(3)]) isa RangeCumsum
18+
@test diff(r) == diff(Vector(r))
19+
end
20+
@test diff(r) == p[firstindex(p)+1:end]
21+
@test last(r) == r[end] == sum(p)
22+
@test first(r) == r[firstindex(r)] == first(p)
1823
end
1924

2025
a,b = RangeCumsum(Base.OneTo(5)), RangeCumsum(Base.OneTo(6))

0 commit comments

Comments
 (0)