Skip to content

[flang] Implement !DIR$ [NO]INLINE and FORCEINLINE directives #134350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions flang/docs/Directives.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ A list of non-standard directives supported by Flang
* `!dir$ novector` disabling vectorization on the following loop.
* `!dir$ nounroll` disabling unrolling on the following loop.
* `!dir$ nounroll_and_jam` disabling unrolling and jamming on the following loop.
* `!dir$ inline` tells the compiler to attempt to inline routines if
this directive is specified before a call statement or for all call function statements
within a DO LOOP. This directive can be improved later to support other place(s) for
inlining function calls.
* `!dir$ forceinline` works in the same way as the `inline` directive, but it force
inlining by the compiler on a function call statement.
* `!dir$ noinline` works in the same way as the `inline` directive, but prevents
any attempt of inlining by the compiler on a function call statement.

# Directive Details

Expand Down
10 changes: 10 additions & 0 deletions flang/include/flang/Evaluate/call.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ class ProcedureRef {
bool IsElemental() const { return proc_.IsElemental(); }
bool hasAlternateReturns() const { return hasAlternateReturns_; }

bool hasNoInline() const { return noInline_; }
void set_noInline(bool ni) { noInline_ = ni; }
bool hasAlwaysInline() const { return alwaysInline_; }
void set_alwaysInline(bool ai) { alwaysInline_ = ai; }
bool hasInlineHint() const { return inlineHint_; }
void set_inlineHint(bool ih) { inlineHint_ = ih; }

Expr<SomeType> *UnwrapArgExpr(int n) {
if (static_cast<std::size_t>(n) < arguments_.size() && arguments_[n]) {
return arguments_[n]->UnwrapExpr();
Expand All @@ -277,6 +284,9 @@ class ProcedureRef {
ActualArguments arguments_;
Chevrons chevrons_;
bool hasAlternateReturns_;
bool noInline_{false};
bool alwaysInline_{false};
bool inlineHint_{false};
};

template <typename A> class FunctionRef : public ProcedureRef {
Expand Down
20 changes: 20 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIRAttr.td
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,24 @@ def fir_OpenMPSafeTempArrayCopyAttr : fir_Attr<"OpenMPSafeTempArrayCopy"> {
}];
}

/// Fortran inline attribute
def FIRinlineNone : I32BitEnumAttrCaseNone<"none">;
def FIRinlineNo : I32BitEnumAttrCaseBit<"no_inline", 0>;
def FIRinlineAlways : I32BitEnumAttrCaseBit<"always_inline", 1>;
def FIRinlineHint : I32BitEnumAttrCaseBit<"inline_hint", 2>;

def fir_FortranInlineEnum
: I32BitEnumAttr<"FortranInlineEnum", "Fortran inline attributes",
[FIRinlineNone, FIRinlineNo, FIRinlineAlways,
FIRinlineHint]> {
let separator = ", ";
let cppNamespace = "::fir";
let genSpecializedAttr = 0;
let printBitEnumPrimaryGroups = 1;
}

def fir_FortranInlineAttr
: EnumAttr<FIROpsDialect, fir_FortranInlineEnum, "inline_attrs"> {
let assemblyFormat = "`<` $value `>`";
}
#endif // FIR_DIALECT_FIR_ATTRS
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2494,6 +2494,7 @@ def fir_CallOp : fir_Op<"call",
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs,
OptionalAttr<fir_FortranInlineAttr>:$inline_attr,
DefaultValuedAttr<Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath
);
Expand Down
3 changes: 3 additions & 0 deletions flang/include/flang/Parser/dump-parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,11 @@ class ParseTreeDumper {
NODE(parser, CompilerDirective)
NODE(CompilerDirective, AssumeAligned)
NODE(CompilerDirective, IgnoreTKR)
NODE(CompilerDirective, Inline)
NODE(CompilerDirective, ForceInline)
NODE(CompilerDirective, LoopCount)
NODE(CompilerDirective, NameValue)
NODE(CompilerDirective, NoInline)
NODE(CompilerDirective, Unrecognized)
NODE(CompilerDirective, VectorAlways)
NODE(CompilerDirective, Unroll)
Expand Down
8 changes: 7 additions & 1 deletion flang/include/flang/Parser/parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -3354,6 +3354,9 @@ struct StmtFunctionStmt {
// !DIR$ NOVECTOR
// !DIR$ NOUNROLL
// !DIR$ NOUNROLL_AND_JAM
// !DIR$ FORCEINLINE
// !DIR$ INLINE
// !DIR$ NOINLINE
// !DIR$ <anything else>
struct CompilerDirective {
UNION_CLASS_BOILERPLATE(CompilerDirective);
Expand Down Expand Up @@ -3382,11 +3385,14 @@ struct CompilerDirective {
EMPTY_CLASS(NoVector);
EMPTY_CLASS(NoUnroll);
EMPTY_CLASS(NoUnrollAndJam);
EMPTY_CLASS(ForceInline);
EMPTY_CLASS(Inline);
EMPTY_CLASS(NoInline);
EMPTY_CLASS(Unrecognized);
CharBlock source;
std::variant<std::list<IgnoreTKR>, LoopCount, std::list<AssumeAligned>,
VectorAlways, std::list<NameValue>, Unroll, UnrollAndJam, Unrecognized,
NoVector, NoUnroll, NoUnrollAndJam>
NoVector, NoUnroll, NoUnrollAndJam, ForceInline, Inline, NoInline>
u;
};

Expand Down
129 changes: 122 additions & 7 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1828,6 +1828,26 @@ class FirConverter : public Fortran::lower::AbstractConverter {
setCurrentPosition(stmt.source);
assert(stmt.typedCall && "Call was not analyzed");
mlir::Value res{};

// Set 'no_inline', 'inline_hint' or 'always_inline' to true on the
// ProcedureRef. The NoInline and AlwaysInline attribute will be set in
// genProcedureRef later.
for (const auto *dir : eval.dirs) {
Fortran::common::visit(
Fortran::common::visitors{
[&](const Fortran::parser::CompilerDirective::ForceInline &) {
stmt.typedCall->set_alwaysInline(true);
},
[&](const Fortran::parser::CompilerDirective::Inline &) {
stmt.typedCall->set_inlineHint(true);
},
[&](const Fortran::parser::CompilerDirective::NoInline &) {
stmt.typedCall->set_noInline(true);
},
[&](const auto &) {}},
dir->u);
}

if (lowerToHighLevelFIR()) {
std::optional<mlir::Type> resultType;
if (stmt.typedCall->hasAlternateReturns())
Expand Down Expand Up @@ -2053,6 +2073,50 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// so no clean-up needs to be generated for these entities.
}

void attachInlineAttributes(
mlir::Operation &op,
const llvm::ArrayRef<const Fortran::parser::CompilerDirective *> &dirs) {
if (dirs.empty())
return;

for (mlir::Value operand : op.getOperands()) {
if (operand.getDefiningOp())
attachInlineAttributes(*operand.getDefiningOp(), dirs);
}

if (fir::CallOp callOp = mlir::dyn_cast<fir::CallOp>(op)) {
for (const auto *dir : dirs) {
Fortran::common::visit(
Fortran::common::visitors{
[&](const Fortran::parser::CompilerDirective::NoInline &) {
callOp.setInlineAttr(fir::FortranInlineEnum::no_inline);
},
[&](const Fortran::parser::CompilerDirective::Inline &) {
callOp.setInlineAttr(fir::FortranInlineEnum::inline_hint);
},
[&](const Fortran::parser::CompilerDirective::ForceInline &) {
callOp.setInlineAttr(fir::FortranInlineEnum::always_inline);
},
[&](const auto &) {}},
dir->u);
}
}
}

void attachAttributesToDoLoopOperations(
fir::DoLoopOp &doLoop,
llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) {
if (!doLoop.getOperation() || dirs.empty())
return;

for (mlir::Block &block : doLoop.getRegion()) {
for (mlir::Operation &op : block.getOperations()) {
if (!dirs.empty())
attachInlineAttributes(op, dirs);
}
}
}

/// Generate FIR for a DO construct. There are six variants:
/// - unstructured infinite and while loops
/// - structured and unstructured increment loops
Expand Down Expand Up @@ -2162,6 +2226,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {

// This call may generate a branch in some contexts.
genFIR(endDoEval, unstructuredContext);

// Add attribute(s) on operations in fir::DoLoopOp if necessary
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there special handling for fir.do_loop but not for fir.if and other operations with regions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the moment, I've focused on do_loop, assignment and callstmt. And in doc/Directives.md I've added a message to say that these directives are partially supported. But if it's really necessary I can add support for the other operations.

for (IncrementLoopInfo &info : incrementLoopNestInfo)
attachAttributesToDoLoopOperations(info.doLoop, doStmtEval.dirs);
}

/// Generate FIR to evaluate loop control values (lower, upper and step).
Expand Down Expand Up @@ -2935,6 +3003,26 @@ class FirConverter : public Fortran::lower::AbstractConverter {
e->dirs.push_back(&dir);
}

void
attachInliningDirectiveToStmt(const Fortran::parser::CompilerDirective &dir,
Fortran::lower::pft::Evaluation *e) {
while (e->isDirective())
e = e->lexicalSuccessor;

// If the successor is a statement or a do loop, the compiler
// will perform inlining.
if (e->isA<Fortran::parser::CallStmt>() ||
e->isA<Fortran::parser::NonLabelDoStmt>() ||
e->isA<Fortran::parser::AssignmentStmt>()) {
e->dirs.push_back(&dir);
} else {
mlir::Location loc = toLocation();
mlir::emitWarning(loc,
"Inlining directive not in front of loops, function"
"call or assignment.\n");
}
}

void genFIR(const Fortran::parser::CompilerDirective &dir) {
Fortran::lower::pft::Evaluation &eval = getEval();

Expand All @@ -2958,6 +3046,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&](const Fortran::parser::CompilerDirective::NoUnrollAndJam &) {
attachDirectiveToLoop(dir, &eval);
},
[&](const Fortran::parser::CompilerDirective::ForceInline &) {
attachInliningDirectiveToStmt(dir, &eval);
},
[&](const Fortran::parser::CompilerDirective::Inline &) {
attachInliningDirectiveToStmt(dir, &eval);
},
[&](const Fortran::parser::CompilerDirective::NoInline &) {
attachInliningDirectiveToStmt(dir, &eval);
},
[&](const auto &) {}},
dir.u);
}
Expand Down Expand Up @@ -4761,7 +4858,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {

void genDataAssignment(
const Fortran::evaluate::Assignment &assign,
const Fortran::evaluate::ProcedureRef *userDefinedAssignment) {
const Fortran::evaluate::ProcedureRef *userDefinedAssignment,
const llvm::ArrayRef<const Fortran::parser::CompilerDirective *> &dirs =
{}) {
mlir::Location loc = getCurrentLocation();
fir::FirOpBuilder &builder = getFirOpBuilder();

Expand Down Expand Up @@ -4834,12 +4933,22 @@ class FirConverter : public Fortran::lower::AbstractConverter {
Fortran::lower::StatementContext localStmtCtx;
hlfir::Entity rhs = evaluateRhs(localStmtCtx);
hlfir::Entity lhs = evaluateLhs(localStmtCtx);
if (isCUDATransfer && !hasCUDAImplicitTransfer)
if (isCUDATransfer && !hasCUDAImplicitTransfer) {
genCUDADataTransfer(builder, loc, assign, lhs, rhs);
else
} else {
// If RHS or LHS have a CallOp in their expression, this operation will
// have the 'no_inline' or 'always_inline' attribute if there is a
// directive just before the assignement.
if (!dirs.empty()) {
if (rhs.getDefiningOp())
attachInlineAttributes(*rhs.getDefiningOp(), dirs);
if (lhs.getDefiningOp())
attachInlineAttributes(*lhs.getDefiningOp(), dirs);
}
builder.create<hlfir::AssignOp>(loc, rhs, lhs,
isWholeAllocatableAssignment,
keepLhsLengthInAllocatableAssignment);
}
if (hasCUDAImplicitTransfer && !isInDeviceContext) {
localSymbols.popScope();
for (mlir::Value temp : implicitTemps)
Expand Down Expand Up @@ -4907,16 +5016,21 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}

/// Shared for both assignments and pointer assignments.
void genAssignment(const Fortran::evaluate::Assignment &assign) {
void
genAssignment(const Fortran::evaluate::Assignment &assign,
const llvm::ArrayRef<const Fortran::parser::CompilerDirective *>
&dirs = {}) {
mlir::Location loc = toLocation();
if (lowerToHighLevelFIR()) {
Fortran::common::visit(
Fortran::common::visitors{
[&](const Fortran::evaluate::Assignment::Intrinsic &) {
genDataAssignment(assign, /*userDefinedAssignment=*/nullptr);
genDataAssignment(assign, /*userDefinedAssignment=*/nullptr,
dirs);
},
[&](const Fortran::evaluate::ProcedureRef &procRef) {
genDataAssignment(assign, /*userDefinedAssignment=*/&procRef);
genDataAssignment(assign, /*userDefinedAssignment=*/&procRef,
dirs);
},
[&](const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) {
if (isInsideHlfirForallOrWhere())
Expand Down Expand Up @@ -5321,7 +5435,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}

void genFIR(const Fortran::parser::AssignmentStmt &stmt) {
genAssignment(*stmt.typedAssignment->v);
Fortran::lower::pft::Evaluation &eval = getEval();
genAssignment(*stmt.typedAssignment->v, eval.dirs);
}

void genFIR(const Fortran::parser::SyncAllStmt &stmt) {
Expand Down
13 changes: 12 additions & 1 deletion flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,9 +645,20 @@ Fortran::lower::genCallOpAndResult(
callResult = dispatch.getResult(0);
} else {
// Standard procedure call with fir.call.
fir::FortranInlineEnumAttr inlineAttr;

if (caller.getCallDescription().hasNoInline())
inlineAttr = fir::FortranInlineEnumAttr::get(
builder.getContext(), fir::FortranInlineEnum::no_inline);
else if (caller.getCallDescription().hasInlineHint())
inlineAttr = fir::FortranInlineEnumAttr::get(
builder.getContext(), fir::FortranInlineEnum::inline_hint);
else if (caller.getCallDescription().hasAlwaysInline())
inlineAttr = fir::FortranInlineEnumAttr::get(
builder.getContext(), fir::FortranInlineEnum::always_inline);
auto call = builder.create<fir::CallOp>(
loc, funcType.getResults(), funcSymbolAttr, operands,
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs);
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs, inlineAttr);

callNumResults = call.getNumResults();
if (callNumResults != 0)
Expand Down
12 changes: 12 additions & 0 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,18 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
llvmCall.setResAttrsAttr(resAttrs);

if (auto inlineAttr = call.getInlineAttrAttr()) {
llvmCall->removeAttr("inline_attr");
if (inlineAttr.getValue() == fir::FortranInlineEnum::no_inline) {
llvmCall.setNoInlineAttr(rewriter.getUnitAttr());
} else if (inlineAttr.getValue() == fir::FortranInlineEnum::inline_hint) {
llvmCall.setInlineHintAttr(rewriter.getUnitAttr());
} else if (inlineAttr.getValue() ==
fir::FortranInlineEnum::always_inline) {
llvmCall.setAlwaysInlineAttr(rewriter.getUnitAttr());
}
}

if (memAttr)
llvmCall.setMemoryEffectsAttr(
mlir::cast<mlir::LLVM::MemoryEffectsAttr>(memAttr));
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ struct DispatchOpConv : public OpConversionPattern<fir::DispatchOp> {
args.append(dispatch.getArgs().begin(), dispatch.getArgs().end());
rewriter.replaceOpWithNewOp<fir::CallOp>(
dispatch, resTypes, nullptr, args, dispatch.getArgAttrsAttr(),
dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr());
dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr(),
/*inline_attr*/ fir::FortranInlineEnumAttr{});
return mlir::success();
}

Expand Down
8 changes: 8 additions & 0 deletions flang/lib/Parser/Fortran-parsers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,11 @@ constexpr auto novector{"NOVECTOR" >> construct<CompilerDirective::NoVector>()};
constexpr auto nounroll{"NOUNROLL" >> construct<CompilerDirective::NoUnroll>()};
constexpr auto nounrollAndJam{
"NOUNROLL_AND_JAM" >> construct<CompilerDirective::NoUnrollAndJam>()};
constexpr auto forceinlineDir{
"FORCEINLINE" >> construct<CompilerDirective::ForceInline>()};
constexpr auto noinlineDir{
"NOINLINE" >> construct<CompilerDirective::NoInline>()};
constexpr auto inlineDir{"INLINE" >> construct<CompilerDirective::Inline>()};
TYPE_PARSER(beginDirective >> "DIR$ "_tok >>
sourced((construct<CompilerDirective>(ignore_tkr) ||
construct<CompilerDirective>(loopCount) ||
Expand All @@ -1324,6 +1329,9 @@ TYPE_PARSER(beginDirective >> "DIR$ "_tok >>
construct<CompilerDirective>(novector) ||
construct<CompilerDirective>(nounrollAndJam) ||
construct<CompilerDirective>(nounroll) ||
construct<CompilerDirective>(noinlineDir) ||
construct<CompilerDirective>(forceinlineDir) ||
construct<CompilerDirective>(inlineDir) ||
construct<CompilerDirective>(
many(construct<CompilerDirective::NameValue>(
name, maybe(("="_tok || ":"_tok) >> digitString64))))) /
Expand Down
7 changes: 7 additions & 0 deletions flang/lib/Parser/unparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1864,6 +1864,13 @@ class UnparseVisitor {
[&](const CompilerDirective::NoUnrollAndJam &) {
Word("!DIR$ NOUNROLL_AND_JAM");
},
[&](const CompilerDirective::ForceInline &) {
Word("!DIR$ FORCEINLINE");
},
[&](const CompilerDirective::Inline &) { Word("!DIR$ INLINE"); },
[&](const CompilerDirective::NoInline &) {
Word("!DIR$ NOINLINE");
},
[&](const CompilerDirective::Unrecognized &) {
Word("!DIR$ ");
Word(x.source.ToString());
Expand Down
Loading