Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 17b4322

Browse files
Merge pull request #184 from facebookresearch/pr/codegen_llvm_names
emitLLVMKernel: avoid relying on isl set variable names
2 parents 5cddab8 + 5bc3e2c commit 17b4322

File tree

5 files changed

+109
-84
lines changed

5 files changed

+109
-84
lines changed

include/tc/core/halide2isl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ isl::aff makeIslAffFromInt(isl::space space, int64_t i);
5151
// does not correspond to a parameter or set dimension of the space.
5252
isl::aff makeIslAffFromExpr(isl::space space, const Halide::Expr& e);
5353

54+
typedef std::unordered_map<isl::id, std::vector<std::string>, isl::IslIdIslHash>
55+
IteratorMap;
5456
typedef std::unordered_map<isl::id, Halide::Internal::Stmt, isl::IslIdIslHash>
5557
StatementMap;
5658
typedef std::unordered_map<const Halide::Internal::IRNode*, isl::id> AccessMap;
@@ -73,6 +75,10 @@ struct ScheduleTreeAndAccesses {
7375
/// The correspondence between leaf Stmts and the statement ids
7476
/// refered to above.
7577
StatementMap statements;
78+
79+
/// The correspondence between statement ids and the outer loop iterators
80+
/// of the corresponding leaf Stmt.
81+
IteratorMap iterators;
7682
};
7783

7884
/// Make a schedule tree from a Halide Stmt, along with auxiliary data

