@@ -134,16 +134,65 @@ addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
134
134
}
135
135
136
136
namespace {
137
+
138
+ // Creates an existing operation with an AddressOfOp or an AddrSpaceCastOp
139
+ // depending on the existing address spaces of the type.
140
+ mlir::Value createAddrOfOrASCast (mlir::ConversionPatternRewriter &rewriter,
141
+ mlir::Location loc, std::uint64_t globalAS,
142
+ std::uint64_t programAS,
143
+ llvm::StringRef symName, mlir::Type type) {
144
+ if (mlir::isa<mlir::LLVM::LLVMPointerType>(type)) {
145
+ if (globalAS != programAS) {
146
+ auto llvmAddrOp = rewriter.create <mlir::LLVM::AddressOfOp>(
147
+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
148
+ return rewriter.create <mlir::LLVM::AddrSpaceCastOp>(
149
+ loc, getLlvmPtrType (rewriter.getContext (), programAS), llvmAddrOp);
150
+ }
151
+ return rewriter.create <mlir::LLVM::AddressOfOp>(
152
+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
153
+ }
154
+ return rewriter.create <mlir::LLVM::AddressOfOp>(loc, type, symName);
155
+ }
156
+
157
+ // Replaces an existing operation with an AddressOfOp or an AddrSpaceCastOp
158
+ // depending on the existing address spaces of the type.
159
+ mlir::Value replaceWithAddrOfOrASCast (mlir::ConversionPatternRewriter &rewriter,
160
+ mlir::Location loc,
161
+ std::uint64_t globalAS,
162
+ std::uint64_t programAS,
163
+ llvm::StringRef symName, mlir::Type type,
164
+ mlir::Operation *replaceOp) {
165
+ if (mlir::isa<mlir::LLVM::LLVMPointerType>(type)) {
166
+ if (globalAS != programAS) {
167
+ auto llvmAddrOp = rewriter.create <mlir::LLVM::AddressOfOp>(
168
+ loc, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
169
+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddrSpaceCastOp>(
170
+ replaceOp, ::getLlvmPtrType (rewriter.getContext (), programAS),
171
+ llvmAddrOp);
172
+ }
173
+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
174
+ replaceOp, getLlvmPtrType (rewriter.getContext (), globalAS), symName);
175
+ }
176
+ return rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(replaceOp, type,
177
+ symName);
178
+ }
179
+
137
180
// / Lower `fir.address_of` operation to `llvm.address_of` operation.
138
181
struct AddrOfOpConversion : public fir ::FIROpConversion<fir::AddrOfOp> {
139
182
using FIROpConversion::FIROpConversion;
140
183
141
184
llvm::LogicalResult
142
185
matchAndRewrite (fir::AddrOfOp addr, OpAdaptor adaptor,
143
186
mlir::ConversionPatternRewriter &rewriter) const override {
144
- auto ty = convertType (addr.getType ());
145
- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
146
- addr, ty, addr.getSymbol ().getRootReference ().getValue ());
187
+ auto global = addr->getParentOfType <mlir::ModuleOp>()
188
+ .lookupSymbol <mlir::LLVM::GlobalOp>(addr.getSymbol ());
189
+ replaceWithAddrOfOrASCast (
190
+ rewriter, addr->getLoc (),
191
+ global ? global.getAddrSpace () : getGlobalAddressSpace (rewriter),
192
+ getProgramAddressSpace (rewriter),
193
+ global ? global.getSymName ()
194
+ : addr.getSymbol ().getRootReference ().getValue (),
195
+ convertType (addr.getType ()), addr);
147
196
return mlir::success ();
148
197
}
149
198
};
@@ -1350,14 +1399,34 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
1350
1399
? fir::NameUniquer::getTypeDescriptorAssemblyName (recType.getName ())
1351
1400
: fir::NameUniquer::getTypeDescriptorName (recType.getName ());
1352
1401
mlir::Type llvmPtrTy = ::getLlvmPtrType (mod.getContext ());
1402
+
1403
+ // We allow the module to be set to a default layout if it's a regular module
1404
+ // however, we prevent this if it's a GPU module, as the datalayout in these
1405
+ // cases will currently be the union of the GPU Module and the parent builtin
1406
+ // module, with the GPU module overriding the parent where there are duplicates.
1407
+ // However, if we force the default layout onto a GPU module, with no datalayout
1408
+ // it'll result in issues as the infrastructure does not support the union of
1409
+ // two layouts with builtin data layout entries currently (and it doesn't look
1410
+ // like it was intended to).
1411
+ std::optional<mlir::DataLayout> dataLayout =
1412
+ fir::support::getOrSetMLIRDataLayout (
1413
+ mod, /* allowDefaultLayout*/ mlir::isa<mlir::gpu::GPUModuleOp>(mod)
1414
+ ? false
1415
+ : true );
1416
+ assert (dataLayout.has_value () && " Module missing DataLayout information" );
1417
+
1353
1418
if (auto global = mod.template lookupSymbol <fir::GlobalOp>(name)) {
1354
- return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1355
- global.getSymName ());
1419
+ return createAddrOfOrASCast (
1420
+ rewriter, loc, fir::factory::getGlobalAddressSpace (&*dataLayout),
1421
+ fir::factory::getProgramAddressSpace (&*dataLayout),
1422
+ global.getSymName (), llvmPtrTy);
1356
1423
}
1357
1424
if (auto global = mod.template lookupSymbol <mlir::LLVM::GlobalOp>(name)) {
1358
1425
// The global may have already been translated to LLVM.
1359
- return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
1360
- global.getSymName ());
1426
+ return createAddrOfOrASCast (
1427
+ rewriter, loc, global.getAddrSpace (),
1428
+ fir::factory::getProgramAddressSpace (&*dataLayout),
1429
+ global.getSymName (), llvmPtrTy);
1361
1430
}
1362
1431
// Type info derived types do not have type descriptors since they are the
1363
1432
// types defining type descriptors.
@@ -2896,12 +2965,16 @@ struct TypeDescOpConversion : public fir::FIROpConversion<fir::TypeDescOp> {
2896
2965
: fir::NameUniquer::getTypeDescriptorName (recordType.getName ());
2897
2966
auto llvmPtrTy = ::getLlvmPtrType (typeDescOp.getContext ());
2898
2967
if (auto global = module.lookupSymbol <mlir::LLVM::GlobalOp>(typeDescName)) {
2899
- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
2900
- typeDescOp, llvmPtrTy, global.getSymName ());
2968
+ replaceWithAddrOfOrASCast (rewriter, typeDescOp->getLoc (),
2969
+ global.getAddrSpace (),
2970
+ getProgramAddressSpace (rewriter),
2971
+ global.getSymName (), llvmPtrTy, typeDescOp);
2901
2972
return mlir::success ();
2902
2973
} else if (auto global = module.lookupSymbol <fir::GlobalOp>(typeDescName)) {
2903
- rewriter.replaceOpWithNewOp <mlir::LLVM::AddressOfOp>(
2904
- typeDescOp, llvmPtrTy, global.getSymName ());
2974
+ replaceWithAddrOfOrASCast (rewriter, typeDescOp->getLoc (),
2975
+ getGlobalAddressSpace (rewriter),
2976
+ getProgramAddressSpace (rewriter),
2977
+ global.getSymName (), llvmPtrTy, typeDescOp);
2905
2978
return mlir::success ();
2906
2979
}
2907
2980
return mlir::failure ();
@@ -2992,8 +3065,8 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
2992
3065
mlir::SymbolRefAttr comdat;
2993
3066
llvm::ArrayRef<mlir::NamedAttribute> attrs;
2994
3067
auto g = rewriter.create <mlir::LLVM::GlobalOp>(
2995
- loc, tyAttr, isConst, linkage, global.getSymName (), initAttr, 0 , 0 ,
2996
- false , false , comdat, attrs, dbgExprs);
3068
+ loc, tyAttr, isConst, linkage, global.getSymName (), initAttr, 0 ,
3069
+ getGlobalAddressSpace (rewriter), false , false , comdat, attrs, dbgExprs);
2997
3070
2998
3071
if (global.getAlignment () && *global.getAlignment () > 0 )
2999
3072
g.setAlignment (*global.getAlignment ());
0 commit comments