Skip to content

[flang][openmp]Add UserReductionDetails and use in DECLARE REDUCTION #131628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
9 changes: 9 additions & 0 deletions flang/include/flang/Semantics/semantics.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ class SemanticsContext {
// Top-level ProgramTrees are owned by the SemanticsContext for persistence.
ProgramTree &SaveProgramTree(ProgramTree &&);

// Store (and get a reference to the stored string) for mangled names
// used for OpenMP DECLARE REDUCTION.
std::string &StoreUserReductionName(const std::string &name);

private:
struct ScopeIndexComparator {
bool operator()(parser::CharBlock, parser::CharBlock) const;
Expand Down Expand Up @@ -343,6 +347,11 @@ class SemanticsContext {
std::map<const Symbol *, SourceName> moduleFileOutputRenamings_;
UnorderedSymbolSet isDefined_;
std::list<ProgramTree> programTrees_;

// Storage for mangled names used in OMP DECLARE REDUCTION.
// use std::list to avoid re-allocating the string when adding
// more content to the container.
std::list<std::string> userReductionNames_;
};

class Semantics {
Expand Down
38 changes: 37 additions & 1 deletion flang/include/flang/Semantics/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class raw_ostream;
}
namespace Fortran::parser {
struct Expr;
struct OpenMPDeclareReductionConstruct;
struct OmpMetadirectiveDirective;
}

namespace Fortran::semantics {
Expand Down Expand Up @@ -701,14 +703,48 @@ class GenericDetails {
};
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const GenericDetails &);

// Used for OpenMP DECLARE REDUCTION, it holds the information
// needed to resolve which declaration (there could be multiple
// with the same name) to use for a given type.
class UserReductionDetails {
public:
using TypeVector = std::vector<const DeclTypeSpec *>;
using DeclInfo = std::variant<const parser::OpenMPDeclareReductionConstruct *,
const parser::OmpMetadirectiveDirective *>;
using DeclVector = std::vector<DeclInfo>;

UserReductionDetails() = default;

void AddType(const DeclTypeSpec &type) { typeList_.push_back(&type); }
const TypeVector &GetTypeList() const { return typeList_; }

bool SupportsType(const DeclTypeSpec &type) const {
// We have to compare the actual type, not the pointer, as some
// types are not guaranteed to be the same object.
for (auto t : typeList_) {
if (*t == type) {
return true;
}
}
return false;
}

void AddDecl(const DeclInfo &decl) { declList_.push_back(decl); }
const DeclVector &GetDeclList() const { return declList_; }

private:
TypeVector typeList_;
DeclVector declList_;
};

class UnknownDetails {};

using Details = std::variant<UnknownDetails, MainProgramDetails, ModuleDetails,
SubprogramDetails, SubprogramNameDetails, EntityDetails,
ObjectEntityDetails, ProcEntityDetails, AssocEntityDetails,
DerivedTypeDetails, UseDetails, UseErrorDetails, HostAssocDetails,
GenericDetails, ProcBindingDetails, NamelistDetails, CommonBlockDetails,
TypeParamDetails, MiscDetails>;
TypeParamDetails, MiscDetails, UserReductionDetails>;
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const Details &);
std::string DetailsToString(const Details &);

Expand Down
7 changes: 7 additions & 0 deletions flang/lib/Parser/unparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3325,4 +3325,11 @@ template void Unparse<Program>(llvm::raw_ostream &, const Program &, Encoding,
bool, bool, preStatementType *, AnalyzedObjectsAsFortran *);
template void Unparse<Expr>(llvm::raw_ostream &, const Expr &, Encoding, bool,
bool, preStatementType *, AnalyzedObjectsAsFortran *);

template void Unparse<parser::OpenMPDeclareReductionConstruct>(
llvm::raw_ostream &, const parser::OpenMPDeclareReductionConstruct &,
Encoding, bool, bool, preStatementType *, AnalyzedObjectsAsFortran *);
template void Unparse<parser::OmpMetadirectiveDirective>(llvm::raw_ostream &,
const parser::OmpMetadirectiveDirective &, Encoding, bool, bool,
preStatementType *, AnalyzedObjectsAsFortran *);
} // namespace Fortran::parser
10 changes: 10 additions & 0 deletions flang/lib/Semantics/assignment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class AssignmentContext {
void Analyze(const parser::PointerAssignmentStmt &);
void Analyze(const parser::ConcurrentControl &);
int deviceConstructDepth_{0};
SemanticsContext &context() { return context_; }

private:
bool CheckForPureContext(const SomeExpr &rhs, parser::CharBlock rhsSource);
Expand Down Expand Up @@ -213,8 +214,17 @@ void AssignmentContext::PopWhereContext() {

AssignmentChecker::~AssignmentChecker() {}

SemanticsContext &AssignmentChecker::context() {
return context_.value().context();
}

AssignmentChecker::AssignmentChecker(SemanticsContext &context)
: context_{new AssignmentContext{context}} {}

void AssignmentChecker::Enter(
const parser::OpenMPDeclareReductionConstruct &x) {
context().set_location(x.source);
}
void AssignmentChecker::Enter(const parser::AssignmentStmt &x) {
context_.value().Analyze(x);
}
Expand Down
3 changes: 3 additions & 0 deletions flang/lib/Semantics/assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class AssignmentChecker : public virtual BaseChecker {
public:
explicit AssignmentChecker(SemanticsContext &);
~AssignmentChecker();
void Enter(const parser::OpenMPDeclareReductionConstruct &x);
void Enter(const parser::AssignmentStmt &);
void Enter(const parser::PointerAssignmentStmt &);
void Enter(const parser::WhereStmt &);
Expand All @@ -54,6 +55,8 @@ class AssignmentChecker : public virtual BaseChecker {
void Enter(const parser::OpenACCLoopConstruct &);
void Leave(const parser::OpenACCLoopConstruct &);

SemanticsContext &context();

private:
common::Indirection<AssignmentContext> context_;
};
Expand Down
81 changes: 67 additions & 14 deletions flang/lib/Semantics/check-omp-structure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "check-omp-structure.h"
#include "definable.h"
#include "resolve-names-utils.h"
#include "flang/Evaluate/check-expression.h"
#include "flang/Evaluate/expression.h"
#include "flang/Evaluate/type.h"
Expand Down Expand Up @@ -3390,6 +3391,17 @@ bool OmpStructureChecker::CheckReductionOperator(
break;
}
}
// User-defined operators are OK if there has been a declared reduction
// for that. We mangle those names to store the user details.
if (const auto *definedOp{std::get_if<parser::DefinedOpName>(&dOpr.u)}) {
std::string mangled = MangleDefinedOperator(definedOp->v.symbol->name());
const Scope &scope = definedOp->v.symbol->owner();
if (const Symbol *symbol = scope.FindSymbol(mangled)) {
if (symbol->detailsIf<UserReductionDetails>()) {
return true;
}
}
}
context_.Say(source, "Invalid reduction operator in %s clause."_err_en_US,
parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
return false;
Expand All @@ -3403,8 +3415,7 @@ bool OmpStructureChecker::CheckReductionOperator(
valid =
llvm::is_contained({"max", "min", "iand", "ior", "ieor"}, realName);
if (!valid) {
auto *misc{name->symbol->detailsIf<MiscDetails>()};
valid = misc && misc->kind() == MiscDetails::Kind::ConstructName;
valid = name->symbol->detailsIf<UserReductionDetails>();
}
}
if (!valid) {
Expand Down Expand Up @@ -3485,8 +3496,20 @@ void OmpStructureChecker::CheckReductionObjects(
}
}

static bool CheckSymbolSupportsType(const Scope &scope,
const parser::CharBlock &name, const DeclTypeSpec &type) {
if (const auto *symbol{scope.FindSymbol(name)}) {
if (const auto *reductionDetails{
symbol->detailsIf<UserReductionDetails>()}) {
return reductionDetails->SupportsType(type);
}
}
return false;
}

static bool IsReductionAllowedForType(
const parser::OmpReductionIdentifier &ident, const DeclTypeSpec &type) {
const parser::OmpReductionIdentifier &ident, const DeclTypeSpec &type,
const Scope &scope, SemanticsContext &context) {
auto isLogical{[](const DeclTypeSpec &type) -> bool {
return type.category() == DeclTypeSpec::Logical;
}};
Expand All @@ -3506,27 +3529,39 @@ static bool IsReductionAllowedForType(
case parser::DefinedOperator::IntrinsicOperator::Multiply:
case parser::DefinedOperator::IntrinsicOperator::Add:
case parser::DefinedOperator::IntrinsicOperator::Subtract:
return type.IsNumeric(TypeCategory::Integer) ||
if (type.IsNumeric(TypeCategory::Integer) ||
type.IsNumeric(TypeCategory::Real) ||
type.IsNumeric(TypeCategory::Complex);
type.IsNumeric(TypeCategory::Complex))
return true;
break;
Comment on lines +3532 to +3536
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the same change required for the isLogical check below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.


case parser::DefinedOperator::IntrinsicOperator::AND:
case parser::DefinedOperator::IntrinsicOperator::OR:
case parser::DefinedOperator::IntrinsicOperator::EQV:
case parser::DefinedOperator::IntrinsicOperator::NEQV:
return isLogical(type);
if (isLogical(type)) {
return true;
}
break;

// Reduction identifier is not in OMP5.2 Table 5.2
default:
DIE("This should have been caught in CheckIntrinsicOperator");
return false;
}
}
return true;
parser::CharBlock name{MakeNameFromOperator(*intrinsicOp, context)};
return CheckSymbolSupportsType(scope, name, type);
} else if (const auto *definedOp{
std::get_if<parser::DefinedOpName>(&dOpr.u)}) {
return CheckSymbolSupportsType(
scope, MangleDefinedOperator(definedOp->v.symbol->name()), type);
}
DIE("Intrinsic Operator not found - parsing gone wrong?");
}};