include/tc/core/polyhedral/scop.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,17 @@ struct Scop {
358358
// Assumes such argument exists.
359359
const Halide::OutputImageParam& findArgument(isl::id id) const;
360360

361+
// Make an affine function from a Halide Expr that is defined
362+
// over the instance set of the statement with identifier "stmtId" and
363+
// with parameters specified by "paramSpace". Return a
364+
// null isl::aff if the expression is not affine. Fail if any
365+
// of the variables does not correspond to a parameter or
366+
// an instance identifier of the statement.
367+
isl::aff makeIslAffFromStmtExpr(
368+
isl::id stmtId,
369+
isl::space paramSpace,
370+
const Halide::Expr& e) const;
371+
361372
// Promote a tensor reference group to a storage of a given "kind",
362373
// inserting the copy
363374
// statements below the given node. Inserts an Extension node below the give
@@ -419,9 +430,10 @@ struct Scop {
419430
std::unordered_map<isl::id, Halide::Internal::Stmt, isl::IslIdIslHash>
420431
statements;
421432
std::unordered_map<const Halide::Internal::IRNode*, isl::id> accesses;
433+
halide2isl::IteratorMap iterators;
422434
} halide;
423435

424-
// Poyhedral IR
436+
// Polyhedral IR
425437
//
426438
// The domain is collected from the root of the ScheduleTree; no redundant
427439
// state is kept.

src/core/halide2isl.cc

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -316,13 +316,33 @@ struct ScheduleTreeAndDomain {
316316
isl::union_set domain;
317317
};
318318

319+
/*
320+
* Helper function for extracting a schedule tree from a Halide Stmt,
321+
* recursively descending over the Stmt.
322+
* "s" is the current position in the recursive descent.
323+
* "set" describes the bounds on the outer loop iterators.
324+
* "outer" contains the names of the outer loop iterators
325+
* from outermost to innermost.
326+
* Return the schedule tree corresponding to the subtree at "s",
327+
* along with a separated out domain.
328+
*
329+
* "reads" and "writes" collect the accesses found along the way.
330+
* "accesses" collects the mapping from Call (for the reads) and Provide nodes
331+
* (for the writes) to the corresponding tag in the access relations.
332+
* "statements" collects the mapping from instance set tuple identifiers
333+
* to the corresponding Provide node.
334+
* "iterators" collects the mapping from instance set tuple identifiers
335+
* to the corresponding outer loop iterator names, from outermost to innermost.
336+
*/
319337
ScheduleTreeAndDomain makeScheduleTreeHelper(
320338
const Stmt& s,
321339
isl::set set,
340+
std::vector<std::string>& outer,
322341
isl::union_map* reads,
323342
isl::union_map* writes,
324343
AccessMap* accesses,
325-
StatementMap* statements) {
344+
StatementMap* statements,
345+
IteratorMap* iterators) {
326346
ScheduleTreeAndDomain result;
327347
if (auto op = s.as<For>()) {
328348
// Add one additional dimension to our set of loop variables
@@ -358,8 +378,17 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
358378
}
359379

360380
// Recursively descend.
381+
auto outerNext = outer;
382+
outerNext.push_back(op->name);
361383
auto body = makeScheduleTreeHelper(
362-
op->body, set, reads, writes, accesses, statements);
384+
op->body,
385+
set,
386+
outerNext,
387+
reads,
388+
writes,
389+
accesses,
390+
statements,
391+
iterators);
363392

364393
// Create an affine function that defines an ordering for all
365394
// the statements in the body of this loop over the values of
@@ -405,8 +434,8 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
405434
// children.
406435
std::vector<ScheduleTreeUPtr> trees;
407436
for (Stmt s : stmts) {
408-
auto mem =
409-
makeScheduleTreeHelper(s, set, reads, writes, accesses, statements);
437+
auto mem = makeScheduleTreeHelper(
438+
s, set, outer, reads, writes, accesses, statements, iterators);
410439
ScheduleTreeUPtr filter;
411440
if (mem.tree) {
412441
// No statement instances are shared between the blocks, so we
@@ -438,6 +467,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
438467
size_t stmtIndex = statements->size();
439468
isl::id id(set.get_ctx(), kStatementLabel + std::to_string(stmtIndex));
440469
statements->emplace(id, op);
470+
iterators->emplace(id, outer);
441471
isl::set domain = set.set_tuple_id(id);
442472
result.domain = domain;
443473

@@ -460,13 +490,16 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
460490
result.writes = result.reads = isl::union_map::empty(paramSpace);
461491

462492
// Walk the IR building a schedule tree
493+
std::vector<std::string> outer;
463494
auto treeAndDomain = makeScheduleTreeHelper(
464495
s,
465496
isl::set::universe(paramSpace),
497+
outer,
466498
&result.reads,
467499
&result.writes,
468500
&result.accesses,
469-
&result.statements);
501+
&result.statements,
502+
&result.iterators);
470503

471504
// TODO: This fails if the stmt is just a Provide node, I'm not sure
472505
// what the schedule tree should look like in that case.

src/core/polyhedral/codegen_llvm.cc

Lines changed: 33 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,9 @@ isl::aff makeIslAffFromExpr(isl::space space, const Halide::Expr& e);
7272

7373
namespace polyhedral {
7474

75+
using IteratorMapType = std::unordered_map<std::string, isl::ast_expr>;
7576
using IteratorMapsType =
76-
std::unordered_map<isl::id, isl::pw_multi_aff, isl::IslIdIslHash>;
77+
std::unordered_map<isl::id, IteratorMapType, isl::IslIdIslHash>;
7778

7879
using IteratorLLVMValueMapType =
7980
std::unordered_map<isl::id, llvm::Value*, isl::IslIdIslHash>;
@@ -96,14 +97,6 @@ llvm::Value* getLLVMConstantSignedInt64(int64_t v) {
9697
return llvm::ConstantInt::get(llvm::Type::getInt64Ty(llvmCtx), v, true);
9798
}
9899

99-
isl::aff extractAff(isl::pw_multi_aff pma) {
100-
isl::PMA pma_(pma);
101-
CHECK_EQ(pma_.size(), 1);
102-
isl::MA ma(pma_[0].second);
103-
CHECK_EQ(ma.size(), 1);
104-
return ma[0];
105-
}
106-
107100
int64_t IslExprToSInt(isl::ast_expr e) {
108101
CHECK(isl_ast_expr_get_type(e.get()) == isl_ast_expr_type::isl_ast_expr_int);
109102
assert(sizeof(long) <= 8); // long is assumed to fit to 64bits
@@ -214,7 +207,7 @@ static constexpr int kOptLevel = 3;
214207

215208
class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
216209
public:
217-
const isl::pw_multi_aff* iteratorMap_;
210+
const IteratorMapType* iteratorMap_;
218211
CodeGen_TC(Target t) : CodeGen_X86(t) {}
219212

220213
using CodeGen_X86::codegen;
@@ -249,6 +242,11 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
249242
return std::move(module);
250243
}
251244

245+
// Convert an isl AST expression into an llvm::Value.
246+
// Only expressions that consist of a pure identifier or
247+
// a pure integer constant are currently supported.
248+
llvm::Value* getValue(isl::ast_expr expr);
249+
252250
protected:
253251
using CodeGen_X86::visit;
254252
void visit(const Halide::Internal::Call* call) override {
@@ -272,44 +270,7 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
272270
}
273271
}
274272
void visit(const Halide::Internal::Variable* op) override {
275-
auto aff = halide2isl::makeIslAffFromExpr(
276-
iteratorMap_->get_space().range(), Halide::Expr(op));
277-
278-
auto subscriptPma = isl::pw_aff(aff).pullback(*iteratorMap_);
279-
auto subscriptAff = extractAff(subscriptPma);
280-
281-
// sanity checks
282-
CHECK_EQ(subscriptAff.dim(isl::dim_type::div), 0);
283-
CHECK_EQ(subscriptAff.dim(isl::dim_type::out), 1);
284-
for (int d = 0; d < subscriptAff.dim(isl::dim_type::param); ++d) {
285-
auto v = subscriptAff.get_coefficient_val(isl::dim_type::param, d);
286-
CHECK(v.is_zero());
287-
}
288-
289-
llvm::Optional<int> posOne;
290-
int sum = 0;
291-
for (int d = 0; d < subscriptAff.dim(isl::dim_type::in); ++d) {
292-
auto v = subscriptAff.get_coefficient_val(isl::dim_type::in, d);
293-
CHECK(v.is_zero() or v.is_one());
294-
if (v.is_zero()) {
295-
continue;
296-
}
297-
++sum;
298-
posOne = d;
299-
}
300-
CHECK_LE(sum, 1);
301-
302-
if (sum == 0) {
303-
value =
304-
getLLVMConstantSignedInt64(toSInt(subscriptAff.get_constant_val()));
305-
return;
306-
}
307-
CHECK(posOne);
308-
309-
std::string name(
310-
isl_aff_get_dim_name(subscriptAff.get(), isl_dim_in, *posOne));
311-
312-
value = sym_get(name);
273+
value = getValue(iteratorMap_->at(op->name));
313274
}
314275

315276
public:
@@ -361,6 +322,21 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
361322
}
362323
};
363324

