@@ -72,8 +72,9 @@ isl::aff makeIslAffFromExpr(isl::space space, const Halide::Expr& e);
72
72
73
73
namespace polyhedral {
74
74
75
+ using IteratorMapType = std::unordered_map<std::string, isl::ast_expr>;
75
76
using IteratorMapsType =
76
- std::unordered_map<isl::id, isl::pw_multi_aff , isl::IslIdIslHash>;
77
+ std::unordered_map<isl::id, IteratorMapType , isl::IslIdIslHash>;
77
78
78
79
using IteratorLLVMValueMapType =
79
80
std::unordered_map<isl::id, llvm::Value*, isl::IslIdIslHash>;
@@ -96,14 +97,6 @@ llvm::Value* getLLVMConstantSignedInt64(int64_t v) {
96
97
return llvm::ConstantInt::get (llvm::Type::getInt64Ty (llvmCtx), v, true );
97
98
}
98
99
99
- isl::aff extractAff (isl::pw_multi_aff pma) {
100
- isl::PMA pma_ (pma);
101
- CHECK_EQ (pma_.size (), 1 );
102
- isl::MA ma (pma_[0 ].second );
103
- CHECK_EQ (ma.size (), 1 );
104
- return ma[0 ];
105
- }
106
-
107
100
int64_t IslExprToSInt (isl::ast_expr e) {
108
101
CHECK (isl_ast_expr_get_type (e.get ()) == isl_ast_expr_type::isl_ast_expr_int);
109
102
assert (sizeof (long ) <= 8 ); // long is assumed to fit to 64bits
@@ -214,7 +207,7 @@ static constexpr int kOptLevel = 3;
214
207
215
208
class CodeGen_TC : public Halide ::Internal::CodeGen_X86 {
216
209
public:
217
- const isl::pw_multi_aff * iteratorMap_;
210
+ const IteratorMapType * iteratorMap_;
218
211
CodeGen_TC (Target t) : CodeGen_X86(t) {}
219
212
220
213
using CodeGen_X86::codegen;
@@ -249,6 +242,11 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
249
242
return std::move (module);
250
243
}
251
244
245
+ // Convert an isl AST expression into an llvm::Value.
246
+ // Only expressions that consist of a pure identifier or
247
+ // a pure integer constant are currently supported.
248
+ llvm::Value* getValue (isl::ast_expr expr);
249
+
252
250
protected:
253
251
using CodeGen_X86::visit;
254
252
void visit (const Halide::Internal::Call* call) override {
@@ -272,44 +270,7 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
272
270
}
273
271
}
274
272
void visit (const Halide::Internal::Variable* op) override {
275
- auto aff = halide2isl::makeIslAffFromExpr (
276
- iteratorMap_->get_space ().range (), Halide::Expr (op));
277
-
278
- auto subscriptPma = isl::pw_aff (aff).pullback (*iteratorMap_);
279
- auto subscriptAff = extractAff (subscriptPma);
280
-
281
- // sanity checks
282
- CHECK_EQ (subscriptAff.dim (isl::dim_type::div ), 0 );
283
- CHECK_EQ (subscriptAff.dim (isl::dim_type::out), 1 );
284
- for (int d = 0 ; d < subscriptAff.dim (isl::dim_type::param); ++d) {
285
- auto v = subscriptAff.get_coefficient_val (isl::dim_type::param, d);
286
- CHECK (v.is_zero ());
287
- }
288
-
289
- llvm::Optional<int > posOne;
290
- int sum = 0 ;
291
- for (int d = 0 ; d < subscriptAff.dim (isl::dim_type::in); ++d) {
292
- auto v = subscriptAff.get_coefficient_val (isl::dim_type::in, d);
293
- CHECK (v.is_zero () or v.is_one ());
294
- if (v.is_zero ()) {
295
- continue ;
296
- }
297
- ++sum;
298
- posOne = d;
299
- }
300
- CHECK_LE (sum, 1 );
301
-
302
- if (sum == 0 ) {
303
- value =
304
- getLLVMConstantSignedInt64 (toSInt (subscriptAff.get_constant_val ()));
305
- return ;
306
- }
307
- CHECK (posOne);
308
-
309
- std::string name (
310
- isl_aff_get_dim_name (subscriptAff.get (), isl_dim_in, *posOne));
311
-
312
- value = sym_get (name);
273
+ value = getValue (iteratorMap_->at (op->name ));
313
274
}
314
275
315
276
public:
@@ -361,6 +322,21 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
361
322
}
362
323
};
363
324
325
+ llvm::Value* CodeGen_TC::getValue (isl::ast_expr expr) {
326
+ switch (isl_ast_expr_get_type (expr.get ())) {
327
+ case isl_ast_expr_type::isl_ast_expr_id:
328
+ return sym_get (expr.get_id ().get_name ());
329
+ case isl_ast_expr_type::isl_ast_expr_int: {
330
+ auto val = isl::manage (isl_ast_expr_get_val (expr.get ()));
331
+ CHECK (val.is_int ());
332
+ return getLLVMConstantSignedInt64 (val.get_num_si ());
333
+ }
334
+ default :
335
+ LOG (FATAL) << " NYI" ;
336
+ return nullptr ;
337
+ }
338
+ }
339
+
364
340
class LLVMCodegen {
365
341
void collectTensor (const Halide::OutputImageParam& t) {
366
342
auto sizes =
@@ -638,22 +614,7 @@ class LLVMCodegen {
638
614
llvm::SmallVector<llvm::Value*, 5 > subscriptValues;
639
615
640
616
for (const auto & subscript : subscripts) {
641
- switch (isl_ast_expr_get_type (subscript.get ())) {
642
- case isl_ast_expr_type::isl_ast_expr_id: {
643
- subscriptValues.push_back (
644
- halide_cg.sym_get (subscript.get_id ().get_name ()));
645
- break ;
646
- }
647
- case isl_ast_expr_type::isl_ast_expr_int: {
648
- auto val = isl::manage (isl_ast_expr_get_val (subscript.get ()));
649
- CHECK_EQ (val.get_den_si (), 1 );
650
- subscriptValues.push_back (
651
- getLLVMConstantSignedInt64 (val.get_num_si ()));
652
- break ;
653
- }
654
- default :
655
- LOG (FATAL) << " NYI" ;
656
- }
617
+ subscriptValues.push_back (halide_cg.getValue (subscript));
657
618
}
658
619
659
620
auto destAddr = halide_cg.get_builder ().CreateInBoundsGEP (
@@ -703,34 +664,28 @@ IslCodegenRes codegenISL(const Scop& scop) {
703
664
const Scop& scop,
704
665
StmtSubscriptExprMapType& stmtSubscripts) -> isl::ast_node {
705
666
auto expr = node.user_get_expr ();
706
- // We rename loop-related dimensions manually.
707
667
auto schedule = build.get_schedule ();
708
- auto scheduleSpace = build.get_schedule_space ();
709
668
auto scheduleMap = isl::map::from_union_map (schedule);
710
669
711
670
auto stmtId = expr.get_op_arg (0 ).get_id ();
712
671
// auto nodeId = isl::id(
713
672
// node.get_ctx(),
714
673
// std::string(kAstNodeIdPrefix) + std::to_string(nAstNodes()++));
715
674
CHECK_EQ (0 , iteratorMaps.count (stmtId)) << " entry exists: " << stmtId;
716
- CHECK_EQ (
717
- scheduleMap.dim (isl::dim_type::out),
718
- scheduleSpace.dim (isl::dim_type::set));
719
- for (int i = 0 ; i < scheduleSpace.dim (isl::dim_type::set); ++i) {
720
- scheduleMap = scheduleMap.set_dim_id (
721
- isl::dim_type::out,
722
- i,
723
- scheduleSpace.get_dim_id (isl::dim_type::set, i));
724
- }
725
675
auto iteratorMap = isl::pw_multi_aff (scheduleMap.reverse ());
726
- iteratorMaps.emplace (stmtId, iteratorMap);
676
+ auto iterators = scop.halide .iterators .at (stmtId);
677
+ auto & stmtIteratorMap = iteratorMaps[stmtId];
678
+ for (int i = 0 ; i < iterators.size (); ++i) {
679
+ auto expr = build.expr_from (iteratorMap.get_pw_aff (i));
680
+ stmtIteratorMap.emplace (iterators[i], expr);
681
+ }
727
682
auto & subscripts = stmtSubscripts[stmtId];
728
683
auto provide =
729
684
scop.halide .statements .at (stmtId).as <Halide::Internal::Provide>();
730
685
for (auto e : provide->args ) {
731
686
const auto & map = iteratorMap;
732
- auto space = map.get_space ().range ();
733
- auto aff = halide2isl::makeIslAffFromExpr ( space, e);
687
+ auto space = map.get_space ().params ();
688
+ auto aff = scop. makeIslAffFromStmtExpr (stmtId, space, e);
734
689
auto pulled = isl::pw_aff (aff).pullback (map);
735
690
CHECK_EQ (pulled.n_piece (), 1 );
736
691
subscripts.push_back (build.expr_from (pulled));
0 commit comments