Skip to content

[Flang] Set address space during FIR pointer-like types lowering #69599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flang/include/flang/Optimizer/CodeGen/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
}

template <typename A> mlir::Type convertPointerLike(A &ty) const {
return mlir::LLVM::LLVMPointerType::get(ty.getContext());
return mlir::LLVM::LLVMPointerType::get(ty.getContext(), addressSpace);
}

// convert a front-end kind value to either a std or LLVM IR dialect type
Expand All @@ -127,6 +127,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
KindMapping kindMapping;
std::unique_ptr<CodeGenSpecifics> specifics;
std::unique_ptr<TBAABuilder> tbaaBuilder;
unsigned addressSpace;
};

} // namespace fir
Expand Down
77 changes: 55 additions & 22 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,25 @@ static constexpr unsigned defaultAlign = 8;
static constexpr unsigned kAttrPointer = CFI_attribute_pointer;
static constexpr unsigned kAttrAllocatable = CFI_attribute_allocatable;

static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context) {
return mlir::LLVM::LLVMPointerType::get(context);
static inline unsigned getAddressSpace(mlir::ModuleOp module) {
if (mlir::Attribute addrSpace =
mlir::DataLayout(module).getAllocaMemorySpace())
return addrSpace.cast<mlir::IntegerAttr>().getUInt();

return 0u;
}

static inline unsigned
getAddressSpace(mlir::ConversionPatternRewriter &rewriter) {
mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
return parentOp
? ::getAddressSpace(parentOp->getParentOfType<mlir::ModuleOp>())
: 0u;
}

static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context,
unsigned addressSpace) {
return mlir::LLVM::LLVMPointerType::get(context, addressSpace);
}

static inline mlir::Type getI8Type(mlir::MLIRContext *context) {
Expand Down Expand Up @@ -197,7 +214,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
mlir::ConversionPatternRewriter &rewriter,
int boxValue) const {
if (box.getType().isa<mlir::LLVM::LLVMPointerType>()) {
auto pty = ::getLlvmPtrType(resultTy.getContext());
auto pty =
::getLlvmPtrType(resultTy.getContext(), ::getAddressSpace(rewriter));
auto p = rewriter.create<mlir::LLVM::GEPOp>(
loc, pty, boxTy.llvm, box,
llvm::ArrayRef<mlir::LLVM::GEPArg>{0, boxValue});
Expand Down Expand Up @@ -278,7 +296,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
mlir::Value
getBaseAddrFromBox(mlir::Location loc, TypePair boxTy, mlir::Value box,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Type resultTy = ::getLlvmPtrType(boxTy.llvm.getContext());
mlir::Type resultTy =
::getLlvmPtrType(boxTy.llvm.getContext(), ::getAddressSpace(rewriter));
return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kAddrPosInBox);
}

Expand Down Expand Up @@ -350,7 +369,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
mlir::ConversionPatternRewriter &rewriter,
mlir::Value base, ARGS... args) const {
llvm::SmallVector<mlir::LLVM::GEPArg> cv = {args...};
auto llvmPtrTy = ::getLlvmPtrType(ty.getContext());
auto llvmPtrTy =
::getLlvmPtrType(ty.getContext(), ::getAddressSpace(rewriter));
return rewriter.create<mlir::LLVM::GEPOp>(loc, llvmPtrTy, ty, base, cv);
}

Expand Down Expand Up @@ -378,7 +398,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
mlir::Block *insertBlock = getBlockForAllocaInsert(parentOp);
rewriter.setInsertionPointToStart(insertBlock);
auto size = genI32Constant(loc, rewriter, 1);
mlir::Type llvmPtrTy = ::getLlvmPtrType(llvmObjectTy.getContext());
mlir::Type llvmPtrTy = ::getLlvmPtrType(llvmObjectTy.getContext(),
::getAddressSpace(rewriter));
auto al = rewriter.create<mlir::LLVM::AllocaOp>(
loc, llvmPtrTy, llvmObjectTy, size, alignment);
rewriter.restoreInsertionPoint(thisPt);
Expand Down Expand Up @@ -532,7 +553,8 @@ struct AllocaOpConversion : public FIROpConversion<fir::AllocaOp> {
size = rewriter.create<mlir::LLVM::MulOp>(
loc, ity, size, integerCast(loc, rewriter, ity, operands[i]));
}
mlir::Type llvmPtrTy = ::getLlvmPtrType(alloc.getContext());
mlir::Type llvmPtrTy =
::getLlvmPtrType(alloc.getContext(), ::getAddressSpace(rewriter));
// NOTE: we used to pass alloc->getAttrs() in the builder for non opaque
// pointers! Only propagate pinned and bindc_name to help debugging, but
// this should have no functional purpose (and passing the operand segment
Expand Down Expand Up @@ -1167,9 +1189,10 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
auto indexType = mlir::IntegerType::get(op.getContext(), 64);
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), "malloc",
mlir::LLVM::LLVMFunctionType::get(getLlvmPtrType(op.getContext()),
indexType,
/*isVarArg=*/false));
mlir::LLVM::LLVMFunctionType::get(
getLlvmPtrType(op.getContext(), ::getAddressSpace(rewriter)),
indexType,
/*isVarArg=*/false));
}

