Skip to content

Commit 7f68601

Browse files
committed
Set address space in all relevant places
1 parent 42aca92 commit 7f68601

File tree

5 files changed

+109
-65
lines changed

5 files changed

+109
-65
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

+55-22
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,25 @@ static constexpr unsigned defaultAlign = 8;
6767
static constexpr unsigned kAttrPointer = CFI_attribute_pointer;
6868
static constexpr unsigned kAttrAllocatable = CFI_attribute_allocatable;
6969

70-
static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context) {
71-
return mlir::LLVM::LLVMPointerType::get(context);
70+
static inline unsigned getAddressSpace(mlir::ModuleOp module) {
71+
if (mlir::Attribute addrSpace =
72+
mlir::DataLayout(module).getAllocaMemorySpace())
73+
return addrSpace.cast<mlir::IntegerAttr>().getUInt();
74+
75+
return 0u;
76+
}
77+
78+
static inline unsigned
79+
getAddressSpace(mlir::ConversionPatternRewriter &rewriter) {
80+
mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
81+
return parentOp
82+
? ::getAddressSpace(parentOp->getParentOfType<mlir::ModuleOp>())
83+
: 0u;
84+
}
85+
86+
static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context,
87+
unsigned addressSpace) {
88+
return mlir::LLVM::LLVMPointerType::get(context, addressSpace);
7289
}
7390

7491
static inline mlir::Type getI8Type(mlir::MLIRContext *context) {
@@ -197,7 +214,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
197214
mlir::ConversionPatternRewriter &rewriter,
198215
int boxValue) const {
199216
if (box.getType().isa<mlir::LLVM::LLVMPointerType>()) {
200-
auto pty = ::getLlvmPtrType(resultTy.getContext());
217+
auto pty =
218+
::getLlvmPtrType(resultTy.getContext(), ::getAddressSpace(rewriter));
201219
auto p = rewriter.create<mlir::LLVM::GEPOp>(
202220
loc, pty, boxTy.llvm, box,
203221
llvm::ArrayRef<mlir::LLVM::GEPArg>{0, boxValue});
@@ -278,7 +296,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
278296
mlir::Value
279297
getBaseAddrFromBox(mlir::Location loc, TypePair boxTy, mlir::Value box,
280298
mlir::ConversionPatternRewriter &rewriter) const {
281-
mlir::Type resultTy = ::getLlvmPtrType(boxTy.llvm.getContext());
299+
mlir::Type resultTy =
300+
::getLlvmPtrType(boxTy.llvm.getContext(), ::getAddressSpace(rewriter));
282301
return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kAddrPosInBox);
283302
}
284303

@@ -350,7 +369,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
350369
mlir::ConversionPatternRewriter &rewriter,
351370
mlir::Value base, ARGS... args) const {
352371
llvm::SmallVector<mlir::LLVM::GEPArg> cv = {args...};
353-
auto llvmPtrTy = ::getLlvmPtrType(ty.getContext());
372+
auto llvmPtrTy =
373+
::getLlvmPtrType(ty.getContext(), ::getAddressSpace(rewriter));
354374
return rewriter.create<mlir::LLVM::GEPOp>(loc, llvmPtrTy, ty, base, cv);
355375
}
356376

@@ -378,7 +398,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
378398
mlir::Block *insertBlock = getBlockForAllocaInsert(parentOp);
379399
rewriter.setInsertionPointToStart(insertBlock);
380400
auto size = genI32Constant(loc, rewriter, 1);
381-
mlir::Type llvmPtrTy = ::getLlvmPtrType(llvmObjectTy.getContext());
401+
mlir::Type llvmPtrTy = ::getLlvmPtrType(llvmObjectTy.getContext(),
402+
::getAddressSpace(rewriter));
382403
auto al = rewriter.create<mlir::LLVM::AllocaOp>(
383404
loc, llvmPtrTy, llvmObjectTy, size, alignment);
384405
rewriter.restoreInsertionPoint(thisPt);
@@ -532,7 +553,8 @@ struct AllocaOpConversion : public FIROpConversion<fir::AllocaOp> {
532553
size = rewriter.create<mlir::LLVM::MulOp>(
533554
loc, ity, size, integerCast(loc, rewriter, ity, operands[i]));
534555
}
535-
mlir::Type llvmPtrTy = ::getLlvmPtrType(alloc.getContext());
556+
mlir::Type llvmPtrTy =
557+
::getLlvmPtrType(alloc.getContext(), ::getAddressSpace(rewriter));
536558
// NOTE: we used to pass alloc->getAttrs() in the builder for non opaque
537559
// pointers! Only propagate pinned and bindc_name to help debugging, but
538560
// this should have no functional purpose (and passing the operand segment
@@ -1167,9 +1189,10 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
11671189
auto indexType = mlir::IntegerType::get(op.getContext(), 64);
11681190
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
11691191
rewriter.getUnknownLoc(), "malloc",
1170-
mlir::LLVM::LLVMFunctionType::get(getLlvmPtrType(op.getContext()),
1171-
indexType,
1172-
/*isVarArg=*/false));
1192+
mlir::LLVM::LLVMFunctionType::get(
1193+
getLlvmPtrType(op.getContext(), ::getAddressSpace(rewriter)),
1194+
indexType,
1195+
/*isVarArg=*/false));
11731196
}
11741197

