Skip to content

Add ZA directives for Flang. #76505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions flang/docs/Directives.md
Original file line number Diff line number Diff line change
@@ -29,3 +29,28 @@ A list of non-standard directives supported by Flang
end
end interface
```

## ARM Streaming SVE directives

These directives are added to support ARM specific instructions. All of
these attributes apply to a specific subroutine or function. These directives
are identical to the attributes provided in C and C++ for the same purpose.
See https://arm-software.github.io/acle/main/acle.html#controlling-the-use-of-streaming-mode for more in depth details. (For the following, function is used
to mean both subroutine and function).

### Directives relating to ARM Streaming mode

* `!dir$ arm_streaming` - The function is intended to be used in streaming
mode.
* `!dir$ arm_streaming_compatible` - The function can work both in streaming
mode and non-streaming mode.
* `!dir$ arm_streaming` - The function will enter streaming mode, and return to

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean arm_locally_streaming?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, of course. Typical copy-pasta error. Thanks for pointing it out.

non-streaming mode when reaturning.

### Directives relating to ZA

* `!dir$ arm_shared_za` - A function that uses ZA for input or output.
* `!dir$ arm_new_za` - A function that has ZA state created and destroyed within
the function.
* `!dir$ arm_preserves_za` - Optimisation hint for the compiler that the
function either doesn't alter, or saves and restores the ZA state.
10 changes: 8 additions & 2 deletions flang/include/flang/Lower/PFTBuilder.h
Original file line number Diff line number Diff line change
@@ -589,6 +589,12 @@ VariableList getDependentVariableList(const Fortran::semantics::Symbol &);

void dump(VariableList &, std::string s = {}); // `s` is an optional dump label

/// Things that can be nested inside of a module or function
/// TODO: add the rest
struct FunctionLikeUnit;
struct CompilerDirectiveUnit;
using NestedUnit = std::variant<FunctionLikeUnit, CompilerDirectiveUnit>;

/// Function-like units may contain evaluations (executable statements) and
/// nested function-like units (internal procedures and function statements).
struct FunctionLikeUnit : public ProgramUnit {
@@ -695,7 +701,7 @@ struct FunctionLikeUnit : public ProgramUnit {
EvaluationList evaluationList;
LabelEvalMap labelEvaluationMap;
SymbolLabelMap assignSymbolLabelMap;
std::list<FunctionLikeUnit> nestedFunctions;
std::list<NestedUnit> nestedUnits;
/// <Symbol, Evaluation> pairs for each entry point. The pair at index 0
/// is the primary entry point; remaining pairs are alternate entry points.
/// The primary entry point symbol is Null for an anonymous program.
@@ -741,7 +747,7 @@ struct ModuleLikeUnit : public ProgramUnit {

ModuleStatement beginStmt;
ModuleStatement endStmt;
std::list<FunctionLikeUnit> nestedFunctions;
std::list<NestedUnit> nestedUnits;
EvaluationList evaluationList;
};

3 changes: 2 additions & 1 deletion flang/include/flang/Parser/parse-tree.h
Original file line number Diff line number Diff line change
@@ -2894,7 +2894,8 @@ struct ModuleSubprogram {
UNION_CLASS_BOILERPLATE(ModuleSubprogram);
std::variant<common::Indirection<FunctionSubprogram>,
common::Indirection<SubroutineSubprogram>,
common::Indirection<SeparateModuleSubprogram>>
common::Indirection<SeparateModuleSubprogram>,
common::Indirection<CompilerDirective>>
u;
};

123 changes: 109 additions & 14 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
@@ -53,6 +53,7 @@
#include "flang/Semantics/runtime-type-info.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Parser/Parser.h"
@@ -303,9 +304,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
for (Fortran::lower::pft::FunctionLikeUnit &f :
m.nestedFunctions)
declareFunction(f);
for (Fortran::lower::pft::NestedUnit &unit :
m.nestedUnits) {
if (auto *f = std::get_if<
Fortran::lower::pft::FunctionLikeUnit>(&unit))
declareFunction(*f);
}
},
[&](Fortran::lower::pft::BlockDataUnit &b) {
if (!globalOmpRequiresSymbol)
@@ -322,13 +326,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&]() { createIntrinsicModuleDefinitions(pft); });

// Primary translation pass.
for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
std::list<Fortran::lower::pft::Program::Units> &units = pft.getUnits();
for (auto it = units.begin(); it != units.end(); it = std::next(it)) {
std::visit(
Fortran::common::visitors{
[&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
[&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
[&](Fortran::lower::pft::BlockDataUnit &b) {},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
processSubprogramDirective(it, units.end(), d);
},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
builder = new fir::FirOpBuilder(bridge.getModule(),
bridge.getKindMap());
@@ -338,7 +345,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder = nullptr;
},
},
u);
*it);
}

// Once all the code has been translated, create global runtime type info
@@ -387,13 +394,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {

// Compute the set of host associated entities from the nested functions.
llvm::SetVector<const Fortran::semantics::Symbol *> escapeHost;
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
collectHostAssociatedVariables(f, escapeHost);
for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
collectHostAssociatedVariables(*f, escapeHost);
}
funit.setHostAssociatedSymbols(escapeHost);

// Declare internal procedures
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
declareFunction(f);
for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
declareFunction(*f);
}
}

/// Get the scope that is defining or using \p sym. The returned scope is not
@@ -4667,8 +4678,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
endNewFunction(funit);
}
funit.setActiveEntry(0);
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
lowerFunc(f); // internal procedure
for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
lowerFunc(*f); // internal procedure
}
}

/// Lower module variable definitions to fir::globalOp and OpenMP/OpenACC
@@ -4692,8 +4705,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {

/// Lower functions contained in a module.
void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
lowerFunc(f);
for (auto it = mod.nestedUnits.begin(); it != mod.nestedUnits.end();
it = std::next(it)) {
std::visit(
Fortran::common::visitors{
[&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
processSubprogramDirective(it, mod.nestedUnits.end(), d);
}},
*it);
}
}

void setCurrentPosition(const Fortran::parser::CharBlock &position) {
@@ -5001,6 +5022,80 @@ class FirConverter : public Fortran::lower::AbstractConverter {
globalOmpRequiresSymbol);
}

/// Process compiler directives that apply to subprograms
template <typename ITERATOR>
void
processSubprogramDirective(ITERATOR it, ITERATOR endIt,
Fortran::lower::pft::CompilerDirectiveUnit &d) {
auto *parserDirective = d.getIf<Fortran::parser::CompilerDirective>();
if (!parserDirective)
return;
auto *nvList =
std::get_if<std::list<Fortran::parser::CompilerDirective::NameValue>>(
&parserDirective->u);
if (!nvList)
return;

// get the function the directive applies to (hopefully the next unit)
mlir::func::FuncOp mlirFunc;
it = std::next(it);
if (it != endIt) {
auto *pftFunction =
std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&*it);
if (pftFunction) {
Fortran::lower::CalleeInterface callee{*pftFunction, *this};
mlirFunc = callee.getFuncOp();
}
}

for (const Fortran::parser::CompilerDirective::NameValue &nv : *nvList) {
std::string name = std::get<Fortran::parser::Name>(nv.t).ToString();

// arm streaming sve directives
auto streamingMode = mlir::arm_sme::ArmStreamingMode::Disabled;
if (name == "arm_streaming")
streamingMode = mlir::arm_sme::ArmStreamingMode::Streaming;
else if (name == "arm_locally_streaming")
streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingLocally;
else if (name == "arm_streaming_compatible")
streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingCompatible;
Comment on lines +5055 to +5061
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a generated helper for this (that returns std::optional<arm_sme::ArmStreamingMode>). I believe it would be mlir::arm_sme::symbolizeArmStreamingMode().

if (streamingMode != mlir::arm_sme::ArmStreamingMode::Disabled) {
if (!mlirFunc) {
// TODO: share diagnostic code with warnings elsewhere
// TODO: source location is printed as loc<"file.f90":line:col>
mlir::Location loc = genLocation(parserDirective->source);
llvm::errs() << loc << ": warning: ignoring directive '" << name
<< "' because it has no associated subprogram\n";
continue;
}
llvm::StringRef attrName =
mlir::arm_sme::stringifyArmStreamingMode(streamingMode);
mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
mlirFunc->setAttr(attrName, unitAttr);
}
auto zaMode = mlir::arm_sme::ArmZaMode::Disabled;
if (name == "arm_new_za")
zaMode = mlir::arm_sme::ArmZaMode::NewZA;
else if (name == "arm_shared_za")
zaMode = mlir::arm_sme::ArmZaMode::SharedZA;
else if (name == "arm_preserves_za")
zaMode = mlir::arm_sme::ArmZaMode::PreservesZA;
Comment on lines +5077 to +5082
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be mlir::arm_sme::symbolizeArmZaMode() for this too.

if (zaMode != mlir::arm_sme::ArmZaMode::Disabled) {
if (!mlirFunc) {
// TODO: share diagnostic code with warnings elsewhere
// TODO: source location is printed as loc<"file.f90":line:col>
mlir::Location loc = genLocation(parserDirective->source);
llvm::errs() << loc << ": warning: ignoring directive '" << name
<< "' because it has no associated subprogram\n";
continue;
}
llvm::StringRef attrName = mlir::arm_sme::stringifyArmZaMode(zaMode);
mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
mlirFunc->setAttr(attrName, unitAttr);
}
}
}

//===--------------------------------------------------------------------===//

Fortran::lower::LoweringBridge &bridge;
Loading