/// Helper function for generating the LLVM IR that computes the distance
Expand All @@ -1189,7 +1212,8 @@ computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
// *)0 + 1)' trick for all types. The generated instructions are optimized
// into constant by the first pass of InstCombine, so it should not be a
// performance issue.
auto llvmPtrTy = ::getLlvmPtrType(llvmObjectType.getContext());
auto llvmPtrTy = ::getLlvmPtrType(llvmObjectType.getContext(),
::getAddressSpace(rewriter));
auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPtrTy);
auto gep = rewriter.create<mlir::LLVM::GEPOp>(
loc, llvmPtrTy, llvmObjectType, nullPtr,
Expand Down Expand Up @@ -1232,7 +1256,8 @@ struct AllocMemOpConversion : public FIROpConversion<fir::AllocMemOp> {
loc, ity, size, integerCast(loc, rewriter, ity, opnd));
heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
heap, ::getLlvmPtrType(heap.getContext()), size, heap->getAttrs());
heap, ::getLlvmPtrType(heap.getContext(), ::getAddressSpace(rewriter)),
size, heap->getAttrs());
return mlir::success();
}

Expand All @@ -1258,9 +1283,10 @@ getFree(fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) {
auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext());
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), "free",
mlir::LLVM::LLVMFunctionType::get(voidType,
getLlvmPtrType(op.getContext()),
/*isVarArg=*/false));
mlir::LLVM::LLVMFunctionType::get(
voidType,
getLlvmPtrType(op.getContext(), ::getAddressSpace(rewriter)),
/*isVarArg=*/false));
}

