diff --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h index 396c136392555..d8072b57b6c94 100644 --- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h +++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h @@ -101,7 +101,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter { } template mlir::Type convertPointerLike(A &ty) const { - return mlir::LLVM::LLVMPointerType::get(ty.getContext()); + return mlir::LLVM::LLVMPointerType::get(ty.getContext(), addressSpace); } // convert a front-end kind value to either a std or LLVM IR dialect type @@ -127,6 +127,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter { KindMapping kindMapping; std::unique_ptr specifics; std::unique_ptr tbaaBuilder; + unsigned addressSpace; }; } // namespace fir diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index e07732d57880c..02b0accc9d20f 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -67,8 +67,25 @@ static constexpr unsigned defaultAlign = 8; static constexpr unsigned kAttrPointer = CFI_attribute_pointer; static constexpr unsigned kAttrAllocatable = CFI_attribute_allocatable; -static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context) { - return mlir::LLVM::LLVMPointerType::get(context); +static inline unsigned getAddressSpace(mlir::ModuleOp module) { + if (mlir::Attribute addrSpace = + mlir::DataLayout(module).getAllocaMemorySpace()) + return addrSpace.cast().getUInt(); + + return 0u; +} + +static inline unsigned +getAddressSpace(mlir::ConversionPatternRewriter &rewriter) { + mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp(); + return parentOp + ? ::getAddressSpace(parentOp->getParentOfType()) + : 0u; +} + +static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context, + unsigned addressSpace) { + return mlir::LLVM::LLVMPointerType::get(context, addressSpace); } static inline mlir::Type getI8Type(mlir::MLIRContext *context) { @@ -197,7 +214,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern { mlir::ConversionPatternRewriter &rewriter, int boxValue) const { if (box.getType().isa()) { - auto pty = ::getLlvmPtrType(resultTy.getContext()); + auto pty = + ::getLlvmPtrType(resultTy.getContext(), ::getAddressSpace(rewriter)); auto p = rewriter.create( loc, pty, boxTy.llvm, box, llvm::ArrayRef{0, boxValue}); @@ -278,7 +296,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern { mlir::Value getBaseAddrFromBox(mlir::Location loc, TypePair boxTy, mlir::Value box, mlir::ConversionPatternRewriter &rewriter) const { - mlir::Type resultTy = ::getLlvmPtrType(boxTy.llvm.getContext()); + mlir::Type resultTy = + ::getLlvmPtrType(boxTy.llvm.getContext(), ::getAddressSpace(rewriter)); return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kAddrPosInBox); } @@ -350,7 +369,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern { mlir::ConversionPatternRewriter &rewriter, mlir::Value base, ARGS... args) const { llvm::SmallVector cv = {args...}; - auto llvmPtrTy = ::getLlvmPtrType(ty.getContext()); + auto llvmPtrTy = + ::getLlvmPtrType(ty.getContext(), ::getAddressSpace(rewriter)); return rewriter.create(loc, llvmPtrTy, ty, base, cv); } @@ -378,7 +398,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern { mlir::Block *insertBlock = getBlockForAllocaInsert(parentOp); rewriter.setInsertionPointToStart(insertBlock); auto size = genI32Constant(loc, rewriter, 1); - mlir::Type llvmPtrTy = ::getLlvmPtrType(llvmObjectTy.getContext()); + mlir::Type llvmPtrTy = ::getLlvmPtrType(llvmObjectTy.getContext(), + ::getAddressSpace(rewriter)); auto al = rewriter.create( loc, llvmPtrTy, llvmObjectTy, size, alignment); rewriter.restoreInsertionPoint(thisPt); @@ -532,7 +553,8 @@ struct AllocaOpConversion : public FIROpConversion { size = rewriter.create( loc, ity, size, integerCast(loc, rewriter, ity, operands[i])); } - mlir::Type llvmPtrTy = ::getLlvmPtrType(alloc.getContext()); + mlir::Type llvmPtrTy = + ::getLlvmPtrType(alloc.getContext(), ::getAddressSpace(rewriter)); // NOTE: we used to pass alloc->getAttrs() in the builder for non opaque // pointers! Only propagate pinned and bindc_name to help debugging, but // this should have no functional purpose (and passing the operand segment @@ -1167,9 +1189,10 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) { auto indexType = mlir::IntegerType::get(op.getContext(), 64); return moduleBuilder.create( rewriter.getUnknownLoc(), "malloc", - mlir::LLVM::LLVMFunctionType::get(getLlvmPtrType(op.getContext()), - indexType, - /*isVarArg=*/false)); + mlir::LLVM::LLVMFunctionType::get( + getLlvmPtrType(op.getContext(), ::getAddressSpace(rewriter)), + indexType, + /*isVarArg=*/false)); } /// Helper function for generating the LLVM IR that computes the distance @@ -1189,7 +1212,8 @@ computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType, // *)0 + 1)' trick for all types. The generated instructions are optimized // into constant by the first pass of InstCombine, so it should not be a // performance issue. - auto llvmPtrTy = ::getLlvmPtrType(llvmObjectType.getContext()); + auto llvmPtrTy = ::getLlvmPtrType(llvmObjectType.getContext(), + ::getAddressSpace(rewriter)); auto nullPtr = rewriter.create(loc, llvmPtrTy); auto gep = rewriter.create( loc, llvmPtrTy, llvmObjectType, nullPtr, @@ -1232,7 +1256,8 @@ struct AllocMemOpConversion : public FIROpConversion { loc, ity, size, integerCast(loc, rewriter, ity, opnd)); heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc)); rewriter.replaceOpWithNewOp( - heap, ::getLlvmPtrType(heap.getContext()), size, heap->getAttrs()); + heap, ::getLlvmPtrType(heap.getContext(), ::getAddressSpace(rewriter)), + size, heap->getAttrs()); return mlir::success(); } @@ -1258,9 +1283,10 @@ getFree(fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) { auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext()); return moduleBuilder.create( rewriter.getUnknownLoc(), "free", - mlir::LLVM::LLVMFunctionType::get(voidType, - getLlvmPtrType(op.getContext()), - /*isVarArg=*/false)); + mlir::LLVM::LLVMFunctionType::get( + voidType, + getLlvmPtrType(op.getContext(), ::getAddressSpace(rewriter)), + /*isVarArg=*/false)); } static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) { @@ -1386,7 +1412,8 @@ struct EmboxCommonConversion : public FIROpConversion { return {getCharacterByteSize(loc, rewriter, charTy, lenParams), typeCodeVal}; if (fir::isa_ref_type(boxEleTy)) { - auto ptrTy = ::getLlvmPtrType(rewriter.getContext()); + auto ptrTy = + ::getLlvmPtrType(rewriter.getContext(), ::getAddressSpace(rewriter)); return {genTypeStrideInBytes(loc, i64Ty, rewriter, ptrTy), typeCodeVal}; } if (boxEleTy.isa()) @@ -1447,7 +1474,8 @@ struct EmboxCommonConversion : public FIROpConversion { fir::RecordType recType) const { std::string name = fir::NameUniquer::getTypeDescriptorName(recType.getName()); - mlir::Type llvmPtrTy = ::getLlvmPtrType(mod.getContext()); + mlir::Type llvmPtrTy = + ::getLlvmPtrType(mod.getContext(), ::getAddressSpace(rewriter)); if (auto global = mod.template lookupSymbol(name)) { return rewriter.create(loc, llvmPtrTy, global.getSymName()); @@ -1505,7 +1533,8 @@ struct EmboxCommonConversion : public FIROpConversion { // Unlimited polymorphic type descriptor with no record type. Set // type descriptor address to a clean state. typeDesc = rewriter.create( - loc, ::getLlvmPtrType(mod.getContext())); + loc, ::getLlvmPtrType(mod.getContext(), + ::getAddressSpace(rewriter))); } } else { typeDesc = getTypeDescriptor(mod, rewriter, loc, @@ -1653,7 +1682,8 @@ struct EmboxCommonConversion : public FIROpConversion { loc, outterOffsetTy, gepArgs[0].get(), cast); } } - mlir::Type llvmPtrTy = ::getLlvmPtrType(resultTy.getContext()); + mlir::Type llvmPtrTy = + ::getLlvmPtrType(resultTy.getContext(), ::getAddressSpace(rewriter)); return rewriter.create( loc, llvmPtrTy, llvmBaseObjectType, base, gepArgs); } @@ -2673,7 +2703,8 @@ struct CoordinateOpConversion getBaseAddrFromBox(loc, boxTyPair, boxBaseAddr, rewriter); // Component Type auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy); - mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext()); + mlir::Type llvmPtrTy = + ::getLlvmPtrType(coor.getContext(), ::getAddressSpace(rewriter)); mlir::Type byteTy = ::getI8Type(coor.getContext()); mlir::LLVM::IntegerOverflowFlagsAttr nsw = mlir::LLVM::IntegerOverflowFlagsAttr::get( @@ -2890,7 +2921,8 @@ struct TypeDescOpConversion : public FIROpConversion { auto module = typeDescOp.getOperation()->getParentOfType(); std::string typeDescName = fir::NameUniquer::getTypeDescriptorName(recordType.getName()); - auto llvmPtrTy = ::getLlvmPtrType(typeDescOp.getContext()); + auto llvmPtrTy = + ::getLlvmPtrType(typeDescOp.getContext(), ::getAddressSpace(rewriter)); if (auto global = module.lookupSymbol(typeDescName)) { rewriter.replaceOpWithNewOp( typeDescOp, llvmPtrTy, global.getSymName()); @@ -3678,7 +3710,8 @@ struct BoxOffsetOpConversion : public FIROpConversion { matchAndRewrite(fir::BoxOffsetOp boxOffset, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Type pty = ::getLlvmPtrType(boxOffset.getContext()); + mlir::Type pty = + ::getLlvmPtrType(boxOffset.getContext(), ::getAddressSpace(rewriter)); mlir::Type boxType = fir::unwrapRefType(boxOffset.getBoxRef().getType()); mlir::Type llvmBoxTy = lowerTy().convertBoxTypeAsStruct(mlir::cast(boxType)); diff --git a/flang/lib/Optimizer/CodeGen/DescriptorModel.h b/flang/lib/Optimizer/CodeGen/DescriptorModel.h index ed35caef93014..9f62f60596ac4 100644 --- a/flang/lib/Optimizer/CodeGen/DescriptorModel.h +++ b/flang/lib/Optimizer/CodeGen/DescriptorModel.h @@ -31,7 +31,7 @@ namespace fir { -using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *); +using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *, unsigned); /// Get the LLVM IR dialect model for building a particular C++ type, `T`. template @@ -39,64 +39,72 @@ TypeBuilderFunc getModel(); template <> TypeBuilderFunc getModel() { - return [](mlir::MLIRContext *context) -> mlir::Type { - return mlir::LLVM::LLVMPointerType::get(context); + return [](mlir::MLIRContext *context, unsigned addressSpace) -> mlir::Type { + return mlir::LLVM::LLVMPointerType::get(context, addressSpace); }; } template <> TypeBuilderFunc getModel() { - return [](mlir::MLIRContext *context) -> mlir::Type { - return mlir::IntegerType::get(context, sizeof(unsigned) * 8); - }; + return + [](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(unsigned) * 8); + }; } template <> TypeBuilderFunc getModel() { - return [](mlir::MLIRContext *context) -> mlir::Type { - return mlir::IntegerType::get(context, sizeof(int) * 8); - }; + return + [](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(int) * 8); + }; } template <> TypeBuilderFunc getModel() { - return [](mlir::MLIRContext *context) -> mlir::Type { - return mlir::IntegerType::get(context, sizeof(unsigned long) * 8); - }; + return + [](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(unsigned long) * 8); + }; } template <> TypeBuilderFunc getModel() { - return [](mlir::MLIRContext *context) -> mlir::Type { - return mlir::IntegerType::get(context, sizeof(unsigned long long) * 8); - }; + return + [](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(unsigned long long) * 8); + }; } template <> TypeBuilderFunc getModel() { - return [](mlir::MLIRContext *context) -> mlir::Type { - return mlir::IntegerType::get(context, sizeof(long long) * 8); - }; + return + [](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(long long) * 8); + }; } template <> TypeBuilderFunc getModel() { - return [](mlir::MLIRContext *context) -> mlir::Type { + return [](mlir::MLIRContext *context, + unsigned /*addressSpace*/) -> mlir::Type { return mlir::IntegerType::get(context, sizeof(Fortran::ISO::CFI_rank_t) * 8); }; } template <> TypeBuilderFunc getModel() { - return [](mlir::MLIRContext *context) -> mlir::Type { + return [](mlir::MLIRContext *context, + unsigned /*addressSpace*/) -> mlir::Type { return mlir::IntegerType::get(context, sizeof(Fortran::ISO::CFI_type_t) * 8); }; } template <> TypeBuilderFunc getModel() { - return [](mlir::MLIRContext *context) -> mlir::Type { - return mlir::IntegerType::get(context, sizeof(long) * 8); - }; + return + [](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(long) * 8); + }; } template <> TypeBuilderFunc getModel() { - return [](mlir::MLIRContext *context) -> mlir::Type { - auto indexTy = getModel()(context); + return [](mlir::MLIRContext *context, unsigned addressSpace) -> mlir::Type { + auto indexTy = getModel()(context, addressSpace); return mlir::LLVM::LLVMArrayType::get(indexTy, 3); }; } diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp index 209c586411f41..7404ee6c6244d 100644 --- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp +++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp @@ -35,7 +35,13 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA, getTargetTriple(module), getKindMapping(module), dl)), tbaaBuilder(std::make_unique(module->getContext(), applyTBAA, - forceUnifiedTBAATree)) { + forceUnifiedTBAATree)), + addressSpace(0) { + // Get default alloca address space for the current target + if (mlir::Attribute addrSpace = + mlir::DataLayout(module).getAllocaMemorySpace()) + addressSpace = addrSpace.cast().getUInt(); + LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n"); // Each conversion should return a value of type mlir::Type. @@ -67,7 +73,7 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA, addConversion([&](fir::LenType field) { // Get size of len paramter from the descriptor. return getModel()( - &getContext()); + &getContext(), addressSpace); }); addConversion([&](fir::LogicalType boolTy) { return mlir::IntegerType::get( @@ -214,25 +220,25 @@ mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box, dataDescFields.push_back(eleTy); else dataDescFields.push_back( - mlir::LLVM::LLVMPointerType::get(eleTy.getContext())); + mlir::LLVM::LLVMPointerType::get(eleTy.getContext(), addressSpace)); // elem_len dataDescFields.push_back( - getDescFieldTypeModel()(&getContext())); + getDescFieldTypeModel()(&getContext(), addressSpace)); // version dataDescFields.push_back( - getDescFieldTypeModel()(&getContext())); + getDescFieldTypeModel()(&getContext(), addressSpace)); // rank dataDescFields.push_back( - getDescFieldTypeModel()(&getContext())); + getDescFieldTypeModel()(&getContext(), addressSpace)); // type dataDescFields.push_back( - getDescFieldTypeModel()(&getContext())); + getDescFieldTypeModel()(&getContext(), addressSpace)); // attribute dataDescFields.push_back( - getDescFieldTypeModel()(&getContext())); + getDescFieldTypeModel()(&getContext(), addressSpace)); // f18Addendum - dataDescFields.push_back( - getDescFieldTypeModel()(&getContext())); + dataDescFields.push_back(getDescFieldTypeModel()( + &getContext(), addressSpace)); // [dims] if (rank == unknownRank()) { if (auto seqTy = ele.dyn_cast()) @@ -241,15 +247,17 @@ mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box, rank = 0; } if (rank > 0) { - auto rowTy = getDescFieldTypeModel()(&getContext()); + auto rowTy = + getDescFieldTypeModel()(&getContext(), addressSpace); dataDescFields.push_back(mlir::LLVM::LLVMArrayType::get(rowTy, rank)); } // opt-type-ptr: i8* (see fir.tdesc) if (requiresExtendedDesc(ele) || fir::isUnlimitedPolymorphicType(box)) { dataDescFields.push_back( - getExtendedDescFieldTypeModel()(&getContext())); - auto rowTy = - getExtendedDescFieldTypeModel()(&getContext()); + getExtendedDescFieldTypeModel()(&getContext(), + addressSpace)); + auto rowTy = getExtendedDescFieldTypeModel()( + &getContext(), addressSpace); dataDescFields.push_back(mlir::LLVM::LLVMArrayType::get(rowTy, 1)); if (auto recTy = fir::unwrapSequenceType(ele).dyn_cast()) if (recTy.getNumLenParams() > 0) { @@ -272,13 +280,14 @@ mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box, mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) const { // TODO: send the box type and the converted LLVM structure layout // to tbaaBuilder for proper creation of TBAATypeDescriptorOp. - return mlir::LLVM::LLVMPointerType::get(box.getContext()); + return mlir::LLVM::LLVMPointerType::get(box.getContext(), addressSpace); } // fir.boxproc --> llvm<"{ any*, i8* }"> mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) const { auto funcTy = convertType(boxproc.getEleTy()); - auto voidPtrTy = mlir::LLVM::LLVMPointerType::get(boxproc.getContext()); + auto voidPtrTy = + mlir::LLVM::LLVMPointerType::get(boxproc.getContext(), addressSpace); llvm::SmallVector tuple = {funcTy, voidPtrTy}; return mlir::LLVM::LLVMStructType::getLiteral(boxproc.getContext(), tuple, /*isPacked=*/false); @@ -329,7 +338,7 @@ mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) const { // the f18 object v. class distinction (F2003). mlir::Type LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) const { - return mlir::LLVM::LLVMPointerType::get(ctx); + return mlir::LLVM::LLVMPointerType::get(ctx, addressSpace); } // Relay TBAA tag attachment to TBAABuilder. diff --git a/flang/test/Fir/alloca-addrspace-2.fir b/flang/test/Fir/alloca-addrspace-2.fir new file mode 100644 index 0000000000000..6ba23630dba13 --- /dev/null +++ b/flang/test/Fir/alloca-addrspace-2.fir @@ -0,0 +1,12 @@ +// RUN: fir-opt --fir-to-llvm-ir %s | FileCheck %s +// RUN: tco --fir-to-llvm-ir %s | FileCheck %s + +module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>> } { + // CHECK-LABEL: llvm.func @set_addrspace + func.func @set_addrspace() { + // CHECK: llvm.alloca {{.*}} x i32 + // CHECK-SAME: -> !llvm.ptr<5> + %0 = fir.alloca i32 + return + } +} diff --git a/flang/test/Fir/alloca-addrspace.fir b/flang/test/Fir/alloca-addrspace.fir new file mode 100644 index 0000000000000..a5f3a18355ad3 --- /dev/null +++ b/flang/test/Fir/alloca-addrspace.fir @@ -0,0 +1,12 @@ +// RUN: fir-opt --fir-to-llvm-ir %s | FileCheck %s +// RUN: tco --fir-to-llvm-ir %s | FileCheck %s + +module { + // CHECK-LABEL: llvm.func @default_addrspace + func.func @default_addrspace() { + // CHECK: llvm.alloca {{.*}} x i32 + // CHECK-SAME: -> !llvm.ptr + %0 = fir.alloca i32 + return + } +}