@@ -132,16 +132,54 @@ addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
132
132
}
133
133
134
134
namespace {
135
+
136
+ mlir::Value replaceWithAddrOfOrASCast (mlir::ConversionPatternRewriter &rewriter,
137
+ mlir::Location loc,
138
+ std::uint64_t globalAS,
139
+ std::uint64_t programAS,
140
+ llvm::StringRef symName, mlir::Type type,
141
+ mlir::Operation *replaceOp = nullptr ) {
142
+ if (mlir::isa<mlir::LLVM::LLVMPointerType>(type)) {
143
+ if (globalAS != programAS) {
144
+ auto llvmAddrOp = rewriter.create <mlir::LLVM::AddressOfOp>(
145
+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
146
+ if (replaceOp)
147
+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddrSpaceCastOp>(
148
+ replaceOp, ::getLlvmPtrType (rewriter.getContext (), programAS),
149
+ llvmAddrOp);
150
+ return rewriter.create <mlir::LLVM::AddrSpaceCastOp>(
151
+ loc, getLlvmPtrType (rewriter.getContext (), programAS), llvmAddrOp);
152
+ }
153
+
154
+ if (replaceOp)
155
+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
156
+ replaceOp, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
157
+ return rewriter.create <mlir::LLVM::AddressOfOp>(
158
+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
159
+ }
160
+
161
+ if (replaceOp)
162
+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(replaceOp, type,
163
+ symName);
164
+ return rewriter.create <mlir::LLVM::AddressOfOp>(loc, type, symName);
165
+ }
166
+
135
167
// / Lower `fir.address_of` operation to `llvm.address_of` operation.
136
168
struct AddrOfOpConversion : public fir ::FIROpConversion<fir::AddrOfOp> {
137
169
using FIROpConversion::FIROpConversion;
138
170
139
171
llvm::LogicalResult
140
172
matchAndRewrite (fir::AddrOfOp addr, OpAdaptor adaptor,
141
173
mlir::ConversionPatternRewriter &rewriter) const override {
142
- auto ty = convertType (addr.getType ());
143
- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
144
- addr, ty, addr.getSymbol ().getRootReference ().getValue ());
174
+ auto global = addr->getParentOfType <mlir::ModuleOp>()
175
+ .lookupSymbol <mlir::LLVM::GlobalOp>(addr.getSymbol ());
176
+ replaceWithAddrOfOrASCast (
177
+ rewriter, addr->getLoc (),
178
+ global ? global.getAddrSpace () : getGlobalAddressSpace (rewriter),
179
+ getProgramAddressSpace (rewriter),
180
+ global ? global.getSymName ()
181
+ : addr.getSymbol ().getRootReference ().getValue (),
182
+ convertType (addr.getType ()), addr);
145
183
return mlir::success ();
146
184
}
147
185
};
@@ -1255,14 +1293,19 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
1255
1293
? fir::NameUniquer::getTypeDescriptorAssemblyName (recType.getName ())
1256
1294
: fir::NameUniquer::getTypeDescriptorName (recType.getName ());
1257
1295
mlir::Type llvmPtrTy = ::getLlvmPtrType (mod.getContext ());
1296
+ mlir::DataLayout dataLayout (mod);
1258
1297
if (auto global = mod.template lookupSymbol <fir::GlobalOp>(name)) {
1259
- return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1260
- global.getSymName ());
1298
+ return replaceWithAddrOfOrASCast (
1299
+ rewriter, loc, fir::factory::getGlobalAddressSpace (&dataLayout),
1300
+ fir::factory::getProgramAddressSpace (&dataLayout),
1301
+ global.getSymName (), llvmPtrTy);
1261
1302
}
1262
1303
if (auto global = mod.template lookupSymbol <mlir::LLVM::GlobalOp>(name)) {
1263
1304
// The global may have already been translated to LLVM.
1264
- return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1265
- global.getSymName ());
1305
+ return replaceWithAddrOfOrASCast (
1306
+ rewriter, loc, global.getAddrSpace (),
1307
+ fir::factory::getProgramAddressSpace (&dataLayout),
1308
+ global.getSymName (), llvmPtrTy);
1266
1309
}
1267
1310
// Type info derived types do not have type descriptors since they are the
1268
1311
// types defining type descriptors.
@@ -2759,12 +2802,16 @@ struct TypeDescOpConversion : public fir::FIROpConversion<fir::TypeDescOp> {
2759
2802
: fir::NameUniquer::getTypeDescriptorName (recordType.getName ());
2760
2803
auto llvmPtrTy = ::getLlvmPtrType (typeDescOp.getContext ());
2761
2804
if (auto global = module.lookupSymbol <mlir::LLVM::GlobalOp>(typeDescName)) {
2762
- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
2763
- typeDescOp, llvmPtrTy, global.getSymName ());
2805
+ replaceWithAddrOfOrASCast (rewriter, typeDescOp->getLoc (),
2806
+ global.getAddrSpace (),
2807
+ getProgramAddressSpace (rewriter),
2808
+ global.getSymName (), llvmPtrTy, typeDescOp);
2764
2809
return mlir::success ();
2765
2810
} else if (auto global = module.lookupSymbol <fir::GlobalOp>(typeDescName)) {
2766
- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
2767
- typeDescOp, llvmPtrTy, global.getSymName ());
2811
+ replaceWithAddrOfOrASCast (rewriter, typeDescOp->getLoc (),
2812
+ getGlobalAddressSpace (rewriter),
2813
+ getProgramAddressSpace (rewriter),
2814
+ global.getSymName (), llvmPtrTy, typeDescOp);
2768
2815
return mlir::success ();
2769
2816
}
2770
2817
return mlir::failure ();
@@ -2855,8 +2902,8 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
2855
2902
mlir::SymbolRefAttr comdat;
2856
2903
llvm::ArrayRef<mlir::NamedAttribute> attrs;
2857
2904
auto g = rewriter.create <mlir::LLVM::GlobalOp>(
2858
- loc, tyAttr, isConst, linkage, global.getSymName (), initAttr, 0 , 0 ,
2859
- false , false , comdat, attrs, dbgExprs);
2905
+ loc, tyAttr, isConst, linkage, global.getSymName (), initAttr, 0 ,
2906
+ getGlobalAddressSpace (rewriter), false , false , comdat, attrs, dbgExprs);
2860
2907
2861
2908
if (global.getAlignment () && *global.getAlignment () > 0 )
2862
2909
g.setAlignment (*global.getAlignment ());
0 commit comments