static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
Expand Down Expand Up @@ -1386,7 +1412,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
return {getCharacterByteSize(loc, rewriter, charTy, lenParams),
typeCodeVal};
if (fir::isa_ref_type(boxEleTy)) {
auto ptrTy = ::getLlvmPtrType(rewriter.getContext());
auto ptrTy =
::getLlvmPtrType(rewriter.getContext(), ::getAddressSpace(rewriter));
return {genTypeStrideInBytes(loc, i64Ty, rewriter, ptrTy), typeCodeVal};
}
if (boxEleTy.isa<fir::RecordType>())
Expand Down Expand Up @@ -1447,7 +1474,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
fir::RecordType recType) const {
std::string name =
fir::NameUniquer::getTypeDescriptorName(recType.getName());
mlir::Type llvmPtrTy = ::getLlvmPtrType(mod.getContext());
mlir::Type llvmPtrTy =
::getLlvmPtrType(mod.getContext(), ::getAddressSpace(rewriter));
if (auto global = mod.template lookupSymbol<fir::GlobalOp>(name)) {
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
global.getSymName());
Expand Down Expand Up @@ -1505,7 +1533,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
// Unlimited polymorphic type descriptor with no record type. Set
// type descriptor address to a clean state.
typeDesc = rewriter.create<mlir::LLVM::ZeroOp>(
loc, ::getLlvmPtrType(mod.getContext()));
loc, ::getLlvmPtrType(mod.getContext(),
::getAddressSpace(rewriter)));
}
} else {
typeDesc = getTypeDescriptor(mod, rewriter, loc,
Expand Down Expand Up @@ -1653,7 +1682,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
loc, outterOffsetTy, gepArgs[0].get<mlir::Value>(), cast);
}
}
mlir::Type llvmPtrTy = ::getLlvmPtrType(resultTy.getContext());
mlir::Type llvmPtrTy =
::getLlvmPtrType(resultTy.getContext(), ::getAddressSpace(rewriter));
return rewriter.create<mlir::LLVM::GEPOp>(
loc, llvmPtrTy, llvmBaseObjectType, base, gepArgs);
}
Expand Down Expand Up @@ -2673,7 +2703,8 @@ struct CoordinateOpConversion
getBaseAddrFromBox(loc, boxTyPair, boxBaseAddr, rewriter);
// Component Type
auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy);
mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext());
mlir::Type llvmPtrTy =
::getLlvmPtrType(coor.getContext(), ::getAddressSpace(rewriter));
mlir::Type byteTy = ::getI8Type(coor.getContext());
mlir::LLVM::IntegerOverflowFlagsAttr nsw =
mlir::LLVM::IntegerOverflowFlagsAttr::get(
Expand Down Expand Up @@ -2890,7 +2921,8 @@ struct TypeDescOpConversion : public FIROpConversion<fir::TypeDescOp> {
auto module = typeDescOp.getOperation()->getParentOfType<mlir::ModuleOp>();
std::string typeDescName =
fir::NameUniquer::getTypeDescriptorName(recordType.getName());
auto llvmPtrTy = ::getLlvmPtrType(typeDescOp.getContext());
auto llvmPtrTy =
::getLlvmPtrType(typeDescOp.getContext(), ::getAddressSpace(rewriter));
if (auto global = module.lookupSymbol<mlir::LLVM::GlobalOp>(typeDescName)) {
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
typeDescOp, llvmPtrTy, global.getSymName());
Expand Down Expand Up @@ -3678,7 +3710,8 @@ struct BoxOffsetOpConversion : public FIROpConversion<fir::BoxOffsetOp> {
matchAndRewrite(fir::BoxOffsetOp boxOffset, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

mlir::Type pty = ::getLlvmPtrType(boxOffset.getContext());
mlir::Type pty =
::getLlvmPtrType(boxOffset.getContext(), ::getAddressSpace(rewriter));
mlir::Type boxType = fir::unwrapRefType(boxOffset.getBoxRef().getType());
mlir::Type llvmBoxTy =
lowerTy().convertBoxTypeAsStruct(mlir::cast<fir::BaseBoxType>(boxType));
Expand Down
58 changes: 33 additions & 25 deletions flang/lib/Optimizer/CodeGen/DescriptorModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,72 +31,80 @@

namespace fir {

using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *);
using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *, unsigned);

/// Get the LLVM IR dialect model for building a particular C++ type, `T`.
template <typename T>
TypeBuilderFunc getModel();

template <>
TypeBuilderFunc getModel<void *>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::LLVM::LLVMPointerType::get(context);
return [](mlir::MLIRContext *context, unsigned addressSpace) -> mlir::Type {
return mlir::LLVM::LLVMPointerType::get(context, addressSpace);
};
}
template <>
TypeBuilderFunc getModel<unsigned>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(context, sizeof(unsigned) * 8);
};
return
[](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
return mlir::IntegerType::get(context, sizeof(unsigned) * 8);
};
}
template <>
TypeBuilderFunc getModel<int>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(context, sizeof(int) * 8);
};
return
[](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
return mlir::IntegerType::get(context, sizeof(int) * 8);
};
}
template <>
TypeBuilderFunc getModel<unsigned long>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(context, sizeof(unsigned long) * 8);
};
return
[](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
return mlir::IntegerType::get(context, sizeof(unsigned long) * 8);
};
}
template <>
TypeBuilderFunc getModel<unsigned long long>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(context, sizeof(unsigned long long) * 8);
};
return
[](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
return mlir::IntegerType::get(context, sizeof(unsigned long long) * 8);
};
}
template <>
TypeBuilderFunc getModel<long long>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(context, sizeof(long long) * 8);
};
return
[](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
return mlir::IntegerType::get(context, sizeof(long long) * 8);
};
}
template <>
TypeBuilderFunc getModel<Fortran::ISO::CFI_rank_t>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return [](mlir::MLIRContext *context,
unsigned /*addressSpace*/) -> mlir::Type {
return mlir::IntegerType::get(context,
sizeof(Fortran::ISO::CFI_rank_t) * 8);
};
}
template <>
TypeBuilderFunc getModel<Fortran::ISO::CFI_type_t>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return [](mlir::MLIRContext *context,
unsigned /*addressSpace*/) -> mlir::Type {
return mlir::IntegerType::get(context,
sizeof(Fortran::ISO::CFI_type_t) * 8);
};
}
template <>
TypeBuilderFunc getModel<long>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(context, sizeof(long) * 8);
};
return
[](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
return mlir::IntegerType::get(context, sizeof(long) * 8);
};
}
template <>
TypeBuilderFunc getModel<Fortran::ISO::CFI_dim_t>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
auto indexTy = getModel<Fortran::ISO::CFI_index_t>()(context);
return [](mlir::MLIRContext *context, unsigned addressSpace) -> mlir::Type {
auto indexTy = getModel<Fortran::ISO::CFI_index_t>()(context, addressSpace);
return mlir::LLVM::LLVMArrayType::get(indexTy, 3);
};
}
Expand Down
Loading