auto checkDesignator{[&](const parser::ProcedureDesignator &procD) {
const parser::Name *name{std::get_if<parser::Name>(&procD.u)};
CHECK(name && name->symbol);
if (name && name->symbol) {
const SourceName &realName{name->symbol->GetUltimate().name()};
// OMP5.2: The type [...] of a list item that appears in a
Expand All @@ -3535,18 +3570,35 @@ static bool IsReductionAllowedForType(
// IAND: arguments must be integers: F2023 16.9.100
// IEOR: arguments must be integers: F2023 16.9.106
// IOR: arguments must be integers: F2023 16.9.111
return type.IsNumeric(TypeCategory::Integer);
if (type.IsNumeric(TypeCategory::Integer)) {
return true;
}
} else if (realName == "max" || realName == "min") {
// MAX: arguments must be integer, real, or character:
// F2023 16.9.135
// MIN: arguments must be integer, real, or character:
// F2023 16.9.141
return type.IsNumeric(TypeCategory::Integer) ||
type.IsNumeric(TypeCategory::Real) || isCharacter(type);
if (type.IsNumeric(TypeCategory::Integer) ||
type.IsNumeric(TypeCategory::Real) || isCharacter(type)) {
return true;
}
}

// If we get here, it may be a user declared reduction, so check
// if the symbol has UserReductionDetails, and if so, the type is
// supported.
if (const auto *reductionDetails{
name->symbol->detailsIf<UserReductionDetails>()}) {
return reductionDetails->SupportsType(type);
}

// We also need to check for mangled names (max, min, iand, ieor and ior)
// and then check if the type is there.
parser::CharBlock mangledName{MangleSpecialFunctions(name->source)};
return CheckSymbolSupportsType(scope, mangledName, type);
}
// TODO: user defined reduction operators. Just allow everything for now.
return true;
// Everything else is "not matching type".
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will return false; be executed after the DIE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this used to be an assert, and when changing to die, I didn't remove the line below.

}};

