Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 86a03ed

Browse files
committed
extract may/must writes from Halide IR
Previous commits introduced may/must writes in Scop and dependence analysis. Extract those from Halide IR. Change extractAccess to return a flag indicating whether the affine access relation is constructed is exact or not. Exact relations correspond to must writes since we statically know which tensor elements are written. Inexact relations overapproximate non-affine accesses and should be treated as may writes, assuming the tensor elements are not necessarily written.
1 parent 3581849 commit 86a03ed

File tree

3 files changed

+68
-23
lines changed

3 files changed

+68
-23
lines changed

include/tc/core/halide2isl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ struct ScheduleTreeAndAccesses {
6666
/// Union maps describing the reads and writes done. Uses the ids in
6767
/// the schedule tree to denote the containing Stmt, and tags each
6868
/// access with a unique reference id of the form __tc_ref_N.
69-
isl::union_map reads, writes;
69+
isl::union_map reads, mayWrites, mustWrites;
7070

7171
/// The correspondence between from Call and Provide nodes and the
7272
/// reference ids in the reads and writes maps.

src/core/halide2isl.cc

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <algorithm>
1919
#include <numeric>
20+
#include <tuple>
2021
#include <unordered_set>
2122

2223
#include "tc/core/constants.h"
@@ -228,7 +229,20 @@ isl::space makeParamSpace(isl::ctx ctx, const SymbolTable& symbolTable) {
228229
return space;
229230
}
230231

231-
isl::map extractAccess(
232+
// Extract a tagged affine access relation from Halide IR.
233+
// The relation is tagged with a unique identifier, i.e. it lives in the space
234+
// [D[...] -> __tc_ref_#[]] -> A[]
235+
// where # is a unique sequential number, D is the statement identifier
236+
// extracted from "domain" and A is the tensor identifier constructed from
237+
// "tensor". "accesses" map is updated to keep track of the Halide IR nodes in
238+
// which a particular reference # appeared.
239+
// Returns the access relation and a flag indicating whether this relation is
240+
// exact or not. The relation is overapproximated (that is, not exact) if it
241+
// represents a non-affine access, for example, an access with indirection such
242+
// as O(Index(i)) = 42. In such overapproximated access relation, dimensions
243+
// that correspond to affine subscripts are still exact while those that
244+
// correspond to non-affine subscripts are not constrained.
245+
std::pair<isl::map, bool> extractAccess(
232246
isl::set domain,
233247
const IRNode* op,
234248
const std::string& tensor,
@@ -257,6 +271,7 @@ isl::map extractAccess(
257271
isl::map map =
258272
isl::map::universe(domainSpace.map_from_domain_and_range(rangeSpace));
259273

274+
bool exact = true;
260275
for (size_t i = 0; i < args.size(); i++) {
261276
// Then add one equality constraint per dimension to encode the
262277
// point in the allocation actually read/written for each point in
@@ -268,47 +283,64 @@ isl::map extractAccess(
268283
isl::pw_aff(isl::local_space(rangeSpace), isl::dim_type::set, i);
269284
// ... equals the coordinate accessed as a function of the domain.
270285
auto domainPoint = halide2isl::makeIslAffFromExpr(domainSpace, args[i]);
271-
if (!domainPoint.is_null()) {
286+
if (!domainPoint) {
287+
exact = false;
288+
} else {
272289
map = map.intersect(isl::pw_aff(domainPoint).eq_map(rangePoint));
273290
}
274291
}
275292

276-
return map;
293+
return std::make_pair(map, exact);
277294
}
278295

279-
std::pair<isl::union_map, isl::union_map>
296+
std::tuple<isl::union_map, isl::union_map, isl::union_map>
280297
extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
281298
class FindAccesses : public IRGraphVisitor {
282299
using IRGraphVisitor::visit;
283300

284301
void visit(const Call* op) override {
285302
IRGraphVisitor::visit(op);
286303
if (op->call_type == Call::Halide || op->call_type == Call::Image) {
287-
reads = reads.unite(
288-
extractAccess(domain, op, op->name, op->args, accesses));
304+
// Read relations can be safely overapproximated.
305+
isl::map read;
306+
std::tie(read, std::ignore) =
307+
extractAccess(domain, op, op->name, op->args, accesses);
308+
reads = reads.unite(read);
289309
}
290310
}
291311

292312
void visit(const Provide* op) override {
293313
IRGraphVisitor::visit(op);
294-
writes =
295-
writes.unite(extractAccess(domain, op, op->name, op->args, accesses));
314+
315+
// If the write access relation is not exact, we consider that any
316+
// element _may_ be written by the statement. If it is exact, then we
317+
// can guarantee that all the elements specified by the relation _must_
318+
// be written and any previously stored value will be killed.
319+
isl::map write;
320+
bool exact;
321+
std::tie(write, exact) =
322+
extractAccess(domain, op, op->name, op->args, accesses);
323+
if (exact) {
324+
mustWrites = mustWrites.unite(write);
325+
}
326+
mayWrites = mayWrites.unite(write);
296327
}
297328

298329
const isl::set& domain;
299330
AccessMap* accesses;
300331

301332
public:
302-
isl::union_map reads, writes;
333+
isl::union_map reads, mayWrites, mustWrites;
303334

304335
FindAccesses(const isl::set& domain, AccessMap* accesses)
305336
: domain(domain),
306337
accesses(accesses),
307338
reads(isl::union_map::empty(domain.get_space())),
308-
writes(isl::union_map::empty(domain.get_space())) {}
339+
mayWrites(isl::union_map::empty(domain.get_space())),
340+
mustWrites(isl::union_map::empty(domain.get_space())) {}
309341
} finder(domain, accesses);
310342
s.accept(&finder);
311-
return {finder.reads, finder.writes};
343+
return std::make_tuple(finder.reads, finder.mayWrites, finder.mustWrites);
312344
}
313345

314346
/*
@@ -333,7 +365,8 @@ isl::schedule makeScheduleTreeHelper(
333365
isl::set set,
334366
std::vector<std::string>& outer,
335367
isl::union_map* reads,
336-
isl::union_map* writes,
368+
isl::union_map* mayWrites,
369+
isl::union_map* mustWrites,
337370
AccessMap* accesses,
338371
StatementMap* statements,
339372
IteratorMap* iterators) {
@@ -379,7 +412,8 @@ isl::schedule makeScheduleTreeHelper(
379412
set,
380413
outerNext,
381414
reads,
382-
writes,
415+
mayWrites,
416+
mustWrites,
383417
accesses,
384418
statements,
385419
iterators);
@@ -412,7 +446,15 @@ isl::schedule makeScheduleTreeHelper(
412446
std::vector<isl::schedule> schedules;
413447
for (Stmt s : stmts) {
414448
schedules.push_back(makeScheduleTreeHelper(
415-
s, set, outer, reads, writes, accesses, statements, iterators));
449+
s,
450+
set,
451+
outer,
452+
reads,
453+
mayWrites,
454+
mustWrites,
455+
accesses,
456+
statements,
457+
iterators));
416458
}
417459
schedule = schedules[0].sequence(schedules[1]);
418460

@@ -427,23 +469,25 @@ isl::schedule makeScheduleTreeHelper(
427469
isl::set domain = set.set_tuple_id(id);
428470
schedule = isl::schedule::from_domain(domain);
429471

430-
isl::union_map newReads, newWrites;
431-
std::tie(newReads, newWrites) =
472+
isl::union_map newReads, newMayWrites, newMustWrites;
473+
std::tie(newReads, newMayWrites, newMustWrites) =
432474
halide2isl::extractAccesses(domain, op, accesses);
433475

434476
*reads = reads->unite(newReads);
435-
*writes = writes->unite(newWrites);
477+
*mayWrites = mayWrites->unite(newMayWrites);
478+
*mustWrites = mustWrites->unite(newMustWrites);
436479

437480
} else {
438481
LOG(FATAL) << "Unhandled Halide stmt: " << s;
439482
}
440483
return schedule;
441-
};
484+
}
442485

443486
ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
444487
ScheduleTreeAndAccesses result;
445488

446-
result.writes = result.reads = isl::union_map::empty(paramSpace);
489+
result.mayWrites = result.mustWrites = result.reads =
490+
isl::union_map::empty(paramSpace);
447491

448492
// Walk the IR building a schedule tree
449493
std::vector<std::string> outer;
@@ -452,7 +496,8 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
452496
isl::set::universe(paramSpace),
453497
outer,
454498
&result.reads,
455-
&result.writes,
499+
&result.mayWrites,
500+
&result.mustWrites,
456501
&result.accesses,
457502
&result.statements,
458503
&result.iterators);

src/core/polyhedral/scop.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ ScopUPtr Scop::makeScop(
6666
auto tree = halide2isl::makeScheduleTree(paramSpace, components.stmt);
6767
scop->scheduleTreeUPtr = std::move(tree.tree);
6868
scop->reads = tree.reads;
69-
scop->mayWrites = tree.writes;
70-
scop->mustWrites = isl::union_map::empty(scop->mayWrites.get_space());
69+
scop->mayWrites = tree.mayWrites;
70+
scop->mustWrites = tree.mustWrites;
7171
scop->halide.statements = std::move(tree.statements);
7272
scop->halide.accesses = std::move(tree.accesses);
7373
scop->halide.reductions = halide2isl::findReductions(components.stmt);

0 commit comments

Comments
 (0)