11751198
/// Helper function for generating the LLVM IR that computes the distance
@@ -1189,7 +1212,8 @@ computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
11891212
// *)0 + 1)' trick for all types. The generated instructions are optimized
11901213
// into constant by the first pass of InstCombine, so it should not be a
11911214
// performance issue.
1192-
auto llvmPtrTy = ::getLlvmPtrType(llvmObjectType.getContext());
1215+
auto llvmPtrTy = ::getLlvmPtrType(llvmObjectType.getContext(),
1216+
::getAddressSpace(rewriter));
11931217
auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPtrTy);
11941218
auto gep = rewriter.create<mlir::LLVM::GEPOp>(
11951219
loc, llvmPtrTy, llvmObjectType, nullPtr,
@@ -1232,7 +1256,8 @@ struct AllocMemOpConversion : public FIROpConversion<fir::AllocMemOp> {
12321256
loc, ity, size, integerCast(loc, rewriter, ity, opnd));
12331257
heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
12341258
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
1235-
heap, ::getLlvmPtrType(heap.getContext()), size, heap->getAttrs());
1259+
heap, ::getLlvmPtrType(heap.getContext(), ::getAddressSpace(rewriter)),
1260+
size, heap->getAttrs());
12361261
return mlir::success();
12371262
}
12381263

@@ -1258,9 +1283,10 @@ getFree(fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) {
12581283
auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext());
12591284
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
12601285
rewriter.getUnknownLoc(), "free",
1261-
mlir::LLVM::LLVMFunctionType::get(voidType,
1262-
getLlvmPtrType(op.getContext()),
1263-
/*isVarArg=*/false));
1286+
mlir::LLVM::LLVMFunctionType::get(
1287+
voidType,
1288+
getLlvmPtrType(op.getContext(), ::getAddressSpace(rewriter)),
1289+
/*isVarArg=*/false));
12641290
}
12651291