return common::visit(
Expand All @@ -3561,7 +3613,8 @@ void OmpStructureChecker::CheckReductionObjectTypes(

for (auto &[symbol, source] : symbols) {
if (auto *type{symbol->GetType()}) {
if (!IsReductionAllowedForType(ident, *type)) {
const auto &scope{context_.FindScope(symbol->name())};
if (!IsReductionAllowedForType(ident, *type, scope, context_)) {
context_.Say(source,
"The type of '%s' is incompatible with the reduction operator."_err_en_US,
symbol->name());
Expand Down
23 changes: 23 additions & 0 deletions flang/lib/Semantics/mod-file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,7 @@ void ModFileWriter::PutEntity(llvm::raw_ostream &os, const Symbol &symbol) {
[&](const ObjectEntityDetails &) { PutObjectEntity(os, symbol); },
[&](const ProcEntityDetails &) { PutProcEntity(os, symbol); },
[&](const TypeParamDetails &) { PutTypeParam(os, symbol); },
[&](const UserReductionDetails &) { PutUserReduction(os, symbol); },
[&](const auto &) {
common::die("PutEntity: unexpected details: %s",
DetailsToString(symbol.details()).c_str());
Expand Down Expand Up @@ -1035,6 +1036,28 @@ void ModFileWriter::PutTypeParam(llvm::raw_ostream &os, const Symbol &symbol) {
os << '\n';
}

void ModFileWriter::PutUserReduction(
llvm::raw_ostream &os, const Symbol &symbol) {
const auto &details{symbol.get<UserReductionDetails>()};
// The module content for a OpenMP Declare Reduction is the OpenMP
// declaration. There may be multiple declarations.
// Decls are pointers, so do not use a referene.
for (const auto decl : details.GetDeclList()) {
common::visit( //
common::visitors{//
[&](const parser::OpenMPDeclareReductionConstruct *d) {
Unparse(os, *d);
},
[&](const parser::OmpMetadirectiveDirective *m) {
Unparse(os, *m);
},
[&](const auto &) {
DIE("Unknown OpenMP DECLARE REDUCTION content");
}},
decl);
}
}

void PutInit(llvm::raw_ostream &os, const Symbol &symbol, const MaybeExpr &init,
const parser::Expr *unanalyzed) {
if (IsNamedConstant(symbol) || symbol.owner().IsDerivedType()) {
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Semantics/mod-file.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class ModFileWriter {
void PutDerivedType(const Symbol &, const Scope * = nullptr);
void PutDECStructure(const Symbol &, const Scope * = nullptr);
void PutTypeParam(llvm::raw_ostream &, const Symbol &);
void PutUserReduction(llvm::raw_ostream &, const Symbol &);
void PutSubprogram(const Symbol &);
void PutGeneric(const Symbol &);
void PutUse(const Symbol &);
Expand Down
6 changes: 6 additions & 0 deletions flang/lib/Semantics/resolve-names-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,5 +146,11 @@ struct SymbolAndTypeMappings;
void MapSubprogramToNewSymbols(const Symbol &oldSymbol, Symbol &newSymbol,
Scope &newScope, SymbolAndTypeMappings * = nullptr);

parser::CharBlock MakeNameFromOperator(
const parser::DefinedOperator::IntrinsicOperator &op,
SemanticsContext &context);
parser::CharBlock MangleSpecialFunctions(const parser::CharBlock &name);
std::string MangleDefinedOperator(const parser::CharBlock &name);

} // namespace Fortran::semantics
#endif // FORTRAN_SEMANTICS_RESOLVE_NAMES_H_
Loading