@@ -67,8 +67,25 @@ static constexpr unsigned defaultAlign = 8;
67
67
static constexpr unsigned kAttrPointer = CFI_attribute_pointer;
68
68
static constexpr unsigned kAttrAllocatable = CFI_attribute_allocatable;
69
69
70
- static inline mlir::Type getLlvmPtrType (mlir::MLIRContext *context) {
71
- return mlir::LLVM::LLVMPointerType::get (context);
70
+ static inline unsigned getAddressSpace (mlir::ModuleOp module) {
71
+ if (mlir::Attribute addrSpace =
72
+ mlir::DataLayout (module).getAllocaMemorySpace ())
73
+ return addrSpace.cast <mlir::IntegerAttr>().getUInt ();
74
+
75
+ return 0u ;
76
+ }
77
+
78
+ static inline unsigned
79
+ getAddressSpace (mlir::ConversionPatternRewriter &rewriter) {
80
+ mlir::Operation *parentOp = rewriter.getInsertionBlock ()->getParentOp ();
81
+ return parentOp
82
+ ? ::getAddressSpace (parentOp->getParentOfType <mlir::ModuleOp>())
83
+ : 0u ;
84
+ }
85
+
86
+ static inline mlir::Type getLlvmPtrType (mlir::MLIRContext *context,
87
+ unsigned addressSpace) {
88
+ return mlir::LLVM::LLVMPointerType::get (context, addressSpace);
72
89
}
73
90
74
91
static inline mlir::Type getI8Type (mlir::MLIRContext *context) {
@@ -197,7 +214,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
197
214
mlir::ConversionPatternRewriter &rewriter,
198
215
int boxValue) const {
199
216
if (box.getType ().isa <mlir::LLVM::LLVMPointerType>()) {
200
- auto pty = ::getLlvmPtrType (resultTy.getContext ());
217
+ auto pty =
218
+ ::getLlvmPtrType (resultTy.getContext(), ::getAddressSpace(rewriter));
201
219
auto p = rewriter.create <mlir::LLVM::GEPOp>(
202
220
loc, pty, boxTy.llvm , box,
203
221
llvm::ArrayRef<mlir::LLVM::GEPArg>{0 , boxValue});
@@ -278,7 +296,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
278
296
mlir::Value
279
297
getBaseAddrFromBox (mlir::Location loc, TypePair boxTy, mlir::Value box,
280
298
mlir::ConversionPatternRewriter &rewriter) const {
281
- mlir::Type resultTy = ::getLlvmPtrType (boxTy.llvm .getContext ());
299
+ mlir::Type resultTy =
300
+ ::getLlvmPtrType (boxTy.llvm.getContext(), ::getAddressSpace(rewriter));
282
301
return getValueFromBox (loc, boxTy, box, resultTy, rewriter, kAddrPosInBox );
283
302
}
284
303
@@ -350,7 +369,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
350
369
mlir::ConversionPatternRewriter &rewriter,
351
370
mlir::Value base, ARGS... args) const {
352
371
llvm::SmallVector<mlir::LLVM::GEPArg> cv = {args...};
353
- auto llvmPtrTy = ::getLlvmPtrType (ty.getContext ());
372
+ auto llvmPtrTy =
373
+ ::getLlvmPtrType (ty.getContext(), ::getAddressSpace(rewriter));
354
374
return rewriter.create <mlir::LLVM::GEPOp>(loc, llvmPtrTy, ty, base, cv);
355
375
}
356
376
@@ -378,7 +398,8 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
378
398
mlir::Block *insertBlock = getBlockForAllocaInsert (parentOp);
379
399
rewriter.setInsertionPointToStart (insertBlock);
380
400
auto size = genI32Constant (loc, rewriter, 1 );
381
- mlir::Type llvmPtrTy = ::getLlvmPtrType (llvmObjectTy.getContext ());
401
+ mlir::Type llvmPtrTy = ::getLlvmPtrType (llvmObjectTy.getContext (),
402
+ ::getAddressSpace (rewriter));
382
403
auto al = rewriter.create <mlir::LLVM::AllocaOp>(
383
404
loc, llvmPtrTy, llvmObjectTy, size, alignment);
384
405
rewriter.restoreInsertionPoint (thisPt);
@@ -532,7 +553,8 @@ struct AllocaOpConversion : public FIROpConversion<fir::AllocaOp> {
532
553
size = rewriter.create <mlir::LLVM::MulOp>(
533
554
loc, ity, size, integerCast (loc, rewriter, ity, operands[i]));
534
555
}
535
- mlir::Type llvmPtrTy = ::getLlvmPtrType (alloc.getContext ());
556
+ mlir::Type llvmPtrTy =
557
+ ::getLlvmPtrType (alloc.getContext(), ::getAddressSpace(rewriter));
536
558
// NOTE: we used to pass alloc->getAttrs() in the builder for non opaque
537
559
// pointers! Only propagate pinned and bindc_name to help debugging, but
538
560
// this should have no functional purpose (and passing the operand segment
@@ -1167,9 +1189,10 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
1167
1189
auto indexType = mlir::IntegerType::get (op.getContext (), 64 );
1168
1190
return moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
1169
1191
rewriter.getUnknownLoc (), " malloc" ,
1170
- mlir::LLVM::LLVMFunctionType::get (getLlvmPtrType (op.getContext ()),
1171
- indexType,
1172
- /* isVarArg=*/ false ));
1192
+ mlir::LLVM::LLVMFunctionType::get (
1193
+ getLlvmPtrType (op.getContext (), ::getAddressSpace (rewriter)),
1194
+ indexType,
1195
+ /* isVarArg=*/ false ));
1173
1196
}
1174
1197
1175
1198
// / Helper function for generating the LLVM IR that computes the distance
@@ -1189,7 +1212,8 @@ computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType,
1189
1212
// *)0 + 1)' trick for all types. The generated instructions are optimized
1190
1213
// into constant by the first pass of InstCombine, so it should not be a
1191
1214
// performance issue.
1192
- auto llvmPtrTy = ::getLlvmPtrType (llvmObjectType.getContext ());
1215
+ auto llvmPtrTy = ::getLlvmPtrType (llvmObjectType.getContext (),
1216
+ ::getAddressSpace (rewriter));
1193
1217
auto nullPtr = rewriter.create <mlir::LLVM::ZeroOp>(loc, llvmPtrTy);
1194
1218
auto gep = rewriter.create <mlir::LLVM::GEPOp>(
1195
1219
loc, llvmPtrTy, llvmObjectType, nullPtr,
@@ -1232,7 +1256,8 @@ struct AllocMemOpConversion : public FIROpConversion<fir::AllocMemOp> {
1232
1256
loc, ity, size, integerCast (loc, rewriter, ity, opnd));
1233
1257
heap->setAttr (" callee" , mlir::SymbolRefAttr::get (mallocFunc));
1234
1258
rewriter.replaceOpWithNewOp <mlir::LLVM::CallOp>(
1235
- heap, ::getLlvmPtrType (heap.getContext ()), size, heap->getAttrs ());
1259
+ heap, ::getLlvmPtrType (heap.getContext (), ::getAddressSpace (rewriter)),
1260
+ size, heap->getAttrs ());
1236
1261
return mlir::success ();
1237
1262
}
1238
1263
@@ -1258,9 +1283,10 @@ getFree(fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) {
1258
1283
auto voidType = mlir::LLVM::LLVMVoidType::get (op.getContext ());
1259
1284
return moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
1260
1285
rewriter.getUnknownLoc (), " free" ,
1261
- mlir::LLVM::LLVMFunctionType::get (voidType,
1262
- getLlvmPtrType (op.getContext ()),
1263
- /* isVarArg=*/ false ));
1286
+ mlir::LLVM::LLVMFunctionType::get (
1287
+ voidType,
1288
+ getLlvmPtrType (op.getContext (), ::getAddressSpace (rewriter)),
1289
+ /* isVarArg=*/ false ));
1264
1290
}
1265
1291
1266
1292
static unsigned getDimension (mlir::LLVM::LLVMArrayType ty) {
@@ -1386,7 +1412,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
1386
1412
return {getCharacterByteSize (loc, rewriter, charTy, lenParams),
1387
1413
typeCodeVal};
1388
1414
if (fir::isa_ref_type (boxEleTy)) {
1389
- auto ptrTy = ::getLlvmPtrType (rewriter.getContext ());
1415
+ auto ptrTy =
1416
+ ::getLlvmPtrType (rewriter.getContext(), ::getAddressSpace(rewriter));
1390
1417
return {genTypeStrideInBytes (loc, i64Ty, rewriter, ptrTy), typeCodeVal};
1391
1418
}
1392
1419
if (boxEleTy.isa <fir::RecordType>())
@@ -1447,7 +1474,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
1447
1474
fir::RecordType recType) const {
1448
1475
std::string name =
1449
1476
fir::NameUniquer::getTypeDescriptorName (recType.getName ());
1450
- mlir::Type llvmPtrTy = ::getLlvmPtrType (mod.getContext ());
1477
+ mlir::Type llvmPtrTy =
1478
+ ::getLlvmPtrType (mod.getContext(), ::getAddressSpace(rewriter));
1451
1479
if (auto global = mod.template lookupSymbol <fir::GlobalOp>(name)) {
1452
1480
return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1453
1481
global.getSymName ());
@@ -1505,7 +1533,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
1505
1533
// Unlimited polymorphic type descriptor with no record type. Set
1506
1534
// type descriptor address to a clean state.
1507
1535
typeDesc = rewriter.create <mlir::LLVM::ZeroOp>(
1508
- loc, ::getLlvmPtrType (mod.getContext ()));
1536
+ loc, ::getLlvmPtrType (mod.getContext (),
1537
+ ::getAddressSpace (rewriter)));
1509
1538
}
1510
1539
} else {
1511
1540
typeDesc = getTypeDescriptor (mod, rewriter, loc,
@@ -1653,7 +1682,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
1653
1682
loc, outterOffsetTy, gepArgs[0 ].get <mlir::Value>(), cast);
1654
1683
}
1655
1684
}
1656
- mlir::Type llvmPtrTy = ::getLlvmPtrType (resultTy.getContext ());
1685
+ mlir::Type llvmPtrTy =
1686
+ ::getLlvmPtrType (resultTy.getContext(), ::getAddressSpace(rewriter));
1657
1687
return rewriter.create <mlir::LLVM::GEPOp>(
1658
1688
loc, llvmPtrTy, llvmBaseObjectType, base, gepArgs);
1659
1689
}
@@ -2673,7 +2703,8 @@ struct CoordinateOpConversion
2673
2703
getBaseAddrFromBox (loc, boxTyPair, boxBaseAddr, rewriter);
2674
2704
// Component Type
2675
2705
auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy (boxObjTy);
2676
- mlir::Type llvmPtrTy = ::getLlvmPtrType (coor.getContext ());
2706
+ mlir::Type llvmPtrTy =
2707
+ ::getLlvmPtrType (coor.getContext(), ::getAddressSpace(rewriter));
2677
2708
mlir::Type byteTy = ::getI8Type (coor.getContext ());
2678
2709
mlir::LLVM::IntegerOverflowFlagsAttr nsw =
2679
2710
mlir::LLVM::IntegerOverflowFlagsAttr::get (
@@ -2890,7 +2921,8 @@ struct TypeDescOpConversion : public FIROpConversion<fir::TypeDescOp> {
2890
2921
auto module = typeDescOp.getOperation ()->getParentOfType <mlir::ModuleOp>();
2891
2922
std::string typeDescName =
2892
2923
fir::NameUniquer::getTypeDescriptorName (recordType.getName ());
2893
- auto llvmPtrTy = ::getLlvmPtrType (typeDescOp.getContext ());
2924
+ auto llvmPtrTy =
2925
+ ::getLlvmPtrType (typeDescOp.getContext(), ::getAddressSpace(rewriter));
2894
2926
if (auto global = module.lookupSymbol <mlir::LLVM::GlobalOp>(typeDescName)) {
2895
2927
rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
2896
2928
typeDescOp, llvmPtrTy, global.getSymName ());
@@ -3678,7 +3710,8 @@ struct BoxOffsetOpConversion : public FIROpConversion<fir::BoxOffsetOp> {
3678
3710
matchAndRewrite (fir::BoxOffsetOp boxOffset, OpAdaptor adaptor,
3679
3711
mlir::ConversionPatternRewriter &rewriter) const override {
3680
3712
3681
- mlir::Type pty = ::getLlvmPtrType (boxOffset.getContext ());
3713
+ mlir::Type pty =
3714
+ ::getLlvmPtrType (boxOffset.getContext(), ::getAddressSpace(rewriter));
3682
3715
mlir::Type boxType = fir::unwrapRefType (boxOffset.getBoxRef ().getType ());
3683
3716
mlir::Type llvmBoxTy =
3684
3717
lowerTy ().convertBoxTypeAsStruct (mlir::cast<fir::BaseBoxType>(boxType));
0 commit comments