Skip to content

Commit b668837

Browse files
Add PagedMergeSort, a merge sort using O(√n) space (#71)
Co-authored-by: Lilith Orion Hafner <lilithhafner@gmail.com>
1 parent ef22e53 commit b668837

File tree

4 files changed

+306
-7
lines changed

4 files changed

+306
-7
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ The `SortingAlgorithms` package provides three sorting algorithms that can be us
99
- [HeapSort] – an unstable, general purpose, in-place, O(n log n) comparison sort that works by heapifying an array and repeatedly taking the maximal element from the heap.
1010
- [TimSort] – a stable, general purpose, hybrid, O(n log n) comparison sort that adapts to different common patterns of partially ordered input data.
1111
- [CombSort] – an unstable, general purpose, in-place, O(n log n) comparison sort with O(n^2) pathological cases that can attain good efficiency through SIMD instructions and instruction level parallelism on modern hardware.
12+
- [PagedMergeSort] – a stable, general purpose, O(n log n) time and O(sqrt n) space comparison sort.
1213

1314
[HeapSort]: https://en.wikipedia.org/wiki/Heapsort
1415
[TimSort]: https://en.wikipedia.org/wiki/Timsort
1516
[CombSort]: https://en.wikipedia.org/wiki/Comb_sort
17+
[PagedMergeSort]: https://link.springer.com/chapter/10.1007/BFb0016253
1618

1719
## Usage
1820

docs/pagedMerge_130_130.gif

786 KB
Loading

src/SortingAlgorithms.jl

+278-1
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@ using Base.Order
99
import Base.Sort: sort!
1010
import DataStructures: heapify!, percolate_down!
1111

12-
export HeapSort, TimSort, RadixSort, CombSort
12+
export HeapSort, TimSort, RadixSort, CombSort, PagedMergeSort
1313

1414
struct HeapSortAlg <: Algorithm end
1515
struct TimSortAlg <: Algorithm end
1616
struct RadixSortAlg <: Algorithm end
1717
struct CombSortAlg <: Algorithm end
18+
struct PagedMergeSortAlg <: Algorithm end
1819

1920
function maybe_optimize(x::Algorithm)
2021
isdefined(Base.Sort, :InitialOptimizations) ? Base.Sort.InitialOptimizations(x) : x
@@ -51,6 +52,34 @@ Characteristics:
5152
"""
5253
const CombSort = maybe_optimize(CombSortAlg())
5354

55+
"""
56+
PagedMergeSort
57+
58+
Indicates that a sorting function should use the paged merge sort
59+
algorithm. Paged merge sort uses is a merge sort, that uses different
60+
merge routines to achieve stable sorting with a scratch space of size O(√n).
61+
The merge routine for merging large subarrays merges
62+
pages of size O(√n) almost in place, before reordering them using a page table.
63+
At deeper recursion levels, where the scratch space is big enough,
64+
normal merging is used, where one input is copied into the scratch space.
65+
When the scratch space is large enough to hold the complete subarray,
66+
the input is merged interleaved from both sides, which increases performance
67+
for random data.
68+
69+
Characteristics:
70+
- *stable*: does preserve the ordering of elements which
71+
compare equal (e.g. "a" and "A" in a sort of letters which
72+
ignores case).
73+
- *`O(√n)`* auxilary memory usage.
74+
- *`O(n log n)`* garuanteed runtime.
75+
76+
## References
77+
- Dvořák, S., Ďurian, B. (1986). Towards an efficient merging. In: Gruska, J., Rovan, B., Wiedermann,
78+
J. (eds) Mathematical Foundations of Computer Science 1986. MFCS 1986. Lecture Notes in Computer Science, vol 233.
79+
Springer, Berlin, Heidelberg. https://doi.org/10.1007/BFb0016253
80+
- https://max-arbuzov.blogspot.com/2021/10/merge-sort-with-osqrtn-auxiliary-memory.html
81+
"""
82+
const PagedMergeSort = maybe_optimize(PagedMergeSortAlg())
5483

5584
## Heap sort
5685

@@ -652,4 +681,252 @@ else
652681
end
653682
end
654683

684+
###
685+
# PagedMergeSort
686+
###
687+
688+
# unsafe version of copyto!
689+
# as workaround for https://github.com/JuliaLang/julia/issues/50900
690+
function _unsafe_copyto!(dest, doffs, src, soffs, n)
691+
@inbounds for i in 0:n-1
692+
dest[doffs + i] = src[soffs + i]
693+
end
694+
dest
695+
end
696+
697+
function _unsafe_copyto!(dest::Array, doffs, src::Array, soffs, n)
698+
unsafe_copyto!(dest, doffs, src, soffs, n)
699+
end
700+
701+
# merge v[lo:m] and v[m+1:hi] ([A;B]) using scratch[1:1+hi-lo]
702+
# This is faster than merge! but requires twice as much auxiliary memory.
703+
function twoended_merge!(v::AbstractVector{T}, lo::Integer, m::Integer, hi::Integer, o::Ordering, scratch::AbstractVector{T}) where T
704+
@assert lo m hi
705+
@assert abs((m-lo) - (hi-(m+1))) 1 "twoended_merge! only supports balanced merges"
706+
len = 1 + hi - lo
707+
# input array indices
708+
a_lo = lo
709+
a_hi = m
710+
b_lo = m + 1
711+
b_hi = hi
712+
# output array indices
713+
k_lo = 1
714+
k_hi = len
715+
@inbounds begin
716+
# two ended merge
717+
while k_lo <= len ÷ 2
718+
if lt(o, v[b_lo], v[a_lo])
719+
scratch[k_lo] = v[b_lo]
720+
b_lo += 1
721+
else
722+
scratch[k_lo] = v[a_lo]
723+
a_lo += 1
724+
end
725+
k_lo +=1
726+
if !lt(o, v[b_hi], v[a_hi])
727+
scratch[k_hi] = v[b_hi]
728+
b_hi -= 1
729+
else
730+
scratch[k_hi] = v[a_hi]
731+
a_hi -= 1
732+
end
733+
k_hi -=1
734+
end
735+
# if the input length is odd,
736+
# one item remains
737+
if a_lo <= a_hi
738+
scratch[k_lo] = v[a_lo]
739+
elseif b_lo <= b_hi
740+
scratch[k_lo] = v[b_lo]
741+
end
742+
# copy back from t to v
743+
offset = lo-1
744+
for i = 1:len
745+
v[offset+i] = scratch[i]
746+
end
747+
end
748+
end
749+
750+
# core merging loop used throughout PagedMergeSort
751+
Base.@propagate_inbounds function merge!(f::Function,
752+
target::AbstractVector{T}, source_a::AbstractVector{T}, source_b::AbstractVector{T},
753+
o::Ordering, a::Integer, b::Integer, k::Integer) where T
754+
@inbounds while f(a,b,k)
755+
if lt(o, source_b[b], source_a[a])
756+
target[k] = source_b[b]
757+
b += 1
758+
else
759+
target[k] = source_a[a]
760+
a += 1
761+
end
762+
k += 1
763+
end
764+
a,b,k
765+
end
766+
767+
# merge v[lo:m] and v[m+1:hi] using scratch[1:1+m-lo]
768+
# based on Base.Sort MergeSort
769+
function merge!(v::AbstractVector{T}, lo::Integer, m::Integer, hi::Integer, o::Ordering, scratch::AbstractVector{T}) where {T}
770+
_unsafe_copyto!(scratch, 1, v, lo, m - lo + 1)
771+
f(_, b, k) = k < b <= hi
772+
a, b, k = merge!(f, v, scratch, v, o, 1, m + 1, lo)
773+
_unsafe_copyto!(v, k, scratch, a, b - k)
774+
end
775+
776+
struct Pages
777+
current::Int # current page being merged into
778+
currentNumber::Int # number of current page (=index in pageLocations)
779+
nextA::Int # next possible page in A
780+
nextB::Int # next possible page in B
781+
end
782+
783+
next_page_A(pages::Pages) = Pages(pages.nextA, pages.currentNumber + 1, pages.nextA + 1, pages.nextB)
784+
next_page_B(pages::Pages) = Pages(pages.nextB, pages.currentNumber + 1, pages.nextA, pages.nextB + 1)
785+
786+
Base.@propagate_inbounds function next_page!(pageLocations, pages, pagesize, lo, a)
787+
if a > pages.nextA * pagesize + lo
788+
pages = next_page_A(pages)
789+
else
790+
pages = next_page_B(pages)
791+
end
792+
pageLocations[pages.currentNumber] = pages.current
793+
pages
794+
end
795+
796+
Base.@propagate_inbounds function permute_pages!(f, v, pageLocations, page_offset, pagesize, page)
797+
while f(page)
798+
plc = pageLocations[page-3] # plc has data belonging to page
799+
pageLocations[page-3] = page
800+
_unsafe_copyto!(v, page_offset(page) + 1, v, page_offset(plc) + 1, pagesize)
801+
page = plc
802+
end
803+
page
804+
end
805+
806+
# merge v[lo:m] (A) and v[m+1:hi] (B) using scratch[] in O(sqrt(n)) space
807+
function paged_merge!(v::AbstractVector{T}, lo::Integer, m::Integer, hi::Integer, o::Ordering, scratch::AbstractVector{T}, pageLocations::AbstractVector{<:Integer}) where {T}
808+
@assert lo < m < hi
809+
lenA = 1 + m - lo
810+
lenB = hi - m
811+
812+
# this function only supports merges with length(A) <= length(B),
813+
# which is guaranteed by pagedmergesort!
814+
@assert lenA <= lenB
815+
816+
# regular merge if scratch is big enough
817+
lenA <= length(scratch) && return merge!(v, lo, m, hi, o, scratch)
818+
819+
len = lenA + lenB
820+
pagesize = isqrt(len)
821+
nPages = len ÷ pagesize # a partial page at the end does not count
822+
@assert length(scratch) >= 3pagesize
823+
@assert length(pageLocations) >= nPages - 3
824+
825+
@inline page_offset(page) = (page - 1) * pagesize + lo - 1
826+
827+
@inbounds begin
828+
##################
829+
# merge
830+
##################
831+
# merge the first 3 pages into scratch
832+
a, b, _ = merge!((_, _, k) -> k <= 3pagesize, scratch, v, v, o, lo, m + 1, 1)
833+
# initialize variables for merging into pages
834+
pages = Pages(-17, 0, 1, (m - lo) ÷ pagesize + 2) # first argument is unused
835+
# more efficient loop while more than pagesize elements of A and B are remaining
836+
while_condition1(offset) = (_, _, k) -> k <= offset + pagesize
837+
while a < m - pagesize && b < hi - pagesize
838+
pages = next_page!(pageLocations, pages, pagesize, lo, a)
839+
offset = page_offset(pages.current)
840+
a, b, _ = merge!(while_condition1(offset), v, v, v, o, a, b, offset + 1)
841+
end
842+
# merge until either A or B is empty or the last page is reached
843+
k, offset = nothing, nothing
844+
while_condition2(offset) = (a, b, k) -> k <= offset + pagesize && a <= m && b <= hi
845+
while a <= m && b <= hi && pages.currentNumber + 3 < nPages
846+
pages = next_page!(pageLocations, pages, pagesize, lo, a)
847+
offset = page_offset(pages.current)
848+
a, b, k = merge!(while_condition2(offset), v, v, v, o, a, b, offset + 1)
849+
end
850+
# if the last page is reached, merge the remaining elements into the final partial page
851+
if pages.currentNumber + 3 == nPages && a <= m && b <= hi
852+
a, b, k = merge!((a, b, _) -> a <= m && b <= hi, v, v, v, o, a, b, nPages * pagesize + lo)
853+
_unsafe_copyto!(v, k, v, a <= m ? a : b, hi - k + 1)
854+
else
855+
use_a = a <= m
856+
# copy the incomplete page
857+
partial_page_size = offset + pagesize - k + 1
858+
_unsafe_copyto!(v, k, v, use_a ? a : b, partial_page_size)
859+
use_a && (a += partial_page_size)
860+
use_a || (b += partial_page_size)
861+
# copy the remaining full pages
862+
while use_a ? a <= m - pagesize + 1 : b <= hi - pagesize + 1
863+
pages = next_page!(pageLocations, pages, pagesize, lo, a)
864+
offset = page_offset(pages.current)
865+
_unsafe_copyto!(v, offset + 1, v, use_a ? a : b, pagesize)
866+
use_a && (a += pagesize)
867+
use_a || (b += pagesize)
868+
end
869+
# copy the final partial page only if sourcing from A.
870+
# If sourcing from B, it is already in place.
871+
use_a && _unsafe_copyto!(v, hi - m + a, v, a, m - a + 1)
872+
end
873+
874+
##################
875+
# rearrange pages
876+
##################
877+
# copy pages belonging to the 3 permutation chains ending with a page in the scratch space
878+
nextA, nextB = pages.nextA, pages.nextB
879+
880+
for _ in 1:3
881+
page = (nextB > nPages ? (nextA += 1) : (nextB += 1)) - 1
882+
page = permute_pages!(>(3), v, pageLocations, page_offset, pagesize, page)
883+
_unsafe_copyto!(v, page_offset(page) + 1, scratch, (page - 1) * pagesize + 1, pagesize)
884+
end
885+
886+
# copy remaining permutation cycles
887+
for donePageIndex = 5:nPages
888+
# linear scan through pageLocations to make sure no cycle is missed
889+
page = pageLocations[donePageIndex-3]
890+
page == donePageIndex && continue
891+
892+
# copy the data belonging to donePageIndex into scratch
893+
_unsafe_copyto!(scratch, 1, v, page_offset(page) + 1, pagesize)
894+
895+
# follow the cycle starting with the newly freed page
896+
permute_pages!(!=(donePageIndex), v, pageLocations, page_offset, pagesize, page)
897+
_unsafe_copyto!(v, page_offset(donePageIndex) + 1, scratch, 1, pagesize)
898+
end
899+
end
900+
end
901+
902+
# midpoint was added to Base.sort in version 1.4 and later moved to Base
903+
# -> redefine for compatibility with earlier versions
904+
midpoint(lo::Integer, hi::Integer) = lo + ((hi - lo) >>> 0x01)
905+
906+
function pagedmergesort!(v::AbstractVector{T}, lo::Integer, hi::Integer, o::Ordering, scratch::AbstractVector{T}, pageLocations) where {T}
907+
len = hi + 1 - lo
908+
if len <= Base.SMALL_THRESHOLD
909+
return Base.Sort.sort!(v, lo, hi, Base.Sort.InsertionSortAlg(), o)
910+
end
911+
m = midpoint(lo, hi - 1) # hi-1: ensure midpoint is rounded down. OK, because lo < hi is satisfied here
912+
pagedmergesort!(v, lo, m, o, scratch, pageLocations)
913+
pagedmergesort!(v, m + 1, hi, o, scratch, pageLocations)
914+
if len <= length(scratch)
915+
twoended_merge!(v, lo, m, hi, o, scratch)
916+
else
917+
paged_merge!(v, lo, m, hi, o, scratch, pageLocations)
918+
end
919+
return v
920+
end
921+
922+
function sort!(v::AbstractVector, lo::Integer, hi::Integer, ::PagedMergeSortAlg, o::Ordering)
923+
lo >= hi && return v
924+
n = hi + 1 - lo
925+
pagesize = isqrt(n)
926+
scratch = Vector{eltype(v)}(undef, 3pagesize)
927+
nPages = n ÷ pagesize
928+
pageLocations = Vector{Int}(undef, max(0, nPages - 3))
929+
pagedmergesort!(v, lo, hi, o, scratch, pageLocations)
930+
return v
931+
end
655932
end # module

test/runtests.jl

+26-6
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@ using Test
33
using StatsBase
44
using Random
55

6+
stable_algorithms = [TimSort, RadixSort, PagedMergeSort]
7+
unstable_algorithms = [HeapSort, CombSort]
8+
69
a = rand(1:10000, 1000)
710
am = [rand() < .9 ? i : missing for i in a]
811

9-
for alg in [TimSort, HeapSort, RadixSort, CombSort, SortingAlgorithms.TimSortAlg()]
12+
for alg in [stable_algorithms; unstable_algorithms; SortingAlgorithms.TimSortAlg()]
1013
b = sort(a, alg=alg)
1114
@test issorted(b)
1215
ix = sortperm(a, alg=alg)
@@ -94,8 +97,7 @@ for n in [0:10..., 100, 101, 1000, 1001]
9497
invpermute!(c, pi)
9598
@test c == v
9699

97-
# stable algorithms
98-
for alg in [TimSort, RadixSort]
100+
for alg in stable_algorithms
99101
p = sortperm(v, alg=alg, order=ord)
100102
@test p == pi
101103
s = copy(v)
@@ -105,8 +107,7 @@ for n in [0:10..., 100, 101, 1000, 1001]
105107
@test s == v
106108
end
107109

108-
# unstable algorithms
109-
for alg in [HeapSort, CombSort]
110+
for alg in unstable_algorithms
110111
p = sortperm(v, alg=alg, order=ord)
111112
@test isperm(p)
112113
@test v[p] == si
@@ -120,7 +121,7 @@ for n in [0:10..., 100, 101, 1000, 1001]
120121

121122
v = randn_with_nans(n,0.1)
122123
for ord in [Base.Order.Forward, Base.Order.Reverse],
123-
alg in [TimSort, HeapSort, RadixSort, CombSort]
124+
alg in [stable_algorithms; unstable_algorithms]
124125
# test float sorting with NaNs
125126
s = sort(v, alg=alg, order=ord)
126127
@test issorted(s, order=ord)
@@ -138,3 +139,22 @@ for n in [0:10..., 100, 101, 1000, 1001]
138139
@test reinterpret(UInt64,vp) == reinterpret(UInt64,s)
139140
end
140141
end
142+
143+
for T in (Float64, Int, UInt8)
144+
for alg in stable_algorithms
145+
for ord in [Base.Order.By(identity), Base.Order.By(_ -> 0), Base.Order.By(Base.Fix2(÷, 100))]
146+
for n in vcat(0:31, 40:11:100, 110:51:1000)
147+
v = rand(T, n)
148+
# use MergeSort to guarantee stable sorting in Julia 1.0
149+
@test sort(v, alg=alg, order=ord) == sort(v, alg=MergeSort, order=ord)
150+
end
151+
end
152+
end
153+
end
154+
155+
# PagedMergeSort with small input without InitialOptimizations
156+
# (https://github.com/JuliaCollections/SortingAlgorithms.jl/pull/71#discussion_r1292774352)
157+
if isdefined(Base.Sort, :InitialOptimizations)
158+
v = [0,1]
159+
@test sort(v, alg=SortingAlgorithms.PagedMergeSortAlg()) == sort(v, alg=MergeSort)
160+
end

0 commit comments

Comments
 (0)