Skip to content

Commit 258e143

Browse files
committed
[OpenACC][CIR][NFC] Refactor to move 'loop' emit into its own file
The 'loop' emit for OpenACC is particularly complicated/involved, so it makes sense to be in its own file. This patch splits it out into its own file, as well as the clause emitter code (as loop is going to require that).
1 parent 985410f commit 258e143

File tree

4 files changed

+349
-297
lines changed

4 files changed

+349
-297
lines changed
+318
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Emit OpenACC clause nodes as CIR code.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include <type_traits>
14+
15+
#include "mlir/Dialect/OpenACC/OpenACC.h"
16+
namespace clang {
17+
// Simple type-trait to see if the first template arg is one of the list, so we
18+
// can tell whether to `if-constexpr` a bunch of stuff.
19+
template <typename ToTest, typename T, typename... Tys>
20+
constexpr bool isOneOfTypes =
21+
std::is_same_v<ToTest, T> || isOneOfTypes<ToTest, Tys...>;
22+
template <typename ToTest, typename T>
23+
constexpr bool isOneOfTypes<ToTest, T> = std::is_same_v<ToTest, T>;
24+
25+
template <typename OpTy>
26+
class OpenACCClauseCIREmitter final
27+
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter<OpTy>> {
28+
OpTy &operation;
29+
CIRGen::CIRGenFunction &cgf;
30+
CIRGen::CIRGenBuilderTy &builder;
31+
32+
// This is necessary since a few of the clauses emit differently based on the
33+
// directive kind they are attached to.
34+
OpenACCDirectiveKind dirKind;
35+
// TODO(cir): This source location should be able to go away once the NYI
36+
// diagnostics are gone.
37+
SourceLocation dirLoc;
38+
39+
llvm::SmallVector<mlir::acc::DeviceType> lastDeviceTypeValues;
40+
41+
void setLastDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
42+
lastDeviceTypeValues.clear();
43+
44+
llvm::for_each(clause.getArchitectures(),
45+
[this](const DeviceTypeArgument &arg) {
46+
lastDeviceTypeValues.push_back(
47+
decodeDeviceType(arg.getIdentifierInfo()));
48+
});
49+
}
50+
51+
void clauseNotImplemented(const OpenACCClause &c) {
52+
cgf.cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
53+
}
54+
55+
mlir::Value createIntExpr(const Expr *intExpr) {
56+
mlir::Value expr = cgf.emitScalarExpr(intExpr);
57+
mlir::Location exprLoc = cgf.cgm.getLoc(intExpr->getBeginLoc());
58+
59+
mlir::IntegerType targetType = mlir::IntegerType::get(
60+
&cgf.getMLIRContext(), cgf.getContext().getIntWidth(intExpr->getType()),
61+
intExpr->getType()->isSignedIntegerOrEnumerationType()
62+
? mlir::IntegerType::SignednessSemantics::Signed
63+
: mlir::IntegerType::SignednessSemantics::Unsigned);
64+
65+
auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
66+
exprLoc, targetType, expr);
67+
return conversionOp.getResult(0);
68+
}
69+
70+
// 'condition' as an OpenACC grammar production is used for 'if' and (some
71+
// variants of) 'self'. It needs to be emitted as a signless-1-bit value, so
72+
// this function emits the expression, then sets the unrealized conversion
73+
// cast correctly, and returns the completed value.
74+
mlir::Value createCondition(const Expr *condExpr) {
75+
mlir::Value condition = cgf.evaluateExprAsBool(condExpr);
76+
mlir::Location exprLoc = cgf.cgm.getLoc(condExpr->getBeginLoc());
77+
mlir::IntegerType targetType = mlir::IntegerType::get(
78+
&cgf.getMLIRContext(), /*width=*/1,
79+
mlir::IntegerType::SignednessSemantics::Signless);
80+
auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
81+
exprLoc, targetType, condition);
82+
return conversionOp.getResult(0);
83+
}
84+
85+
mlir::acc::DeviceType decodeDeviceType(const IdentifierInfo *ii) {
86+
// '*' case leaves no identifier-info, just a nullptr.
87+
if (!ii)
88+
return mlir::acc::DeviceType::Star;
89+
return llvm::StringSwitch<mlir::acc::DeviceType>(ii->getName())
90+
.CaseLower("default", mlir::acc::DeviceType::Default)
91+
.CaseLower("host", mlir::acc::DeviceType::Host)
92+
.CaseLower("multicore", mlir::acc::DeviceType::Multicore)
93+
.CasesLower("nvidia", "acc_device_nvidia",
94+
mlir::acc::DeviceType::Nvidia)
95+
.CaseLower("radeon", mlir::acc::DeviceType::Radeon);
96+
}
97+
98+
public:
99+
OpenACCClauseCIREmitter(OpTy &operation, CIRGen::CIRGenFunction &cgf,
100+
CIRGen::CIRGenBuilderTy &builder,
101+
OpenACCDirectiveKind dirKind, SourceLocation dirLoc)
102+
: operation(operation), cgf(cgf), builder(builder), dirKind(dirKind),
103+
dirLoc(dirLoc) {}
104+
105+
void VisitClause(const OpenACCClause &clause) {
106+
clauseNotImplemented(clause);
107+
}
108+
109+
void VisitDefaultClause(const OpenACCDefaultClause &clause) {
110+
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc
111+
// operations listed in the rest of the arguments.
112+
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
113+
mlir::acc::KernelsOp, mlir::acc::DataOp>) {
114+
switch (clause.getDefaultClauseKind()) {
115+
case OpenACCDefaultClauseKind::None:
116+
operation.setDefaultAttr(mlir::acc::ClauseDefaultValue::None);
117+
break;
118+
case OpenACCDefaultClauseKind::Present:
119+
operation.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
120+
break;
121+
case OpenACCDefaultClauseKind::Invalid:
122+
break;
123+
}
124+
} else {
125+
// TODO: When we've implemented this for everything, switch this to an
126+
// unreachable. Combined constructs remain.
127+
return clauseNotImplemented(clause);
128+
}
129+
}
130+
131+
void VisitDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
132+
setLastDeviceTypeClause(clause);
133+
134+
if constexpr (isOneOfTypes<OpTy, mlir::acc::InitOp,
135+
mlir::acc::ShutdownOp>) {
136+
llvm::for_each(
137+
clause.getArchitectures(), [this](const DeviceTypeArgument &arg) {
138+
operation.addDeviceType(builder.getContext(),
139+
decodeDeviceType(arg.getIdentifierInfo()));
140+
});
141+
} else if constexpr (isOneOfTypes<OpTy, mlir::acc::SetOp>) {
142+
assert(!operation.getDeviceTypeAttr() && "already have device-type?");
143+
assert(clause.getArchitectures().size() <= 1);
144+
145+
if (!clause.getArchitectures().empty())
146+
operation.setDeviceType(
147+
decodeDeviceType(clause.getArchitectures()[0].getIdentifierInfo()));
148+
} else if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
149+
mlir::acc::SerialOp, mlir::acc::KernelsOp,
150+
mlir::acc::DataOp>) {
151+
// Nothing to do here, these constructs don't have any IR for these, as
152+
// they just modify the other clauses IR. So setting of
153+
// `lastDeviceTypeValues` (done above) is all we need.
154+
} else {
155+
// TODO: When we've implemented this for everything, switch this to an
156+
// unreachable. update, data, loop, routine, combined constructs remain.
157+
return clauseNotImplemented(clause);
158+
}
159+
}
160+
161+
void VisitNumWorkersClause(const OpenACCNumWorkersClause &clause) {
162+
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
163+
mlir::acc::KernelsOp>) {
164+
operation.addNumWorkersOperand(builder.getContext(),
165+
createIntExpr(clause.getIntExpr()),
166+
lastDeviceTypeValues);
167+
} else if constexpr (isOneOfTypes<OpTy, mlir::acc::SerialOp>) {
168+
llvm_unreachable("num_workers not valid on serial");
169+
} else {
170+
// TODO: When we've implemented this for everything, switch this to an
171+
// unreachable. Combined constructs remain.
172+
return clauseNotImplemented(clause);
173+
}
174+
}
175+
176+
void VisitVectorLengthClause(const OpenACCVectorLengthClause &clause) {
177+
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
178+
mlir::acc::KernelsOp>) {
179+
operation.addVectorLengthOperand(builder.getContext(),
180+
createIntExpr(clause.getIntExpr()),
181+
lastDeviceTypeValues);
182+
} else if constexpr (isOneOfTypes<OpTy, mlir::acc::SerialOp>) {
183+
llvm_unreachable("vector_length not valid on serial");
184+
} else {
185+
// TODO: When we've implemented this for everything, switch this to an
186+
// unreachable. Combined constructs remain.
187+
return clauseNotImplemented(clause);
188+
}
189+
}
190+
191+
void VisitAsyncClause(const OpenACCAsyncClause &clause) {
192+
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
193+
mlir::acc::KernelsOp, mlir::acc::DataOp>) {
194+
if (!clause.hasIntExpr())
195+
operation.addAsyncOnly(builder.getContext(), lastDeviceTypeValues);
196+
else
197+
operation.addAsyncOperand(builder.getContext(),
198+
createIntExpr(clause.getIntExpr()),
199+
lastDeviceTypeValues);
200+
} else if constexpr (isOneOfTypes<OpTy, mlir::acc::WaitOp>) {
201+
// Wait doesn't have a device_type, so its handling here is slightly
202+
// different.
203+
if (!clause.hasIntExpr())
204+
operation.setAsync(true);
205+
else
206+
operation.getAsyncOperandMutable().append(
207+
createIntExpr(clause.getIntExpr()));
208+
} else {
209+
// TODO: When we've implemented this for everything, switch this to an
210+
// unreachable. Combined constructs remain. Data, enter data, exit data,
211+
// update, combined constructs remain.
212+
return clauseNotImplemented(clause);
213+
}
214+
}
215+
216+
void VisitSelfClause(const OpenACCSelfClause &clause) {
217+
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
218+
mlir::acc::KernelsOp>) {
219+
if (clause.isEmptySelfClause()) {
220+
operation.setSelfAttr(true);
221+
} else if (clause.isConditionExprClause()) {
222+
assert(clause.hasConditionExpr());
223+
operation.getSelfCondMutable().append(
224+
createCondition(clause.getConditionExpr()));
225+
} else {
226+
llvm_unreachable("var-list version of self shouldn't get here");
227+
}
228+
} else {
229+
// TODO: When we've implemented this for everything, switch this to an
230+
// unreachable. If, combined constructs remain.
231+
return clauseNotImplemented(clause);
232+
}
233+
}
234+
235+
void VisitIfClause(const OpenACCIfClause &clause) {
236+
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
237+
mlir::acc::KernelsOp, mlir::acc::InitOp,
238+
mlir::acc::ShutdownOp, mlir::acc::SetOp,
239+
mlir::acc::DataOp, mlir::acc::WaitOp>) {
240+
operation.getIfCondMutable().append(
241+
createCondition(clause.getConditionExpr()));
242+
} else {
243+
// 'if' applies to most of the constructs, but hold off on lowering them
244+
// until we can write tests/know what we're doing with codegen to make
245+
// sure we get it right.
246+
// TODO: When we've implemented this for everything, switch this to an
247+
// unreachable. Enter data, exit data, host_data, update, combined
248+
// constructs remain.
249+
return clauseNotImplemented(clause);
250+
}
251+
}
252+
253+
void VisitDeviceNumClause(const OpenACCDeviceNumClause &clause) {
254+
if constexpr (isOneOfTypes<OpTy, mlir::acc::InitOp, mlir::acc::ShutdownOp,
255+
mlir::acc::SetOp>) {
256+
operation.getDeviceNumMutable().append(
257+
createIntExpr(clause.getIntExpr()));
258+
} else {
259+
llvm_unreachable(
260+
"init, shutdown, set, are only valid device_num constructs");
261+
}
262+
}
263+
264+
void VisitNumGangsClause(const OpenACCNumGangsClause &clause) {
265+
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
266+
mlir::acc::KernelsOp>) {
267+
llvm::SmallVector<mlir::Value> values;
268+
for (const Expr *E : clause.getIntExprs())
269+
values.push_back(createIntExpr(E));
270+
271+
operation.addNumGangsOperands(builder.getContext(), values,
272+
lastDeviceTypeValues);
273+
} else {
274+
// TODO: When we've implemented this for everything, switch this to an
275+
// unreachable. Combined constructs remain.
276+
return clauseNotImplemented(clause);
277+
}
278+
}
279+
280+
void VisitWaitClause(const OpenACCWaitClause &clause) {
281+
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
282+
mlir::acc::KernelsOp, mlir::acc::DataOp>) {
283+
if (!clause.hasExprs()) {
284+
operation.addWaitOnly(builder.getContext(), lastDeviceTypeValues);
285+
} else {
286+
llvm::SmallVector<mlir::Value> values;
287+
if (clause.hasDevNumExpr())
288+
values.push_back(createIntExpr(clause.getDevNumExpr()));
289+
for (const Expr *E : clause.getQueueIdExprs())
290+
values.push_back(createIntExpr(E));
291+
operation.addWaitOperands(builder.getContext(), clause.hasDevNumExpr(),
292+
values, lastDeviceTypeValues);
293+
}
294+
} else {
295+
// TODO: When we've implemented this for everything, switch this to an
296+
// unreachable. Enter data, exit data, update, Combined constructs remain.
297+
return clauseNotImplemented(clause);
298+
}
299+
}
300+
301+
void VisitDefaultAsyncClause(const OpenACCDefaultAsyncClause &clause) {
302+
if constexpr (isOneOfTypes<OpTy, mlir::acc::SetOp>) {
303+
operation.getDefaultAsyncMutable().append(
304+
createIntExpr(clause.getIntExpr()));
305+
} else {
306+
llvm_unreachable("set, is only valid device_num constructs");
307+
}
308+
}
309+
};
310+
311+
template <typename OpTy>
312+
auto makeClauseEmitter(OpTy &op, CIRGen::CIRGenFunction &cgf,
313+
CIRGen::CIRGenBuilderTy &builder,
314+
OpenACCDirectiveKind dirKind, SourceLocation dirLoc) {
315+
return OpenACCClauseCIREmitter<OpTy>(op, cgf, builder, dirKind, dirLoc);
316+
}
317+
318+
} // namespace clang

0 commit comments

Comments
 (0)