12661292
static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
@@ -1386,7 +1412,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
13861412
return {getCharacterByteSize(loc, rewriter, charTy, lenParams),
13871413
typeCodeVal};
13881414
if (fir::isa_ref_type(boxEleTy)) {
1389-
auto ptrTy = ::getLlvmPtrType(rewriter.getContext());
1415+
auto ptrTy =
1416+
::getLlvmPtrType(rewriter.getContext(), ::getAddressSpace(rewriter));
13901417
return {genTypeStrideInBytes(loc, i64Ty, rewriter, ptrTy), typeCodeVal};
13911418
}
13921419
if (boxEleTy.isa<fir::RecordType>())
@@ -1447,7 +1474,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
14471474
fir::RecordType recType) const {
14481475
std::string name =
14491476
fir::NameUniquer::getTypeDescriptorName(recType.getName());
1450-
mlir::Type llvmPtrTy = ::getLlvmPtrType(mod.getContext());
1477+
mlir::Type llvmPtrTy =
1478+
::getLlvmPtrType(mod.getContext(), ::getAddressSpace(rewriter));
14511479
if (auto global = mod.template lookupSymbol<fir::GlobalOp>(name)) {
14521480
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
14531481
global.getSymName());
@@ -1505,7 +1533,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
15051533
// Unlimited polymorphic type descriptor with no record type. Set
15061534
// type descriptor address to a clean state.
15071535
typeDesc = rewriter.create<mlir::LLVM::ZeroOp>(
1508-
loc, ::getLlvmPtrType(mod.getContext()));
1536+
loc, ::getLlvmPtrType(mod.getContext(),
1537+
::getAddressSpace(rewriter)));
15091538
}
15101539
} else {
15111540
typeDesc = getTypeDescriptor(mod, rewriter, loc,
@@ -1653,7 +1682,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
16531682
loc, outterOffsetTy, gepArgs[0].get<mlir::Value>(), cast);
16541683
}
16551684
}
1656-
mlir::Type llvmPtrTy = ::getLlvmPtrType(resultTy.getContext());
1685+
mlir::Type llvmPtrTy =
1686+
::getLlvmPtrType(resultTy.getContext(), ::getAddressSpace(rewriter));
16571687
return rewriter.create<mlir::LLVM::GEPOp>(
16581688
loc, llvmPtrTy, llvmBaseObjectType, base, gepArgs);
16591689
}
@@ -2673,7 +2703,8 @@ struct CoordinateOpConversion
26732703
getBaseAddrFromBox(loc, boxTyPair, boxBaseAddr, rewriter);
26742704
// Component Type
26752705
auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy);
2676-
mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext());
2706+
mlir::Type llvmPtrTy =
2707+
::getLlvmPtrType(coor.getContext(), ::getAddressSpace(rewriter));
26772708
mlir::Type byteTy = ::getI8Type(coor.getContext());
26782709
mlir::LLVM::IntegerOverflowFlagsAttr nsw =
26792710
mlir::LLVM::IntegerOverflowFlagsAttr::get(
@@ -2890,7 +2921,8 @@ struct TypeDescOpConversion : public FIROpConversion<fir::TypeDescOp> {
28902921
auto module = typeDescOp.getOperation()->getParentOfType<mlir::ModuleOp>();
28912922
std::string typeDescName =
28922923
fir::NameUniquer::getTypeDescriptorName(recordType.getName());
2893-
auto llvmPtrTy = ::getLlvmPtrType(typeDescOp.getContext());
2924+
auto llvmPtrTy =
2925+
::getLlvmPtrType(typeDescOp.getContext(), ::getAddressSpace(rewriter));
28942926
if (auto global = module.lookupSymbol<mlir::LLVM::GlobalOp>(typeDescName)) {
28952927
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
28962928
typeDescOp, llvmPtrTy, global.getSymName());
@@ -3678,7 +3710,8 @@ struct BoxOffsetOpConversion : public FIROpConversion<fir::BoxOffsetOp> {
36783710
matchAndRewrite(fir::BoxOffsetOp boxOffset, OpAdaptor adaptor,
36793711
mlir::ConversionPatternRewriter &rewriter) const override {
36803712

3681-
mlir::Type pty = ::getLlvmPtrType(boxOffset.getContext());
3713+
mlir::Type pty =
3714+
::getLlvmPtrType(boxOffset.getContext(), ::getAddressSpace(rewriter));
36823715
mlir::Type boxType = fir::unwrapRefType(boxOffset.getBoxRef().getType());
36833716
mlir::Type llvmBoxTy =
36843717
lowerTy().convertBoxTypeAsStruct(mlir::cast<fir::BaseBoxType>(boxType));

flang/lib/Optimizer/CodeGen/DescriptorModel.h

+33-25
Original file line numberDiff line numberDiff line change
@@ -31,72 +31,80 @@
3131

3232
namespace fir {
3333

34-
using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *);
34+
using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *, unsigned);
3535

3636
/// Get the LLVM IR dialect model for building a particular C++ type, `T`.
3737
template <typename T>
3838
TypeBuilderFunc getModel();
3939

4040
template <>
4141
TypeBuilderFunc getModel<void *>() {
42-
return [](mlir::MLIRContext *context) -> mlir::Type {
43-
return mlir::LLVM::LLVMPointerType::get(context);
42+
return [](mlir::MLIRContext *context, unsigned addressSpace) -> mlir::Type {
43+
return mlir::LLVM::LLVMPointerType::get(context, addressSpace);
4444
};
4545
}
4646
template <>
4747
TypeBuilderFunc getModel<unsigned>() {
48-
return [](mlir::MLIRContext *context) -> mlir::Type {
49-
return mlir::IntegerType::get(context, sizeof(unsigned) * 8);
50-
};
48+
return
49+
[](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
50+
return mlir::IntegerType::get(context, sizeof(unsigned) * 8);
51+
};
5152
}
5253
template <>
5354
TypeBuilderFunc getModel<int>() {
54-
return [](mlir::MLIRContext *context) -> mlir::Type {
55-
return mlir::IntegerType::get(context, sizeof(int) * 8);
56-
};
55+
return
56+
[](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
57+
return mlir::IntegerType::get(context, sizeof(int) * 8);
58+
};
5759
}
5860
template <>
5961
TypeBuilderFunc getModel<unsigned long>() {
60-
return [](mlir::MLIRContext *context) -> mlir::Type {
61-
return mlir::IntegerType::get(context, sizeof(unsigned long) * 8);
62-
};
62+
return
63+
[](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
64+
return mlir::IntegerType::get(context, sizeof(unsigned long) * 8);
65+
};
6366
}
6467
template <>
6568
TypeBuilderFunc getModel<unsigned long long>() {
66-
return [](mlir::MLIRContext *context) -> mlir::Type {
67-
return mlir::IntegerType::get(context, sizeof(unsigned long long) * 8);
68-
};
69+
return
70+
[](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
71+
return mlir::IntegerType::get(context, sizeof(unsigned long long) * 8);
72+
};
6973
}
7074
template <>
7175
TypeBuilderFunc getModel<long long>() {
72-
return [](mlir::MLIRContext *context) -> mlir::Type {
73-
return mlir::IntegerType::get(context, sizeof(long long) * 8);
74-
};
76+
return
77+
[](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
78+
return mlir::IntegerType::get(context, sizeof(long long) * 8);
79+
};
7580
}
7681
template <>
7782
TypeBuilderFunc getModel<Fortran::ISO::CFI_rank_t>() {
78-
return [](mlir::MLIRContext *context) -> mlir::Type {
83+
return [](mlir::MLIRContext *context,
84+
unsigned /*addressSpace*/) -> mlir::Type {
7985
return mlir::IntegerType::get(context,
8086
sizeof(Fortran::ISO::CFI_rank_t) * 8);
8187
};
8288
}
8389
template <>
8490
TypeBuilderFunc getModel<Fortran::ISO::CFI_type_t>() {
85-
return [](mlir::MLIRContext *context) -> mlir::Type {
91+
return [](mlir::MLIRContext *context,
92+
unsigned /*addressSpace*/) -> mlir::Type {
8693
return mlir::IntegerType::get(context,
8794
sizeof(Fortran::ISO::CFI_type_t) * 8);
8895
};
8996
}
9097
template <>
9198
TypeBuilderFunc getModel<long>() {
92-
return [](mlir::MLIRContext *context) -> mlir::Type {
93-
return mlir::IntegerType::get(context, sizeof(long) * 8);
94-
};
99+
return
100+
[](mlir::MLIRContext *context, unsigned /*addressSpace*/) -> mlir::Type {
101+
return mlir::IntegerType::get(context, sizeof(long) * 8);
102+
};
95103
}
96104
template <>
97105
TypeBuilderFunc getModel<Fortran::ISO::CFI_dim_t>() {
98-
return [](mlir::MLIRContext *context) -> mlir::Type {
99-
auto indexTy = getModel<Fortran::ISO::CFI_index_t>()(context);
106+
return [](mlir::MLIRContext *context, unsigned addressSpace) -> mlir::Type {
107+
auto indexTy = getModel<Fortran::ISO::CFI_index_t>()(context, addressSpace);
100108
return mlir::LLVM::LLVMArrayType::get(indexTy, 3);
101109
};
102110
}

flang/lib/Optimizer/CodeGen/TypeConverter.cpp

+19-16
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
7373
addConversion([&](fir::LenType field) {
7474
// Get size of len paramter from the descriptor.
7575
return getModel<Fortran::runtime::typeInfo::TypeParameterValue>()(
76-
&getContext());
76+
&getContext(), addressSpace);
7777
});
7878
addConversion([&](fir::LogicalType boolTy) {
7979
return mlir::IntegerType::get(
@@ -220,25 +220,25 @@ mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box,
220220
dataDescFields.push_back(eleTy);
221221
else
222222
dataDescFields.push_back(
223-
mlir::LLVM::LLVMPointerType::get(eleTy.getContext()));
223+
mlir::LLVM::LLVMPointerType::get(eleTy.getContext(), addressSpace));
224224
// elem_len
225225
dataDescFields.push_back(
226-
getDescFieldTypeModel<kElemLenPosInBox>()(&getContext()));
226+
getDescFieldTypeModel<kElemLenPosInBox>()(&getContext(), addressSpace));
227227
// version
228228
dataDescFields.push_back(
229-
getDescFieldTypeModel<kVersionPosInBox>()(&getContext()));
229+
getDescFieldTypeModel<kVersionPosInBox>()(&getContext(), addressSpace));
230230
// rank
231231
dataDescFields.push_back(
232-
getDescFieldTypeModel<kRankPosInBox>()(&getContext()));
232+
getDescFieldTypeModel<kRankPosInBox>()(&getContext(), addressSpace));
233233
// type
234234
dataDescFields.push_back(
235-
getDescFieldTypeModel<kTypePosInBox>()(&getContext()));
235+
getDescFieldTypeModel<kTypePosInBox>()(&getContext(), addressSpace));
236236
// attribute
237237
dataDescFields.push_back(
238-
getDescFieldTypeModel<kAttributePosInBox>()(&getContext()));
238+
getDescFieldTypeModel<kAttributePosInBox>()(&getContext(), addressSpace));
239239
// f18Addendum
240-
dataDescFields.push_back(
241-
getDescFieldTypeModel<kF18AddendumPosInBox>()(&getContext()));
240+
dataDescFields.push_back(getDescFieldTypeModel<kF18AddendumPosInBox>()(
241+
&getContext(), addressSpace));
242242
// [dims]
243243
if (rank == unknownRank()) {
244244
if (auto seqTy = ele.dyn_cast<SequenceType>())
@@ -247,15 +247,17 @@ mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box,
247247
rank = 0;
248248
}
249249
if (rank > 0) {
250-
auto rowTy = getDescFieldTypeModel<kDimsPosInBox>()(&getContext());
250+
auto rowTy =
251+
getDescFieldTypeModel<kDimsPosInBox>()(&getContext(), addressSpace);
251252
dataDescFields.push_back(mlir::LLVM::LLVMArrayType::get(rowTy, rank));
252253
}
253254
// opt-type-ptr: i8* (see fir.tdesc)
254255
if (requiresExtendedDesc(ele) || fir::isUnlimitedPolymorphicType(box)) {
255256
dataDescFields.push_back(
256-
getExtendedDescFieldTypeModel<kOptTypePtrPosInBox>()(&getContext()));
257-
auto rowTy =
258-
getExtendedDescFieldTypeModel<kOptRowTypePosInBox>()(&getContext());
257+
getExtendedDescFieldTypeModel<kOptTypePtrPosInBox>()(&getContext(),
258+
addressSpace));
259+
auto rowTy = getExtendedDescFieldTypeModel<kOptRowTypePosInBox>()(
260+
&getContext(), addressSpace);
259261
dataDescFields.push_back(mlir::LLVM::LLVMArrayType::get(rowTy, 1));
260262
if (auto recTy = fir::unwrapSequenceType(ele).dyn_cast<fir::RecordType>())
261263
if (recTy.getNumLenParams() > 0) {
@@ -278,13 +280,14 @@ mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box,
278280
mlir::Type LLVMTypeConverter::convertBoxType(BaseBoxType box, int rank) const {
279281
// TODO: send the box type and the converted LLVM structure layout
280282
// to tbaaBuilder for proper creation of TBAATypeDescriptorOp.
281-
return mlir::LLVM::LLVMPointerType::get(box.getContext());
283+
return mlir::LLVM::LLVMPointerType::get(box.getContext(), addressSpace);
282284
}
283285

284286
// fir.boxproc<any> --> llvm<"{ any*, i8* }">
285287
mlir::Type LLVMTypeConverter::convertBoxProcType(BoxProcType boxproc) const {
286288
auto funcTy = convertType(boxproc.getEleTy());
287-
auto voidPtrTy = mlir::LLVM::LLVMPointerType::get(boxproc.getContext());
289+
auto voidPtrTy =
290+
mlir::LLVM::LLVMPointerType::get(boxproc.getContext(), addressSpace);
288291
llvm::SmallVector<mlir::Type, 2> tuple = {funcTy, voidPtrTy};
289292
return mlir::LLVM::LLVMStructType::getLiteral(boxproc.getContext(), tuple,
290293
/*isPacked=*/false);
@@ -335,7 +338,7 @@ mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) const {
335338
// the f18 object v. class distinction (F2003).
336339
mlir::Type
337340
LLVMTypeConverter::convertTypeDescType(mlir::MLIRContext *ctx) const {
338-
return mlir::LLVM::LLVMPointerType::get(ctx);
341+
return mlir::LLVM::LLVMPointerType::get(ctx, addressSpace);
339342
}
340343

341344
// Relay TBAA tag attachment to TBAABuilder.

0 commit comments

Comments
 (0)