-
Notifications
You must be signed in to change notification settings - Fork 13.5k
Add ZA directives for Flang. #76505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add ZA directives for Flang. #76505
Changes from all commits
1522c61
8ce3b25
0453b9c
69384d3
7e9b552
7eb4a3d
6ab4020
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,7 @@ | |
#include "flang/Semantics/runtime-type-info.h" | ||
#include "flang/Semantics/symbol.h" | ||
#include "flang/Semantics/tools.h" | ||
#include "mlir/Dialect/ArmSME/Transforms/Passes.h" | ||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Parser/Parser.h" | ||
|
@@ -303,9 +304,12 @@ class FirConverter : public Fortran::lower::AbstractConverter { | |
}, | ||
[&](Fortran::lower::pft::ModuleLikeUnit &m) { | ||
lowerModuleDeclScope(m); | ||
for (Fortran::lower::pft::FunctionLikeUnit &f : | ||
m.nestedFunctions) | ||
declareFunction(f); | ||
for (Fortran::lower::pft::NestedUnit &unit : | ||
m.nestedUnits) { | ||
if (auto *f = std::get_if< | ||
Fortran::lower::pft::FunctionLikeUnit>(&unit)) | ||
declareFunction(*f); | ||
} | ||
}, | ||
[&](Fortran::lower::pft::BlockDataUnit &b) { | ||
if (!globalOmpRequiresSymbol) | ||
|
@@ -322,13 +326,16 @@ class FirConverter : public Fortran::lower::AbstractConverter { | |
[&]() { createIntrinsicModuleDefinitions(pft); }); | ||
|
||
// Primary translation pass. | ||
for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) { | ||
std::list<Fortran::lower::pft::Program::Units> &units = pft.getUnits(); | ||
for (auto it = units.begin(); it != units.end(); it = std::next(it)) { | ||
std::visit( | ||
Fortran::common::visitors{ | ||
[&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); }, | ||
[&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); }, | ||
[&](Fortran::lower::pft::BlockDataUnit &b) {}, | ||
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {}, | ||
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) { | ||
processSubprogramDirective(it, units.end(), d); | ||
}, | ||
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) { | ||
builder = new fir::FirOpBuilder(bridge.getModule(), | ||
bridge.getKindMap()); | ||
|
@@ -338,7 +345,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { | |
builder = nullptr; | ||
}, | ||
}, | ||
u); | ||
*it); | ||
} | ||
|
||
// Once all the code has been translated, create global runtime type info | ||
|
@@ -387,13 +394,17 @@ class FirConverter : public Fortran::lower::AbstractConverter { | |
|
||
// Compute the set of host associated entities from the nested functions. | ||
llvm::SetVector<const Fortran::semantics::Symbol *> escapeHost; | ||
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions) | ||
collectHostAssociatedVariables(f, escapeHost); | ||
for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) { | ||
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested)) | ||
collectHostAssociatedVariables(*f, escapeHost); | ||
} | ||
funit.setHostAssociatedSymbols(escapeHost); | ||
|
||
// Declare internal procedures | ||
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions) | ||
declareFunction(f); | ||
for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) { | ||
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested)) | ||
declareFunction(*f); | ||
} | ||
} | ||
|
||
/// Get the scope that is defining or using \p sym. The returned scope is not | ||
|
@@ -4667,8 +4678,10 @@ class FirConverter : public Fortran::lower::AbstractConverter { | |
endNewFunction(funit); | ||
} | ||
funit.setActiveEntry(0); | ||
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions) | ||
lowerFunc(f); // internal procedure | ||
for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) { | ||
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested)) | ||
lowerFunc(*f); // internal procedure | ||
} | ||
} | ||
|
||
/// Lower module variable definitions to fir::globalOp and OpenMP/OpenACC | ||
|
@@ -4692,8 +4705,16 @@ class FirConverter : public Fortran::lower::AbstractConverter { | |
|
||
/// Lower functions contained in a module. | ||
void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) { | ||
for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions) | ||
lowerFunc(f); | ||
for (auto it = mod.nestedUnits.begin(); it != mod.nestedUnits.end(); | ||
it = std::next(it)) { | ||
std::visit( | ||
Fortran::common::visitors{ | ||
[&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); }, | ||
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) { | ||
processSubprogramDirective(it, mod.nestedUnits.end(), d); | ||
}}, | ||
*it); | ||
} | ||
} | ||
|
||
void setCurrentPosition(const Fortran::parser::CharBlock &position) { | ||
|
@@ -5001,6 +5022,80 @@ class FirConverter : public Fortran::lower::AbstractConverter { | |
globalOmpRequiresSymbol); | ||
} | ||
|
||
/// Process compiler directives that apply to subprograms | ||
template <typename ITERATOR> | ||
void | ||
processSubprogramDirective(ITERATOR it, ITERATOR endIt, | ||
Fortran::lower::pft::CompilerDirectiveUnit &d) { | ||
auto *parserDirective = d.getIf<Fortran::parser::CompilerDirective>(); | ||
if (!parserDirective) | ||
return; | ||
auto *nvList = | ||
std::get_if<std::list<Fortran::parser::CompilerDirective::NameValue>>( | ||
&parserDirective->u); | ||
if (!nvList) | ||
return; | ||
|
||
// get the function the directive applies to (hopefully the next unit) | ||
mlir::func::FuncOp mlirFunc; | ||
it = std::next(it); | ||
if (it != endIt) { | ||
auto *pftFunction = | ||
std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&*it); | ||
if (pftFunction) { | ||
Fortran::lower::CalleeInterface callee{*pftFunction, *this}; | ||
mlirFunc = callee.getFuncOp(); | ||
} | ||
} | ||
|
||
for (const Fortran::parser::CompilerDirective::NameValue &nv : *nvList) { | ||
std::string name = std::get<Fortran::parser::Name>(nv.t).ToString(); | ||
|
||
// arm streaming sve directives | ||
auto streamingMode = mlir::arm_sme::ArmStreamingMode::Disabled; | ||
if (name == "arm_streaming") | ||
streamingMode = mlir::arm_sme::ArmStreamingMode::Streaming; | ||
else if (name == "arm_locally_streaming") | ||
streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingLocally; | ||
else if (name == "arm_streaming_compatible") | ||
streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingCompatible; | ||
Comment on lines
+5055
to
+5061
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should be a generated helper for this (that returns |
||
if (streamingMode != mlir::arm_sme::ArmStreamingMode::Disabled) { | ||
if (!mlirFunc) { | ||
// TODO: share diagnostic code with warnings elsewhere | ||
// TODO: source location is printed as loc<"file.f90":line:col> | ||
mlir::Location loc = genLocation(parserDirective->source); | ||
llvm::errs() << loc << ": warning: ignoring directive '" << name | ||
<< "' because it has no associated subprogram\n"; | ||
continue; | ||
} | ||
llvm::StringRef attrName = | ||
mlir::arm_sme::stringifyArmStreamingMode(streamingMode); | ||
mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext()); | ||
mlirFunc->setAttr(attrName, unitAttr); | ||
} | ||
auto zaMode = mlir::arm_sme::ArmZaMode::Disabled; | ||
if (name == "arm_new_za") | ||
zaMode = mlir::arm_sme::ArmZaMode::NewZA; | ||
else if (name == "arm_shared_za") | ||
zaMode = mlir::arm_sme::ArmZaMode::SharedZA; | ||
else if (name == "arm_preserves_za") | ||
zaMode = mlir::arm_sme::ArmZaMode::PreservesZA; | ||
Comment on lines
+5077
to
+5082
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should be |
||
if (zaMode != mlir::arm_sme::ArmZaMode::Disabled) { | ||
if (!mlirFunc) { | ||
// TODO: share diagnostic code with warnings elsewhere | ||
// TODO: source location is printed as loc<"file.f90":line:col> | ||
mlir::Location loc = genLocation(parserDirective->source); | ||
llvm::errs() << loc << ": warning: ignoring directive '" << name | ||
<< "' because it has no associated subprogram\n"; | ||
continue; | ||
} | ||
llvm::StringRef attrName = mlir::arm_sme::stringifyArmZaMode(zaMode); | ||
mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext()); | ||
mlirFunc->setAttr(attrName, unitAttr); | ||
} | ||
} | ||
} | ||
|
||
//===--------------------------------------------------------------------===// | ||
|
||
Fortran::lower::LoweringBridge &bridge; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean
arm_locally_streaming
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, of course. Typical copy-pasta error. Thanks for pointing it out.