From f694f8f6cff0fe5f29820b270833fc5ea80348a8 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Wed, 17 Apr 2024 23:01:19 +0000 Subject: [PATCH] Make emitted egal code more loopy The strategy here is to look at (data, padding) pairs and RLE them into loops, so that repeated adjacent patterns use a loop rather than getting unrolled. On the test case from #54109, this makes compilation essentially instant, while also being faster at runtime (turns out LLVM spends a massive amount of time AND the answer is bad). There's some obvious further enhancements possible here: 1. The `memcmp` constant is small. LLVM has a pass to inline these with better code. However, we don't have it turned on. We should consider vendoring it, though we may want to add some shorcutting to it to avoid having it iterate through each function. 2. This only does one level of sequence matching. It could be recursed to turn things into nested loops. However, this solves the immediate issue, so hopefully it's a useful start. Fixes #54109. --- base/reflection.jl | 19 +++++- src/builtins.c | 6 +- src/cgutils.cpp | 3 +- src/codegen.cpp | 137 ++++++++++++++++++++++++++++++++++++++- src/datatype.c | 40 ++++++++---- src/julia.h | 5 +- test/compiler/codegen.jl | 51 +++++++++++++++ test/core.jl | 8 ++- 8 files changed, 246 insertions(+), 23 deletions(-) diff --git a/base/reflection.jl b/base/reflection.jl index da731cd8fd6a2..d01e1d77e0be9 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -490,8 +490,8 @@ gc_alignment(T::Type) = gc_alignment(Core.sizeof(T)) Base.datatype_haspadding(dt::DataType) -> Bool Return whether the fields of instances of this type are packed in memory, -with no intervening padding bits (defined as bits whose value does not uniquely -impact the egal test when applied to the struct fields). +with no intervening padding bits (defined as bits whose value does not impact +the semantic value of the instance itself). Can be called on any `isconcretetype`. """ function datatype_haspadding(dt::DataType) @@ -501,6 +501,21 @@ function datatype_haspadding(dt::DataType) return flags & 1 == 1 end +""" + Base.datatype_isbitsegal(dt::DataType) -> Bool + +Return whether egality of the (non-padding bits of the) in-memory representation +of an instance of this type implies semantic egality of the instance itself. +This may not be the case if the type contains to other values whose egality is +independent of their identity (e.g. immutable structs, some types, etc.). +""" +function datatype_isbitsegal(dt::DataType) + @_foldable_meta + dt.layout == C_NULL && throw(UndefRefError()) + flags = unsafe_load(convert(Ptr{DataTypeLayout}, dt.layout)).flags + return (flags & (1<<5)) != 0 +end + """ Base.datatype_nfields(dt::DataType) -> UInt32 diff --git a/src/builtins.c b/src/builtins.c index eb9839bcc20f3..dc62d7e4d4eba 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -115,7 +115,7 @@ static int NOINLINE compare_fields(const jl_value_t *a, const jl_value_t *b, jl_ continue; // skip this field (it is #undef) } } - if (!ft->layout->flags.haspadding) { + if (!ft->layout->flags.haspadding && ft->layout->flags.isbitsegal) { if (!bits_equal(ao, bo, ft->layout->size)) return 0; } @@ -284,7 +284,7 @@ inline int jl_egal__bits(const jl_value_t *a JL_MAYBE_UNROOTED, const jl_value_t if (sz == 0) return 1; size_t nf = jl_datatype_nfields(dt); - if (nf == 0 || !dt->layout->flags.haspadding) + if (nf == 0 || (!dt->layout->flags.haspadding && dt->layout->flags.isbitsegal)) return bits_equal(a, b, sz); return compare_fields(a, b, dt); } @@ -394,7 +394,7 @@ static uintptr_t immut_id_(jl_datatype_t *dt, jl_value_t *v, uintptr_t h) JL_NOT if (sz == 0) return ~h; size_t f, nf = jl_datatype_nfields(dt); - if (nf == 0 || (!dt->layout->flags.haspadding && dt->layout->npointers == 0)) { + if (nf == 0 || (!dt->layout->flags.haspadding && dt->layout->flags.isbitsegal && dt->layout->npointers == 0)) { // operate element-wise if there are unused bits inside, // otherwise just take the whole data block at once // a few select pointers (notably symbol) also have special hash values diff --git a/src/cgutils.cpp b/src/cgutils.cpp index 6084df729845c..e40f60997848c 100644 --- a/src/cgutils.cpp +++ b/src/cgutils.cpp @@ -2200,7 +2200,8 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx, } else if (!isboxed) { assert(jl_is_concrete_type(jltype)); - needloop = ((jl_datatype_t*)jltype)->layout->flags.haspadding; + needloop = ((jl_datatype_t*)jltype)->layout->flags.haspadding || + !((jl_datatype_t*)jltype)->layout->flags.isbitsegal; Value *SameType = emit_isa(ctx, cmp, jltype, Twine()).first; if (SameType != ConstantInt::getTrue(ctx.builder.getContext())) { BasicBlock *SkipBB = BasicBlock::Create(ctx.builder.getContext(), "skip_xchg", ctx.f); diff --git a/src/codegen.cpp b/src/codegen.cpp index c8c04356feec4..c3428b3466d46 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -3352,6 +3352,58 @@ static Value *emit_bitsunion_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1, return phi; } +struct egal_desc { + size_t offset; + size_t nrepeats; + size_t data_bytes; + size_t padding_bytes; +}; + +template +static size_t emit_masked_bits_compare(callback &emit_desc, jl_datatype_t *aty, egal_desc ¤t_desc) +{ + // Memcmp, but with masked padding + size_t data_bytes = 0; + size_t padding_bytes = 0; + size_t nfields = jl_datatype_nfields(aty); + size_t total_size = jl_datatype_size(aty); + for (size_t i = 0; i < nfields; ++i) { + size_t offset = jl_field_offset(aty, i); + size_t fend = i == nfields - 1 ? total_size : jl_field_offset(aty, i + 1); + size_t fsz = jl_field_size(aty, i); + jl_datatype_t *fty = (jl_datatype_t*)jl_field_type(aty, i); + if (jl_field_isptr(aty, i) || !fty->layout->flags.haspadding) { + // The field has no internal padding + data_bytes += fsz; + if (offset + fsz == fend) { + // The field has no padding after. Merge this into the current + // comparison range and go to next field. + } else { + padding_bytes = fend - offset - fsz; + // Found padding. Either merge this into the current comparison + // range, or emit the old one and start a new one. + if (current_desc.data_bytes == data_bytes && + current_desc.padding_bytes == padding_bytes) { + // Same as the previous range, just note that down, so we + // emit this as a loop. + current_desc.nrepeats += 1; + } else { + if (current_desc.nrepeats != 0) + emit_desc(current_desc); + current_desc.nrepeats = 1; + current_desc.data_bytes = data_bytes; + current_desc.padding_bytes = padding_bytes; + } + data_bytes = 0; + } + } else { + // The field may have internal padding. Recurse this. + data_bytes += emit_masked_bits_compare(emit_desc, fty, current_desc); + } + } + return data_bytes; +} + static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t arg2) { ++EmittedBitsCompares; @@ -3390,7 +3442,7 @@ static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t a if (at->isAggregateType()) { // Struct or Array jl_datatype_t *sty = (jl_datatype_t*)arg1.typ; size_t sz = jl_datatype_size(sty); - if (sz > 512 && !sty->layout->flags.haspadding) { + if (sz > 512 && !sty->layout->flags.haspadding && sty->layout->flags.isbitsegal) { Value *varg1 = arg1.ispointer() ? data_pointer(ctx, arg1) : value_to_pointer(ctx, arg1).V; Value *varg2 = arg2.ispointer() ? data_pointer(ctx, arg2) : @@ -3427,6 +3479,89 @@ static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t a } return ctx.builder.CreateICmpEQ(answer, ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0)); } + else if (sz > 512 && jl_struct_try_layout(sty) && sty->layout->flags.isbitsegal) { + Type *TInt8 = getInt8Ty(ctx.builder.getContext()); + Type *TpInt8 = getInt8PtrTy(ctx.builder.getContext()); + Type *TInt1 = getInt1Ty(ctx.builder.getContext()); + Value *varg1 = arg1.ispointer() ? data_pointer(ctx, arg1) : + value_to_pointer(ctx, arg1).V; + Value *varg2 = arg2.ispointer() ? data_pointer(ctx, arg2) : + value_to_pointer(ctx, arg2).V; + varg1 = emit_pointer_from_objref(ctx, varg1); + varg2 = emit_pointer_from_objref(ctx, varg2); + varg1 = emit_bitcast(ctx, varg1, TpInt8); + varg2 = emit_bitcast(ctx, varg2, TpInt8); + + Value *answer = nullptr; + auto emit_desc = [&](egal_desc desc) { + Value *ptr1 = varg1; + Value *ptr2 = varg2; + if (desc.offset != 0) { + ptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr1, desc.offset); + ptr2 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr2, desc.offset); + } + + Value *new_ptr1 = ptr1; + Value *endptr1 = nullptr; + BasicBlock *postBB = nullptr; + BasicBlock *loopBB = nullptr; + PHINode *answerphi = nullptr; + if (desc.nrepeats != 1) { + // Set up loop + endptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr1, desc.nrepeats * (desc.data_bytes + desc.padding_bytes));; + + BasicBlock *currBB = ctx.builder.GetInsertBlock(); + loopBB = BasicBlock::Create(ctx.builder.getContext(), "egal_loop", ctx.f); + postBB = BasicBlock::Create(ctx.builder.getContext(), "post", ctx.f); + ctx.builder.CreateBr(loopBB); + + ctx.builder.SetInsertPoint(loopBB); + answerphi = ctx.builder.CreatePHI(TInt1, 2); + answerphi->addIncoming(answer ? answer : ConstantInt::get(TInt1, 1), currBB); + answer = answerphi; + + PHINode *itr1 = ctx.builder.CreatePHI(ptr1->getType(), 2); + PHINode *itr2 = ctx.builder.CreatePHI(ptr2->getType(), 2); + + new_ptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, itr1, desc.data_bytes + desc.padding_bytes); + itr1->addIncoming(ptr1, currBB); + itr1->addIncoming(new_ptr1, loopBB); + + Value *new_ptr2 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, itr2, desc.data_bytes + desc.padding_bytes); + itr2->addIncoming(ptr2, currBB); + itr2->addIncoming(new_ptr2, loopBB); + + ptr1 = itr1; + ptr2 = itr2; + } + + // Emit memcmp. TODO: LLVM has a pass to expand this for additional + // performance. + Value *this_answer = ctx.builder.CreateCall(prepare_call(memcmp_func), + { ptr1, + ptr2, + ConstantInt::get(ctx.types().T_size, desc.data_bytes) }); + this_answer = ctx.builder.CreateICmpEQ(this_answer, ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0)); + answer = answer ? ctx.builder.CreateAnd(answer, this_answer) : this_answer; + if (endptr1) { + answerphi->addIncoming(answer, loopBB); + Value *loopend = ctx.builder.CreateICmpEQ(new_ptr1, endptr1); + ctx.builder.CreateCondBr(loopend, postBB, loopBB); + ctx.builder.SetInsertPoint(postBB); + } + }; + egal_desc current_desc = {0}; + size_t trailing_data_bytes = emit_masked_bits_compare(emit_desc, sty, current_desc); + assert(current_desc.nrepeats != 0); + emit_desc(current_desc); + if (trailing_data_bytes != 0) { + current_desc.nrepeats = 1; + current_desc.data_bytes = trailing_data_bytes; + current_desc.padding_bytes = 0; + emit_desc(current_desc); + } + return answer; + } else { jl_svec_t *types = sty->types; Value *answer = ConstantInt::get(getInt1Ty(ctx.builder.getContext()), 1); diff --git a/src/datatype.c b/src/datatype.c index ee33e75869ee8..abbec420bb617 100644 --- a/src/datatype.c +++ b/src/datatype.c @@ -180,6 +180,7 @@ static jl_datatype_layout_t *jl_get_layout(uint32_t sz, uint32_t npointers, uint32_t alignment, int haspadding, + int isbitsegal, int arrayelem, jl_fielddesc32_t desc[], uint32_t pointers[]) JL_NOTSAFEPOINT @@ -226,6 +227,7 @@ static jl_datatype_layout_t *jl_get_layout(uint32_t sz, flddesc->nfields = nfields; flddesc->alignment = alignment; flddesc->flags.haspadding = haspadding; + flddesc->flags.isbitsegal = isbitsegal; flddesc->flags.fielddesc_type = fielddesc_type; flddesc->flags.arrayelem_isboxed = arrayelem == 1; flddesc->flags.arrayelem_isunion = arrayelem == 2; @@ -504,6 +506,7 @@ void jl_get_genericmemory_layout(jl_datatype_t *st) int isunboxed = jl_islayout_inline(eltype, &elsz, &al) && (kind != (jl_value_t*)jl_atomic_sym || jl_is_datatype(eltype)); int isunion = isunboxed && jl_is_uniontype(eltype); int haspadding = 1; // we may want to eventually actually compute this more precisely + int isbitsegal = 0; int nfields = 0; // aka jl_is_layout_opaque int npointers = 1; int zi; @@ -562,7 +565,7 @@ void jl_get_genericmemory_layout(jl_datatype_t *st) else arrayelem = 0; assert(!st->layout); - st->layout = jl_get_layout(elsz, nfields, npointers, al, haspadding, arrayelem, NULL, pointers); + st->layout = jl_get_layout(elsz, nfields, npointers, al, haspadding, isbitsegal, arrayelem, NULL, pointers); st->zeroinit = zi; //st->has_concrete_subtype = 1; //st->isbitstype = 0; @@ -621,18 +624,17 @@ void jl_compute_field_offsets(jl_datatype_t *st) // if we have no fields, we can trivially skip the rest if (st == jl_symbol_type || st == jl_string_type) { // opaque layout - heap-allocated blob - static const jl_datatype_layout_t opaque_byte_layout = {0, 0, 1, -1, 1, {0}}; + static const jl_datatype_layout_t opaque_byte_layout = {0, 0, 1, -1, 1, { .haspadding = 0, .fielddesc_type=0, .isbitsegal=1, .arrayelem_isboxed=0, .arrayelem_isunion=0 }}; st->layout = &opaque_byte_layout; return; } else if (st == jl_simplevector_type || st == jl_module_type) { - static const jl_datatype_layout_t opaque_ptr_layout = {0, 0, 1, -1, sizeof(void*), {0}}; + static const jl_datatype_layout_t opaque_ptr_layout = {0, 0, 1, -1, sizeof(void*), { .haspadding = 0, .fielddesc_type=0, .isbitsegal=1, .arrayelem_isboxed=0, .arrayelem_isunion=0 }}; st->layout = &opaque_ptr_layout; return; } else { - // reuse the same layout for all singletons - static const jl_datatype_layout_t singleton_layout = {0, 0, 0, -1, 1, {0}}; + static const jl_datatype_layout_t singleton_layout = {0, 0, 0, -1, 1, { .haspadding = 0, .fielddesc_type=0, .isbitsegal=1, .arrayelem_isboxed=0, .arrayelem_isunion=0 }}; st->layout = &singleton_layout; } } @@ -673,6 +675,7 @@ void jl_compute_field_offsets(jl_datatype_t *st) size_t alignm = 1; int zeroinit = 0; int haspadding = 0; + int isbitsegal = 1; int homogeneous = 1; int needlock = 0; uint32_t npointers = 0; @@ -687,19 +690,30 @@ void jl_compute_field_offsets(jl_datatype_t *st) throw_ovf(should_malloc, desc, st, fsz); desc[i].isptr = 0; if (jl_is_uniontype(fld)) { - haspadding = 1; fsz += 1; // selector byte zeroinit = 1; + // TODO: Some unions could be bits comparable. + isbitsegal = 0; } else { uint32_t fld_npointers = ((jl_datatype_t*)fld)->layout->npointers; if (((jl_datatype_t*)fld)->layout->flags.haspadding) haspadding = 1; + if (!((jl_datatype_t*)fld)->layout->flags.isbitsegal) + isbitsegal = 0; if (i >= nfields - st->name->n_uninitialized && fld_npointers && fld_npointers * sizeof(void*) != fsz) { - // field may be undef (may be uninitialized and contains pointer), - // and contains non-pointer fields of non-zero sizes. - haspadding = 1; + // For field types that contain pointers, we allow inlinealloc + // as long as the field type itself is always fully initialized. + // In such a case, we use the first pointer in the inlined field + // as the #undef marker (if it is zero, we treat the whole inline + // struct as #undef). However, we do not zero-initialize the whole + // struct, so the non-pointer parts of the inline allocation may + // be arbitrary, but still need to compare egal (because all #undef) + // representations are egal. Because of this, we cannot bitscompare + // them. + // TODO: Consider zero-initializing the whole struct. + isbitsegal = 0; } if (!zeroinit) zeroinit = ((jl_datatype_t*)fld)->zeroinit; @@ -715,8 +729,7 @@ void jl_compute_field_offsets(jl_datatype_t *st) zeroinit = 1; npointers++; if (!jl_pointer_egal(fld)) { - // this somewhat poorly named flag says whether some of the bits can be non-unique - haspadding = 1; + isbitsegal = 0; } } if (isatomic && fsz > MAX_ATOMIC_SIZE) @@ -777,7 +790,7 @@ void jl_compute_field_offsets(jl_datatype_t *st) } } assert(ptr_i == npointers); - st->layout = jl_get_layout(sz, nfields, npointers, alignm, haspadding, 0, desc, pointers); + st->layout = jl_get_layout(sz, nfields, npointers, alignm, haspadding, isbitsegal, 0, desc, pointers); if (should_malloc) { free(desc); if (npointers) @@ -931,7 +944,7 @@ JL_DLLEXPORT jl_datatype_t *jl_new_primitivetype(jl_value_t *name, jl_module_t * bt->ismutationfree = 1; bt->isidentityfree = 1; bt->isbitstype = (parameters == jl_emptysvec); - bt->layout = jl_get_layout(nbytes, 0, 0, alignm, 0, 0, NULL, NULL); + bt->layout = jl_get_layout(nbytes, 0, 0, alignm, 0, 1, 0, NULL, NULL); bt->instance = NULL; return bt; } @@ -954,6 +967,7 @@ JL_DLLEXPORT jl_datatype_t * jl_new_foreign_type(jl_sym_t *name, layout->alignment = sizeof(void *); layout->npointers = haspointers; layout->flags.haspadding = 1; + layout->flags.isbitsegal = 0; layout->flags.fielddesc_type = 3; layout->flags.padding = 0; layout->flags.arrayelem_isboxed = 0; diff --git a/src/julia.h b/src/julia.h index e90e9653d2c85..d75a93185d727 100644 --- a/src/julia.h +++ b/src/julia.h @@ -574,7 +574,10 @@ typedef struct { // metadata bit only for GenericMemory eltype layout uint16_t arrayelem_isboxed : 1; uint16_t arrayelem_isunion : 1; - uint16_t padding : 11; + // If set, this type's egality can be determined entirely by comparing + // the non-padding bits of this datatype. + uint16_t isbitsegal : 1; + uint16_t padding : 10; } flags; // union { // jl_fielddesc8_t field8[nfields]; diff --git a/test/compiler/codegen.jl b/test/compiler/codegen.jl index e66338d460f53..ff3cf50f45d21 100644 --- a/test/compiler/codegen.jl +++ b/test/compiler/codegen.jl @@ -887,3 +887,54 @@ end ex54166 = Union{Missing, Int64}[missing -2; missing -2]; dims54166 = (1,2) @test (minimum(ex54166; dims=dims54166)[1] === missing) + +# #54109 - Excessive LLVM time for egal +struct DefaultOr54109{T} + x::T + default::Bool +end + +@eval struct Torture1_54109 + $((Expr(:(::), Symbol("x$i"), DefaultOr54109{Float64}) for i = 1:897)...) +end +Torture1_54109() = Torture1_54109((DefaultOr54109(1.0, false) for i = 1:897)...) + +@eval struct Torture2_54109 + $((Expr(:(::), Symbol("x$i"), DefaultOr54109{Float64}) for i = 1:400)...) + $((Expr(:(::), Symbol("x$(i+400)"), DefaultOr54109{Int16}) for i = 1:400)...) +end +Torture2_54109() = Torture2_54109((DefaultOr54109(1.0, false) for i = 1:400)..., (DefaultOr54109(Int16(1), false) for i = 1:400)...) + +@noinline egal_any54109(x, @nospecialize(y::Any)) = x === Base.compilerbarrier(:type, y) + +let ir1 = get_llvm(egal_any54109, Tuple{Torture1_54109, Any}), + ir2 = get_llvm(egal_any54109, Tuple{Torture2_54109, Any}) + + # We can't really do timing on CI, so instead, let's look at the length of + # the optimized IR. The original version had tens of thousands of lines and + # was slower, so just check here that we only have < 500 lines. If somebody, + # implements a better comparison that's larger than that, just re-benchmark + # this and adjust the threshold. + + @test count(==('\n'), ir1) < 500 + @test count(==('\n'), ir2) < 500 +end + +## Regression test for egal of a struct of this size without padding, but with +## non-bitsegal, to make sure that it doesn't accidentally go down the accelerated +## path. +@eval struct BigStructAnyInt + $((Expr(:(::), Symbol("x$i"), Pair{Any, Int}) for i = 1:33)...) +end +BigStructAnyInt() = BigStructAnyInt((Union{Base.inferencebarrier(Float64), Int}=>i for i = 1:33)...) +@test egal_any54109(BigStructAnyInt(), BigStructAnyInt()) + +## For completeness, also test correctness, since we don't have a lot of +## large-struct tests. + +# The two allocations of the same struct will likely have different padding, +# we want to make sure we find them egal anyway - a naive memcmp would +# accidentally look at it. +@test egal_any54109(Torture1_54109(), Torture1_54109()) +@test egal_any54109(Torture2_54109(), Torture2_54109()) +@test !egal_any54109(Torture1_54109(), Torture1_54109((DefaultOr54109(2.0, false) for i = 1:897)...)) diff --git a/test/core.jl b/test/core.jl index ed1e1fc6757d2..11071896edea9 100644 --- a/test/core.jl +++ b/test/core.jl @@ -7729,13 +7729,17 @@ struct ContainsPointerNopadding{T} end @test !Base.datatype_haspadding(PointerNopadding{Symbol}) +@test Base.datatype_isbitsegal(PointerNopadding{Int}) @test !Base.datatype_haspadding(PointerNopadding{Int}) +@test Base.datatype_isbitsegal(PointerNopadding{Int}) # Sanity check to make sure the meaning of haspadding didn't change. -@test Base.datatype_haspadding(PointerNopadding{Any}) +@test !Base.datatype_haspadding(PointerNopadding{Any}) +@test !Base.datatype_isbitsegal(PointerNopadding{Any}) @test !Base.datatype_haspadding(Tuple{PointerNopadding{Symbol}}) @test !Base.datatype_haspadding(Tuple{PointerNopadding{Int}}) @test !Base.datatype_haspadding(ContainsPointerNopadding{Symbol}) -@test Base.datatype_haspadding(ContainsPointerNopadding{Int}) +@test !Base.datatype_haspadding(ContainsPointerNopadding{Int}) +@test !Base.datatype_isbitsegal(ContainsPointerNopadding{Int}) # Test the codegen optimized version as well as the unoptimized version of `jl_egal` @noinline unopt_jl_egal(@nospecialize(a), @nospecialize(b)) =