17
17
18
18
#include < algorithm>
19
19
#include < numeric>
20
+ #include < tuple>
20
21
#include < unordered_set>
21
22
22
23
#include " tc/core/constants.h"
@@ -228,7 +229,20 @@ isl::space makeParamSpace(isl::ctx ctx, const SymbolTable& symbolTable) {
228
229
return space;
229
230
}
230
231
231
- isl::map extractAccess (
232
+ // Extract a tagged affine access relation from Halide IR.
233
+ // The relation is tagged with a unique identifier, i.e. it lives in the space
234
+ // [D[...] -> __tc_ref_#[]] -> A[]
235
+ // where # is a unique sequential number, D is the statement identifier
236
+ // extracted from "domain" and A is the tensor identifier constructed from
237
+ // "tensor". "accesses" map is updated to keep track of the Halide IR nodes in
238
+ // which a particular reference # appeared.
239
+ // Returns the access relation and a flag indicating whether this relation is
240
+ // exact or not. The relation is overapproximated (that is, not exact) if it
241
+ // represents a non-affine access, for example, an access with indirection such
242
+ // as O(Index(i)) = 42. In such overapproximated access relation, dimensions
243
+ // that correspond to affine subscripts are still exact while those that
244
+ // correspond to non-affine subscripts are not constrained.
245
+ std::pair<isl::map, bool > extractAccess (
232
246
isl::set domain,
233
247
const IRNode* op,
234
248
const std::string& tensor,
@@ -257,6 +271,7 @@ isl::map extractAccess(
257
271
isl::map map =
258
272
isl::map::universe (domainSpace.map_from_domain_and_range (rangeSpace));
259
273
274
+ bool exact = true ;
260
275
for (size_t i = 0 ; i < args.size (); i++) {
261
276
// Then add one equality constraint per dimension to encode the
262
277
// point in the allocation actually read/written for each point in
@@ -268,47 +283,64 @@ isl::map extractAccess(
268
283
isl::pw_aff (isl::local_space (rangeSpace), isl::dim_type::set, i);
269
284
// ... equals the coordinate accessed as a function of the domain.
270
285
auto domainPoint = halide2isl::makeIslAffFromExpr (domainSpace, args[i]);
271
- if (!domainPoint.is_null ()) {
286
+ if (!domainPoint) {
287
+ exact = false ;
288
+ } else {
272
289
map = map.intersect (isl::pw_aff (domainPoint).eq_map (rangePoint));
273
290
}
274
291
}
275
292
276
- return map;
293
+ return std::make_pair ( map, exact) ;
277
294
}
278
295
279
- std::pair< isl::union_map, isl::union_map>
296
+ std::tuple<isl::union_map, isl::union_map, isl::union_map>
280
297
extractAccesses (isl::set domain, const Stmt& s, AccessMap* accesses) {
281
298
class FindAccesses : public IRGraphVisitor {
282
299
using IRGraphVisitor::visit;
283
300
284
301
void visit (const Call* op) override {
285
302
IRGraphVisitor::visit (op);
286
303
if (op->call_type == Call::Halide || op->call_type == Call::Image) {
287
- reads = reads.unite (
288
- extractAccess (domain, op, op->name , op->args , accesses));
304
+ // Read relations can be safely overapproximated.
305
+ isl::map read ;
306
+ std::tie (read , std::ignore) =
307
+ extractAccess (domain, op, op->name , op->args , accesses);
308
+ reads = reads.unite (read );
289
309
}
290
310
}
291
311
292
312
void visit (const Provide* op) override {
293
313
IRGraphVisitor::visit (op);
294
- writes =
295
- writes.unite (extractAccess (domain, op, op->name , op->args , accesses));
314
+
315
+ // If the write access relation is not exact, we consider that any
316
+ // element _may_ be written by the statement. If it is exact, then we
317
+ // can guarantee that all the elements specified by the relation _must_
318
+ // be written and any previously stored value will be killed.
319
+ isl::map write ;
320
+ bool exact;
321
+ std::tie (write , exact) =
322
+ extractAccess (domain, op, op->name , op->args , accesses);
323
+ if (exact) {
324
+ mustWrites = mustWrites.unite (write );
325
+ }
326
+ mayWrites = mayWrites.unite (write );
296
327
}
297
328
298
329
const isl::set& domain;
299
330
AccessMap* accesses;
300
331
301
332
public:
302
- isl::union_map reads, writes ;
333
+ isl::union_map reads, mayWrites, mustWrites ;
303
334
304
335
FindAccesses (const isl::set& domain, AccessMap* accesses)
305
336
: domain(domain),
306
337
accesses (accesses),
307
338
reads(isl::union_map::empty(domain.get_space())),
308
- writes(isl::union_map::empty(domain.get_space())) {}
339
+ mayWrites(isl::union_map::empty(domain.get_space())),
340
+ mustWrites(isl::union_map::empty(domain.get_space())) {}
309
341
} finder(domain, accesses);
310
342
s.accept(&finder);
311
- return { finder.reads , finder.writes } ;
343
+ return std::make_tuple( finder.reads, finder.mayWrites, finder.mustWrites) ;
312
344
}
313
345
314
346
/*
@@ -333,7 +365,8 @@ isl::schedule makeScheduleTreeHelper(
333
365
isl::set set,
334
366
std::vector<std::string>& outer,
335
367
isl::union_map* reads,
336
- isl::union_map* writes,
368
+ isl::union_map* mayWrites,
369
+ isl::union_map* mustWrites,
337
370
AccessMap* accesses,
338
371
StatementMap* statements,
339
372
IteratorMap* iterators) {
@@ -379,7 +412,8 @@ isl::schedule makeScheduleTreeHelper(
379
412
set,
380
413
outerNext,
381
414
reads,
382
- writes,
415
+ mayWrites,
416
+ mustWrites,
383
417
accesses,
384
418
statements,
385
419
iterators);
@@ -412,7 +446,15 @@ isl::schedule makeScheduleTreeHelper(
412
446
std::vector<isl::schedule> schedules;
413
447
for (Stmt s : stmts) {
414
448
schedules.push_back (makeScheduleTreeHelper (
415
- s, set, outer, reads, writes, accesses, statements, iterators));
449
+ s,
450
+ set,
451
+ outer,
452
+ reads,
453
+ mayWrites,
454
+ mustWrites,
455
+ accesses,
456
+ statements,
457
+ iterators));
416
458
}
417
459
schedule = schedules[0 ].sequence (schedules[1 ]);
418
460
@@ -427,23 +469,25 @@ isl::schedule makeScheduleTreeHelper(
427
469
isl::set domain = set.set_tuple_id (id);
428
470
schedule = isl::schedule::from_domain (domain);
429
471
430
- isl::union_map newReads, newWrites ;
431
- std::tie (newReads, newWrites ) =
472
+ isl::union_map newReads, newMayWrites, newMustWrites ;
473
+ std::tie (newReads, newMayWrites, newMustWrites ) =
432
474
halide2isl::extractAccesses (domain, op, accesses);
433
475
434
476
*reads = reads->unite (newReads);
435
- *writes = writes->unite (newWrites);
477
+ *mayWrites = mayWrites->unite (newMayWrites);
478
+ *mustWrites = mustWrites->unite (newMustWrites);
436
479
437
480
} else {
438
481
LOG (FATAL) << " Unhandled Halide stmt: " << s;
439
482
}
440
483
return schedule;
441
- };
484
+ }
442
485
443
486
ScheduleTreeAndAccesses makeScheduleTree (isl::space paramSpace, const Stmt& s) {
444
487
ScheduleTreeAndAccesses result;
445
488
446
- result.writes = result.reads = isl::union_map::empty (paramSpace);
489
+ result.mayWrites = result.mustWrites = result.reads =
490
+ isl::union_map::empty (paramSpace);
447
491
448
492
// Walk the IR building a schedule tree
449
493
std::vector<std::string> outer;
@@ -452,7 +496,8 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
452
496
isl::set::universe (paramSpace),
453
497
outer,
454
498
&result.reads ,
455
- &result.writes ,
499
+ &result.mayWrites ,
500
+ &result.mustWrites ,
456
501
&result.accesses ,
457
502
&result.statements ,
458
503
&result.iterators );
0 commit comments