325+
llvm::Value* CodeGen_TC::getValue(isl::ast_expr expr) {
326+
switch (isl_ast_expr_get_type(expr.get())) {
327+
case isl_ast_expr_type::isl_ast_expr_id:
328+
return sym_get(expr.get_id().get_name());
329+
case isl_ast_expr_type::isl_ast_expr_int: {
330+
auto val = isl::manage(isl_ast_expr_get_val(expr.get()));
331+
CHECK(val.is_int());
332+
return getLLVMConstantSignedInt64(val.get_num_si());
333+
}
334+
default:
335+
LOG(FATAL) << "NYI";
336+
return nullptr;
337+
}
338+
}
339+
364340
class LLVMCodegen {
365341
void collectTensor(const Halide::OutputImageParam& t) {
366342
auto sizes =
@@ -638,22 +614,7 @@ class LLVMCodegen {
638614
llvm::SmallVector<llvm::Value*, 5> subscriptValues;
639615

640616
for (const auto& subscript : subscripts) {
641-
switch (isl_ast_expr_get_type(subscript.get())) {
642-
case isl_ast_expr_type::isl_ast_expr_id: {
643-
subscriptValues.push_back(
644-
halide_cg.sym_get(subscript.get_id().get_name()));
645-
break;
646-
}
647-
case isl_ast_expr_type::isl_ast_expr_int: {
648-
auto val = isl::manage(isl_ast_expr_get_val(subscript.get()));
649-
CHECK_EQ(val.get_den_si(), 1);
650-
subscriptValues.push_back(
651-
getLLVMConstantSignedInt64(val.get_num_si()));
652-
break;
653-
}
654-
default:
655-
LOG(FATAL) << "NYI";
656-
}
617+
subscriptValues.push_back(halide_cg.getValue(subscript));
657618
}
658619

659620
auto destAddr = halide_cg.get_builder().CreateInBoundsGEP(
@@ -703,34 +664,28 @@ IslCodegenRes codegenISL(const Scop& scop) {
703664
const Scop& scop,
704665
StmtSubscriptExprMapType& stmtSubscripts) -> isl::ast_node {
705666
auto expr = node.user_get_expr();
706-
// We rename loop-related dimensions manually.
707667
auto schedule = build.get_schedule();
708-
auto scheduleSpace = build.get_schedule_space();
709668
auto scheduleMap = isl::map::from_union_map(schedule);
710669

711670
auto stmtId = expr.get_op_arg(0).get_id();
712671
// auto nodeId = isl::id(
713672
// node.get_ctx(),
714673
// std::string(kAstNodeIdPrefix) + std::to_string(nAstNodes()++));
715674
CHECK_EQ(0, iteratorMaps.count(stmtId)) << "entry exists: " << stmtId;
716-
CHECK_EQ(
717-
scheduleMap.dim(isl::dim_type::out),
718-
scheduleSpace.dim(isl::dim_type::set));
719-
for (int i = 0; i < scheduleSpace.dim(isl::dim_type::set); ++i) {
720-
scheduleMap = scheduleMap.set_dim_id(
721-
isl::dim_type::out,
722-
i,
723-
scheduleSpace.get_dim_id(isl::dim_type::set, i));
724-
}
725675
auto iteratorMap = isl::pw_multi_aff(scheduleMap.reverse());
726-
iteratorMaps.emplace(stmtId, iteratorMap);
676+
auto iterators = scop.halide.iterators.at(stmtId);
677+
auto& stmtIteratorMap = iteratorMaps[stmtId];
678+
for (int i = 0; i < iterators.size(); ++i) {
679+
auto expr = build.expr_from(iteratorMap.get_pw_aff(i));
680+
stmtIteratorMap.emplace(iterators[i], expr);
681+
}
727682
auto& subscripts = stmtSubscripts[stmtId];
728683
auto provide =
729684
scop.halide.statements.at(stmtId).as<Halide::Internal::Provide>();
730685
for (auto e : provide->args) {
731686
const auto& map = iteratorMap;
732-
auto space = map.get_space().range();
733-
auto aff = halide2isl::makeIslAffFromExpr(space, e);
687+
auto space = map.get_space().params();
688+
auto aff = scop.makeIslAffFromStmtExpr(stmtId, space, e);
734689
auto pulled = isl::pw_aff(aff).pullback(map);
735690
CHECK_EQ(pulled.n_piece(), 1);
736691
subscripts.push_back(build.expr_from(pulled));

src/core/polyhedral/scop.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ ScopUPtr Scop::makeScop(
7070
scop->halide.statements = std::move(tree.statements);
7171
scop->halide.accesses = std::move(tree.accesses);
7272
scop->halide.reductions = halide2isl::findReductions(components.stmt);
73+
scop->halide.iterators = std::move(tree.iterators);
7374

7475
// Set partial schedule tuples for proper comparison with ISL
7576
// schedules (needs DFSPreorder numbering). Just for testing.
@@ -557,5 +558,23 @@ const Halide::OutputImageParam& Scop::findArgument(isl::id id) const {
557558
return *halide.inputs.begin();
558559
}
559560

561+
isl::aff Scop::makeIslAffFromStmtExpr(
562+
isl::id stmtId,
563+
isl::space paramSpace,
564+
const Halide::Expr& e) const {
565+
auto ctx = stmtId.get_ctx();
566+
auto iterators = halide.iterators.at(stmtId);
567+
auto space = paramSpace.set_from_params();
568+
space = space.add_dims(isl::dim_type::set, iterators.size());
569+
// Set the names of the set dimensions of "space" for use
570+
// by halide2isl::makeIslAffFromExpr.
571+
for (int i = 0; i < iterators.size(); ++i) {
572+
isl::id id(ctx, iterators[i]);
573+
space = space.set_dim_id(isl::dim_type::set, i, id);
574+
}
575+
space = space.set_tuple_id(isl::dim_type::set, stmtId);
576+
return halide2isl::makeIslAffFromExpr(space, e);
577+
}
578+
560579
} // namespace polyhedral
561580
} // namespace tc

0 commit comments

Comments
 (0)