Skip to content

Commit 181209a

Browse files
committed
[flang] Implement !DIR$ NOINLINE, INLINE and FORCEINLINE directives
1 parent aeb06c6 commit 181209a

File tree

17 files changed

+375
-12
lines changed

17 files changed

+375
-12
lines changed

flang/docs/Directives.md

+8
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ A list of non-standard directives supported by Flang
5353
* `!dir$ novector` disabling vectorization on the following loop.
5454
* `!dir$ nounroll` disabling unrolling on the following loop.
5555
* `!dir$ nounroll_and_jam` disabling unrolling and jamming on the following loop.
56+
* `!dir$ inline` tells the compiler to attempt to inline routines if
57+
this directive is specified before a call statement or for all call function statements
58+
within a DO LOOP. This directive can be improved later to support other place(s) for
59+
inlining function calls.
60+
* `!dir$ forceinline` works in the same way as the `inline` directive, but it force
61+
inlining by the compiler on a function call statement.
62+
* `!dir$ noinline` works in the same way as the `inline` directive, but prevents
63+
any attempt of inlining by the compiler on a function call statement.
5664

5765
# Directive Details
5866

flang/include/flang/Evaluate/call.h

+10
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,13 @@ class ProcedureRef {
254254
bool IsElemental() const { return proc_.IsElemental(); }
255255
bool hasAlternateReturns() const { return hasAlternateReturns_; }
256256

257+
bool hasNoInline() const { return noInline_; }
258+
void set_noInline(bool ni) { noInline_ = ni; }
259+
bool hasAlwaysInline() const { return alwaysInline_; }
260+
void set_alwaysInline(bool ai) { alwaysInline_ = ai; }
261+
bool hasInlineHint() const { return inlineHint_; }
262+
void set_inlineHint(bool ih) { inlineHint_ = ih; }
263+
257264
Expr<SomeType> *UnwrapArgExpr(int n) {
258265
if (static_cast<std::size_t>(n) < arguments_.size() && arguments_[n]) {
259266
return arguments_[n]->UnwrapExpr();
@@ -277,6 +284,9 @@ class ProcedureRef {
277284
ActualArguments arguments_;
278285
Chevrons chevrons_;
279286
bool hasAlternateReturns_;
287+
bool noInline_{false};
288+
bool alwaysInline_{false};
289+
bool inlineHint_{false};
280290
};
281291

282292
template <typename A> class FunctionRef : public ProcedureRef {

flang/include/flang/Optimizer/Dialect/FIRAttr.td

+20
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,24 @@ def fir_OpenMPSafeTempArrayCopyAttr : fir_Attr<"OpenMPSafeTempArrayCopy"> {
200200
}];
201201
}
202202

203+
/// Fortran inline attribute
204+
def FIRinlineNone : I32BitEnumAttrCaseNone<"none">;
205+
def FIRinlineNo : I32BitEnumAttrCaseBit<"no_inline", 0>;
206+
def FIRinlineAlways : I32BitEnumAttrCaseBit<"always_inline", 1>;
207+
def FIRinlineHint : I32BitEnumAttrCaseBit<"inline_hint", 2>;
208+
209+
def fir_FortranInlineEnum
210+
: I32BitEnumAttr<"FortranInlineEnum", "Fortran inline attributes",
211+
[FIRinlineNone, FIRinlineNo, FIRinlineAlways,
212+
FIRinlineHint]> {
213+
let separator = ", ";
214+
let cppNamespace = "::fir";
215+
let genSpecializedAttr = 0;
216+
let printBitEnumPrimaryGroups = 1;
217+
}
218+
219+
def fir_FortranInlineAttr
220+
: EnumAttr<FIROpsDialect, fir_FortranInlineEnum, "inline_attrs"> {
221+
let assemblyFormat = "`<` $value `>`";
222+
}
203223
#endif // FIR_DIALECT_FIR_ATTRS

flang/include/flang/Optimizer/Dialect/FIROps.td

+1
Original file line numberDiff line numberDiff line change
@@ -2494,6 +2494,7 @@ def fir_CallOp : fir_Op<"call",
24942494
OptionalAttr<DictArrayAttr>:$arg_attrs,
24952495
OptionalAttr<DictArrayAttr>:$res_attrs,
24962496
OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs,
2497+
OptionalAttr<fir_FortranInlineAttr>:$inline_attr,
24972498
DefaultValuedAttr<Arith_FastMathAttr,
24982499
"::mlir::arith::FastMathFlags::none">:$fastmath
24992500
);

flang/include/flang/Parser/dump-parse-tree.h

+3
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,11 @@ class ParseTreeDumper {
204204
NODE(parser, CompilerDirective)
205205
NODE(CompilerDirective, AssumeAligned)
206206
NODE(CompilerDirective, IgnoreTKR)
207+
NODE(CompilerDirective, Inline)
208+
NODE(CompilerDirective, ForceInline)
207209
NODE(CompilerDirective, LoopCount)
208210
NODE(CompilerDirective, NameValue)
211+
NODE(CompilerDirective, NoInline)
209212
NODE(CompilerDirective, Unrecognized)
210213
NODE(CompilerDirective, VectorAlways)
211214
NODE(CompilerDirective, Unroll)

flang/include/flang/Parser/parse-tree.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -3354,6 +3354,9 @@ struct StmtFunctionStmt {
33543354
// !DIR$ NOVECTOR
33553355
// !DIR$ NOUNROLL
33563356
// !DIR$ NOUNROLL_AND_JAM
3357+
// !DIR$ FORCEINLINE
3358+
// !DIR$ INLINE
3359+
// !DIR$ NOINLINE
33573360
// !DIR$ <anything else>
33583361
struct CompilerDirective {
33593362
UNION_CLASS_BOILERPLATE(CompilerDirective);
@@ -3382,11 +3385,14 @@ struct CompilerDirective {
33823385
EMPTY_CLASS(NoVector);
33833386
EMPTY_CLASS(NoUnroll);
33843387
EMPTY_CLASS(NoUnrollAndJam);
3388+
EMPTY_CLASS(ForceInline);
3389+
EMPTY_CLASS(Inline);
3390+
EMPTY_CLASS(NoInline);
33853391
EMPTY_CLASS(Unrecognized);
33863392
CharBlock source;
33873393
std::variant<std::list<IgnoreTKR>, LoopCount, std::list<AssumeAligned>,
33883394
VectorAlways, std::list<NameValue>, Unroll, UnrollAndJam, Unrecognized,
3389-
NoVector, NoUnroll, NoUnrollAndJam>
3395+
NoVector, NoUnroll, NoUnrollAndJam, ForceInline, Inline, NoInline>
33903396
u;
33913397
};
33923398

flang/lib/Lower/Bridge.cpp

+122-7
Original file line numberDiff line numberDiff line change
@@ -1828,6 +1828,26 @@ class FirConverter : public Fortran::lower::AbstractConverter {
18281828
setCurrentPosition(stmt.source);
18291829
assert(stmt.typedCall && "Call was not analyzed");
18301830
mlir::Value res{};
1831+
1832+
// Set 'no_inline', 'inline_hint' or 'always_inline' to true on the
1833+
// ProcedureRef. The NoInline and AlwaysInline attribute will be set in
1834+
// genProcedureRef later.
1835+
for (const auto *dir : eval.dirs) {
1836+
Fortran::common::visit(
1837+
Fortran::common::visitors{
1838+
[&](const Fortran::parser::CompilerDirective::ForceInline &) {
1839+
stmt.typedCall->set_alwaysInline(true);
1840+
},
1841+
[&](const Fortran::parser::CompilerDirective::Inline &) {
1842+
stmt.typedCall->set_inlineHint(true);
1843+
},
1844+
[&](const Fortran::parser::CompilerDirective::NoInline &) {
1845+
stmt.typedCall->set_noInline(true);
1846+
},
1847+
[&](const auto &) {}},
1848+
dir->u);
1849+
}
1850+
18311851
if (lowerToHighLevelFIR()) {
18321852
std::optional<mlir::Type> resultType;
18331853
if (stmt.typedCall->hasAlternateReturns())
@@ -2053,6 +2073,50 @@ class FirConverter : public Fortran::lower::AbstractConverter {
20532073
// so no clean-up needs to be generated for these entities.
20542074
}
20552075

2076+
void attachInlineAttributes(
2077+
mlir::Operation &op,
2078+
const llvm::ArrayRef<const Fortran::parser::CompilerDirective *> &dirs) {
2079+
if (dirs.empty())
2080+
return;
2081+
2082+
for (mlir::Value operand : op.getOperands()) {
2083+
if (operand.getDefiningOp())
2084+
attachInlineAttributes(*operand.getDefiningOp(), dirs);
2085+
}
2086+
2087+
if (fir::CallOp callOp = mlir::dyn_cast<fir::CallOp>(op)) {
2088+
for (const auto *dir : dirs) {
2089+
Fortran::common::visit(
2090+
Fortran::common::visitors{
2091+
[&](const Fortran::parser::CompilerDirective::NoInline &) {
2092+
callOp.setInlineAttr(fir::FortranInlineEnum::no_inline);
2093+
},
2094+
[&](const Fortran::parser::CompilerDirective::Inline &) {
2095+
callOp.setInlineAttr(fir::FortranInlineEnum::inline_hint);
2096+
},
2097+
[&](const Fortran::parser::CompilerDirective::ForceInline &) {
2098+
callOp.setInlineAttr(fir::FortranInlineEnum::always_inline);
2099+
},
2100+
[&](const auto &) {}},
2101+
dir->u);
2102+
}
2103+
}
2104+
}
2105+
2106+
void attachAttributesToDoLoopOperations(
2107+
fir::DoLoopOp &doLoop,
2108+
llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) {
2109+
if (!doLoop.getOperation() || dirs.empty())
2110+
return;
2111+
2112+
for (mlir::Block &block : doLoop.getRegion()) {
2113+
for (mlir::Operation &op : block.getOperations()) {
2114+
if (!dirs.empty())
2115+
attachInlineAttributes(op, dirs);
2116+
}
2117+
}
2118+
}
2119+
20562120
/// Generate FIR for a DO construct. There are six variants:
20572121
/// - unstructured infinite and while loops
20582122
/// - structured and unstructured increment loops
@@ -2162,6 +2226,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
21622226

21632227
// This call may generate a branch in some contexts.
21642228
genFIR(endDoEval, unstructuredContext);
2229+
2230+
// Add attribute(s) on operations in fir::DoLoopOp if necessary
2231+
for (IncrementLoopInfo &info : incrementLoopNestInfo)
2232+
attachAttributesToDoLoopOperations(info.doLoop, doStmtEval.dirs);
21652233
}
21662234

21672235
/// Generate FIR to evaluate loop control values (lower, upper and step).
@@ -2935,6 +3003,26 @@ class FirConverter : public Fortran::lower::AbstractConverter {
29353003
e->dirs.push_back(&dir);
29363004
}
29373005

3006+
void
3007+
attachInliningDirectiveToStmt(const Fortran::parser::CompilerDirective &dir,
3008+
Fortran::lower::pft::Evaluation *e) {
3009+
while (e->isDirective())
3010+
e = e->lexicalSuccessor;
3011+
3012+
// If the successor is a statement or a do loop, the compiler
3013+
// will perform inlining.
3014+
if (e->isA<Fortran::parser::CallStmt>() ||
3015+
e->isA<Fortran::parser::NonLabelDoStmt>() ||
3016+
e->isA<Fortran::parser::AssignmentStmt>()) {
3017+
e->dirs.push_back(&dir);
3018+
} else {
3019+
mlir::Location loc = toLocation();
3020+
mlir::emitWarning(loc,
3021+
"Inlining directive not in front of loops, function"
3022+
"call or assignment.\n");
3023+
}
3024+
}
3025+
29383026
void genFIR(const Fortran::parser::CompilerDirective &dir) {
29393027
Fortran::lower::pft::Evaluation &eval = getEval();
29403028

@@ -2958,6 +3046,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
29583046
[&](const Fortran::parser::CompilerDirective::NoUnrollAndJam &) {
29593047
attachDirectiveToLoop(dir, &eval);
29603048
},
3049+
[&](const Fortran::parser::CompilerDirective::ForceInline &) {
3050+
attachInliningDirectiveToStmt(dir, &eval);
3051+
},
3052+
[&](const Fortran::parser::CompilerDirective::Inline &) {
3053+
attachInliningDirectiveToStmt(dir, &eval);
3054+
},
3055+
[&](const Fortran::parser::CompilerDirective::NoInline &) {
3056+
attachInliningDirectiveToStmt(dir, &eval);
3057+
},
29613058
[&](const auto &) {}},
29623059
dir.u);
29633060
}
@@ -4761,7 +4858,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
47614858

47624859
void genDataAssignment(
47634860
const Fortran::evaluate::Assignment &assign,
4764-
const Fortran::evaluate::ProcedureRef *userDefinedAssignment) {
4861+
const Fortran::evaluate::ProcedureRef *userDefinedAssignment,
4862+
const llvm::ArrayRef<const Fortran::parser::CompilerDirective *> &dirs =
4863+
{}) {
47654864
mlir::Location loc = getCurrentLocation();
47664865
fir::FirOpBuilder &builder = getFirOpBuilder();
47674866

@@ -4834,12 +4933,22 @@ class FirConverter : public Fortran::lower::AbstractConverter {
48344933
Fortran::lower::StatementContext localStmtCtx;
48354934
hlfir::Entity rhs = evaluateRhs(localStmtCtx);
48364935
hlfir::Entity lhs = evaluateLhs(localStmtCtx);
4837-
if (isCUDATransfer && !hasCUDAImplicitTransfer)
4936+
if (isCUDATransfer && !hasCUDAImplicitTransfer) {
48384937
genCUDADataTransfer(builder, loc, assign, lhs, rhs);
4839-
else
4938+
} else {
4939+
// If RHS or LHS have a CallOp in their expression, this operation will
4940+
// have the 'no_inline' or 'always_inline' attribute if there is a
4941+
// directive just before the assignement.
4942+
if (!dirs.empty()) {
4943+
if (rhs.getDefiningOp())
4944+
attachInlineAttributes(*rhs.getDefiningOp(), dirs);
4945+
if (lhs.getDefiningOp())
4946+
attachInlineAttributes(*lhs.getDefiningOp(), dirs);
4947+
}
48404948
builder.create<hlfir::AssignOp>(loc, rhs, lhs,
48414949
isWholeAllocatableAssignment,
48424950
keepLhsLengthInAllocatableAssignment);
4951+
}
48434952
if (hasCUDAImplicitTransfer && !isInDeviceContext) {
48444953
localSymbols.popScope();
48454954
for (mlir::Value temp : implicitTemps)
@@ -4907,16 +5016,21 @@ class FirConverter : public Fortran::lower::AbstractConverter {
49075016
}
49085017

49095018
/// Shared for both assignments and pointer assignments.
4910-
void genAssignment(const Fortran::evaluate::Assignment &assign) {
5019+
void
5020+
genAssignment(const Fortran::evaluate::Assignment &assign,
5021+
const llvm::ArrayRef<const Fortran::parser::CompilerDirective *>
5022+
&dirs = {}) {
49115023
mlir::Location loc = toLocation();
49125024
if (lowerToHighLevelFIR()) {
49135025
Fortran::common::visit(
49145026
Fortran::common::visitors{
49155027
[&](const Fortran::evaluate::Assignment::Intrinsic &) {
4916-
genDataAssignment(assign, /*userDefinedAssignment=*/nullptr);
5028+
genDataAssignment(assign, /*userDefinedAssignment=*/nullptr,
5029+
dirs);
49175030
},
49185031
[&](const Fortran::evaluate::ProcedureRef &procRef) {
4919-
genDataAssignment(assign, /*userDefinedAssignment=*/&procRef);
5032+
genDataAssignment(assign, /*userDefinedAssignment=*/&procRef,
5033+
dirs);
49205034
},
49215035
[&](const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) {
49225036
if (isInsideHlfirForallOrWhere())
@@ -5321,7 +5435,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
53215435
}
53225436

53235437
void genFIR(const Fortran::parser::AssignmentStmt &stmt) {
5324-
genAssignment(*stmt.typedAssignment->v);
5438+
Fortran::lower::pft::Evaluation &eval = getEval();
5439+
genAssignment(*stmt.typedAssignment->v, eval.dirs);
53255440
}
53265441

53275442
void genFIR(const Fortran::parser::SyncAllStmt &stmt) {

flang/lib/Lower/ConvertCall.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -645,9 +645,20 @@ Fortran::lower::genCallOpAndResult(
645645
callResult = dispatch.getResult(0);
646646
} else {
647647
// Standard procedure call with fir.call.
648+
fir::FortranInlineEnumAttr inlineAttr;
649+
650+
if (caller.getCallDescription().hasNoInline())
651+
inlineAttr = fir::FortranInlineEnumAttr::get(
652+
builder.getContext(), fir::FortranInlineEnum::no_inline);
653+
else if (caller.getCallDescription().hasInlineHint())
654+
inlineAttr = fir::FortranInlineEnumAttr::get(
655+
builder.getContext(), fir::FortranInlineEnum::inline_hint);
656+
else if (caller.getCallDescription().hasAlwaysInline())
657+
inlineAttr = fir::FortranInlineEnumAttr::get(
658+
builder.getContext(), fir::FortranInlineEnum::always_inline);
648659
auto call = builder.create<fir::CallOp>(
649660
loc, funcType.getResults(), funcSymbolAttr, operands,
650-
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs);
661+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs, inlineAttr);
651662

652663
callNumResults = call.getNumResults();
653664
if (callNumResults != 0)

flang/lib/Optimizer/CodeGen/CodeGen.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,18 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
634634
if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
635635
llvmCall.setResAttrsAttr(resAttrs);
636636

637+
if (auto inlineAttr = call.getInlineAttrAttr()) {
638+
llvmCall->removeAttr("inline_attr");
639+
if (inlineAttr.getValue() == fir::FortranInlineEnum::no_inline) {
640+
llvmCall.setNoInlineAttr(rewriter.getUnitAttr());
641+
} else if (inlineAttr.getValue() == fir::FortranInlineEnum::inline_hint) {
642+
llvmCall.setInlineHintAttr(rewriter.getUnitAttr());
643+
} else if (inlineAttr.getValue() ==
644+
fir::FortranInlineEnum::always_inline) {
645+
llvmCall.setAlwaysInlineAttr(rewriter.getUnitAttr());
646+
}
647+
}
648+
637649
if (memAttr)
638650
llvmCall.setMemoryEffectsAttr(
639651
mlir::cast<mlir::LLVM::MemoryEffectsAttr>(memAttr));

flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ struct DispatchOpConv : public OpConversionPattern<fir::DispatchOp> {
207207
args.append(dispatch.getArgs().begin(), dispatch.getArgs().end());
208208
rewriter.replaceOpWithNewOp<fir::CallOp>(
209209
dispatch, resTypes, nullptr, args, dispatch.getArgAttrsAttr(),
210-
dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr());
210+
dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr(),
211+
/*inline_attr*/ fir::FortranInlineEnumAttr{});
211212
return mlir::success();
212213
}
213214

flang/lib/Parser/Fortran-parsers.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,11 @@ constexpr auto novector{"NOVECTOR" >> construct<CompilerDirective::NoVector>()};
13141314
constexpr auto nounroll{"NOUNROLL" >> construct<CompilerDirective::NoUnroll>()};
13151315
constexpr auto nounrollAndJam{
13161316
"NOUNROLL_AND_JAM" >> construct<CompilerDirective::NoUnrollAndJam>()};
1317+
constexpr auto forceinlineDir{
1318+
"FORCEINLINE" >> construct<CompilerDirective::ForceInline>()};
1319+
constexpr auto noinlineDir{
1320+
"NOINLINE" >> construct<CompilerDirective::NoInline>()};
1321+
constexpr auto inlineDir{"INLINE" >> construct<CompilerDirective::Inline>()};
13171322
TYPE_PARSER(beginDirective >> "DIR$ "_tok >>
13181323
sourced((construct<CompilerDirective>(ignore_tkr) ||
13191324
construct<CompilerDirective>(loopCount) ||
@@ -1324,6 +1329,9 @@ TYPE_PARSER(beginDirective >> "DIR$ "_tok >>
13241329
construct<CompilerDirective>(novector) ||
13251330
construct<CompilerDirective>(nounrollAndJam) ||
13261331
construct<CompilerDirective>(nounroll) ||
1332+
construct<CompilerDirective>(noinlineDir) ||
1333+
construct<CompilerDirective>(forceinlineDir) ||
1334+
construct<CompilerDirective>(inlineDir) ||
13271335
construct<CompilerDirective>(
13281336
many(construct<CompilerDirective::NameValue>(
13291337
name, maybe(("="_tok || ":"_tok) >> digitString64))))) /

flang/lib/Parser/unparse.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -1864,6 +1864,13 @@ class UnparseVisitor {
18641864
[&](const CompilerDirective::NoUnrollAndJam &) {
18651865
Word("!DIR$ NOUNROLL_AND_JAM");
18661866
},
1867+
[&](const CompilerDirective::ForceInline &) {
1868+
Word("!DIR$ FORCEINLINE");
1869+
},
1870+
[&](const CompilerDirective::Inline &) { Word("!DIR$ INLINE"); },
1871+
[&](const CompilerDirective::NoInline &) {
1872+
Word("!DIR$ NOINLINE");
1873+
},
18671874
[&](const CompilerDirective::Unrecognized &) {
18681875
Word("!DIR$ ");
18691876
Word(x.source.ToString());

0 commit comments

Comments
 (0)