@@ -134,16 +134,65 @@ addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
134
134
}
135
135
136
136
namespace {
137
+
138
+ // Creates an existing operation with an AddressOfOp or an AddrSpaceCastOp
139
+ // depending on the existing address spaces of the type.
140
+ mlir::Value createAddrOfOrASCast (mlir::ConversionPatternRewriter &rewriter,
141
+ mlir::Location loc, std::uint64_t globalAS,
142
+ std::uint64_t programAS,
143
+ llvm::StringRef symName, mlir::Type type) {
144
+ if (mlir::isa<mlir::LLVM::LLVMPointerType>(type)) {
145
+ if (globalAS != programAS) {
146
+ auto llvmAddrOp = rewriter.create <mlir::LLVM::AddressOfOp>(
147
+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
148
+ return rewriter.create <mlir::LLVM::AddrSpaceCastOp>(
149
+ loc, getLlvmPtrType (rewriter.getContext (), programAS), llvmAddrOp);
150
+ }
151
+ return rewriter.create <mlir::LLVM::AddressOfOp>(
152
+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
153
+ }
154
+ return rewriter.create <mlir::LLVM::AddressOfOp>(loc, type, symName);
155
+ }
156
+
157
+ // Replaces an existing operation with an AddressOfOp or an AddrSpaceCastOp
158
+ // depending on the existing address spaces of the type.
159
+ mlir::Value replaceWithAddrOfOrASCast (mlir::ConversionPatternRewriter &rewriter,
160
+ mlir::Location loc,
161
+ std::uint64_t globalAS,
162
+ std::uint64_t programAS,
163
+ llvm::StringRef symName, mlir::Type type,
164
+ mlir::Operation *replaceOp) {
165
+ if (mlir::isa<mlir::LLVM::LLVMPointerType>(type)) {
166
+ if (globalAS != programAS) {
167
+ auto llvmAddrOp = rewriter.create <mlir::LLVM::AddressOfOp>(
168
+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
169
+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddrSpaceCastOp>(
170
+ replaceOp, ::getLlvmPtrType (rewriter.getContext (), programAS),
171
+ llvmAddrOp);
172
+ }
173
+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
174
+ replaceOp, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
175
+ }
176
+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(replaceOp, type,
177
+ symName);
178
+ }
179
+
137
180
// / Lower `fir.address_of` operation to `llvm.address_of` operation.
138
181
struct AddrOfOpConversion : public fir ::FIROpConversion<fir::AddrOfOp> {
139
182
using FIROpConversion::FIROpConversion;
140
183
141
184
llvm::LogicalResult
142
185
matchAndRewrite (fir::AddrOfOp addr, OpAdaptor adaptor,
143
186
mlir::ConversionPatternRewriter &rewriter) const override {
144
- auto ty = convertType (addr.getType ());
145
- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
146
- addr, ty, addr.getSymbol ().getRootReference ().getValue ());
187
+ auto global = addr->getParentOfType <mlir::ModuleOp>()
188
+ .lookupSymbol <mlir::LLVM::GlobalOp>(addr.getSymbol ());
189
+ replaceWithAddrOfOrASCast (
190
+ rewriter, addr->getLoc (),
191
+ global ? global.getAddrSpace () : getGlobalAddressSpace (rewriter),
192
+ getProgramAddressSpace (rewriter),
193
+ global ? global.getSymName ()
194
+ : addr.getSymbol ().getRootReference ().getValue (),
195
+ convertType (addr.getType ()), addr);
147
196
return mlir::success ();
148
197
}
149
198
};
@@ -1350,14 +1399,26 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
1350
1399
? fir::NameUniquer::getTypeDescriptorAssemblyName (recType.getName ())
1351
1400
: fir::NameUniquer::getTypeDescriptorName (recType.getName ());
1352
1401
mlir::Type llvmPtrTy = ::getLlvmPtrType (mod.getContext ());
1402
+
1403
+ // As we set allowDefaultLayout to true, there should be no chance the
1404
+ // optional returns null even if the module has no layout information,
1405
+ // however, assert just incase.
1406
+ std::optional<mlir::DataLayout> dataLayout =
1407
+ fir::support::getOrSetDataLayout (mod, /* allowDefaultLayout=*/ true );
1408
+ assert (!dataLayout.has_value ());
1409
+
1353
1410
if (auto global = mod.template lookupSymbol <fir::GlobalOp>(name)) {
1354
- return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1355
- global.getSymName ());
1411
+ return createAddrOfOrASCast (
1412
+ rewriter, loc, fir::factory::getGlobalAddressSpace (&*dataLayout),
1413
+ fir::factory::getProgramAddressSpace (&*dataLayout),
1414
+ global.getSymName (), llvmPtrTy);
1356
1415
}
1357
1416
if (auto global = mod.template lookupSymbol <mlir::LLVM::GlobalOp>(name)) {
1358
1417
// The global may have already been translated to LLVM.
1359
- return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1360
- global.getSymName ());
1418
+ return createAddrOfOrASCast (
1419
+ rewriter, loc, global.getAddrSpace (),
1420
+ fir::factory::getProgramAddressSpace (&*dataLayout),
1421
+ global.getSymName (), llvmPtrTy);
1361
1422
}
1362
1423
// Type info derived types do not have type descriptors since they are the
1363
1424
// types defining type descriptors.
@@ -2896,12 +2957,16 @@ struct TypeDescOpConversion : public fir::FIROpConversion<fir::TypeDescOp> {
2896
2957
: fir::NameUniquer::getTypeDescriptorName (recordType.getName ());
2897
2958
auto llvmPtrTy = ::getLlvmPtrType (typeDescOp.getContext ());
2898
2959
if (auto global = module.lookupSymbol <mlir::LLVM::GlobalOp>(typeDescName)) {
2899
- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
2900
- typeDescOp, llvmPtrTy, global.getSymName ());
2960
+ replaceWithAddrOfOrASCast (rewriter, typeDescOp->getLoc (),
2961
+ global.getAddrSpace (),
2962
+ getProgramAddressSpace (rewriter),
2963
+ global.getSymName (), llvmPtrTy, typeDescOp);
2901
2964
return mlir::success ();
2902
2965
} else if (auto global = module.lookupSymbol <fir::GlobalOp>(typeDescName)) {
2903
- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
2904
- typeDescOp, llvmPtrTy, global.getSymName ());
2966
+ replaceWithAddrOfOrASCast (rewriter, typeDescOp->getLoc (),
2967
+ getGlobalAddressSpace (rewriter),
2968
+ getProgramAddressSpace (rewriter),
2969
+ global.getSymName (), llvmPtrTy, typeDescOp);
2905
2970
return mlir::success ();
2906
2971
}
2907
2972
return mlir::failure ();
@@ -2992,8 +3057,8 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
2992
3057
mlir::SymbolRefAttr comdat;
2993
3058
llvm::ArrayRef<mlir::NamedAttribute> attrs;
2994
3059
auto g = rewriter.create <mlir::LLVM::GlobalOp>(
2995
- loc, tyAttr, isConst, linkage, global.getSymName (), initAttr, 0 , 0 ,
2996
- false , false , comdat, attrs, dbgExprs);
3060
+ loc, tyAttr, isConst, linkage, global.getSymName (), initAttr, 0 ,
3061
+ getGlobalAddressSpace (rewriter), false , false , comdat, attrs, dbgExprs);
2997
3062
2998
3063
if (global.getAlignment () && *global.getAlignment () > 0 )
2999
3064
g.setAlignment (*global.getAlignment ());
0 commit comments