Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 31cb2a8

Browse files
committed
Allow computed expressions on the left-hand-side
1 parent d26f79a commit 31cb2a8

File tree

8 files changed

+170
-48
lines changed

8 files changed

+170
-48
lines changed

include/tc/core/libraries.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,14 @@ namespace c {
3232

3333
constexpr auto types = R"C(
3434
// Halide type handling
35-
typedef int int32;
36-
typedef long int64;
35+
typedef signed char int8;
36+
typedef unsigned char uint8;
37+
typedef signed short int16;
38+
typedef unsigned short uint16;
39+
typedef signed int int32;
40+
typedef unsigned int uint32;
41+
typedef signed long int64;
42+
typedef unsigned long uint64;
3743
typedef float float32;
3844
typedef double float64;
3945
)C";
@@ -81,16 +87,16 @@ float fmodf ( float x, float y );
8187
//float frexpf ( float x, int* nptr );
8288
float hypotf ( float x, float y );
8389
//int ilogbf ( float x );
84-
//__RETURN_TYPE isfinite ( float a );
85-
//__RETURN_TYPE isinf ( float a );
86-
//__RETURN_TYPE isnan ( float a );
90+
//__RETURN_TYPE isfinite ( float a );
91+
//__RETURN_TYPE isinf ( float a );
92+
//__RETURN_TYPE isnan ( float a );
8793
float j0f ( float x );
8894
float j1f ( float x );
8995
//float jnf ( int n, float x );
9096
//float ldexpf ( float x, int exp );
9197
float lgammaf ( float x );
92-
//long long int llrintf ( float x );
93-
//long long int llroundf ( float x );
98+
//long long int llrintf ( float x );
99+
//long long int llroundf ( float x );
94100
float log10f ( float x );
95101
float log1pf ( float x );
96102
float log2f ( float x );
@@ -120,7 +126,7 @@ float roundf ( float x );
120126
float rsqrtf ( float x );
121127
//float scalblnf ( float x, long int n );
122128
//float scalbnf ( float x, int n );
123-
//__RETURN_TYPE signbit ( float a );
129+
//__RETURN_TYPE signbit ( float a );
124130
//void sincosf ( float x, float* sptr, float* cptr );
125131
//void sincospif ( float x, float* sptr, float* cptr );
126132
float sinf ( float x );

include/tc/lang/parser.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,15 @@ struct Parser {
138138
TreeRef parseExpList() {
139139
return parseList('(', ',', ')', [&](int i) { return parseExp(); });
140140
}
141+
TreeRef parseOptionalExpList() {
142+
TreeRef list = nullptr;
143+
if (L.cur().kind == '(') {
144+
list = parseExpList();
145+
} else {
146+
list = List::create(L.cur().range, {});
147+
}
148+
return list;
149+
}
141150
TreeRef parseIdentList() {
142151
return parseList('(', ',', ')', [&](int i) { return parseIdent(); });
143152
}
@@ -213,7 +222,7 @@ struct Parser {
213222
}
214223
TreeRef parseStmt() {
215224
auto ident = parseIdent();
216-
TreeRef list = parseOptionalIdentList();
225+
TreeRef list = parseOptionalExpList();
217226
auto assign = parseAssignment();
218227
auto rhs = parseExp();
219228
TreeRef equivalent_statement = parseEquivalent();

include/tc/lang/sema.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,11 @@ struct Sema {
442442

443443
// register index variables (non-reductions)
444444
for (const auto& index : stmt.indices()) {
445-
std::string idx = index.name();
446-
auto typ = indexType(index);
447-
insert(index_env, index, typ, true);
445+
if (index->kind() == TK_IDENT) {
446+
std::string idx = Ident(index).name();
447+
auto typ = indexType(index);
448+
insert(index_env, Ident(index), typ, true);
449+
}
448450
}
449451

450452
// make dimension variables for each dimension of the output tensor
@@ -465,6 +467,9 @@ struct Sema {
465467
auto where_clauses_ = stmt.whereClauses().map(
466468
[&](TreeRef rc) { return checkWhereClause(rc); });
467469

470+
auto indices_ = stmt.indices().map(
471+
[&](TreeRef idx) { return checkExp(idx, true); });
472+
468473
TreeRef rhs_ = checkExp(stmt.rhs(), true);
469474
TreeRef scalar_type = typeOfExpr(rhs_);
470475

@@ -525,7 +530,7 @@ struct Sema {
525530
TreeRef result = Comprehension::create(
526531
stmt.range(),
527532
stmt.ident(),
528-
stmt.indices(),
533+
indices_,
529534
stmt.assignment(),
530535
rhs_,
531536
where_clauses_,

include/tc/lang/tree_views.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,8 @@ struct Comprehension : public TreeView {
386386
Ident ident() const {
387387
return Ident(subtree(0));
388388
}
389-
ListView<Ident> indices() const {
390-
return ListView<Ident>(subtree(1));
389+
ListView<TreeRef> indices() const {
390+
return ListView<TreeRef>(subtree(1));
391391
}
392392
// kind == '=', TK_PLUS_EQ, TK_PLUS_EQ_B, etc.
393393
TreeRef assignment() const {

src/core/tc2halide.cc

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ Expr translateExpr(
216216
}
217217
}
218218

219-
vector<const Variable*> unboundVariables(const vector<Var>& lhs, Expr rhs) {
219+
vector<const Variable*> unboundVariables(const vector<Expr>& lhs, Expr rhs) {
220220
class FindUnboundVariables : public IRVisitor {
221221
using IRVisitor::visit;
222222

@@ -241,14 +241,19 @@ vector<const Variable*> unboundVariables(const vector<Var>& lhs, Expr rhs) {
241241
set<string> visited;
242242

243243
public:
244-
FindUnboundVariables(const vector<Var>& lhs) {
245-
for (auto v : lhs) {
246-
bound.push(v.name());
244+
FindUnboundVariables(const vector<Expr>& lhs) {
245+
for (auto e : lhs) {
246+
if (const Variable *v = e.as<Variable>()) {
247+
bound.push(v->name);
248+
}
247249
}
248250
}
249251
vector<const Variable*> result;
250252
} finder(lhs);
251253
rhs.accept(&finder);
254+
for (auto e : lhs) {
255+
e.accept(&finder);
256+
}
252257
return finder.result;
253258
}
254259

@@ -507,22 +512,31 @@ void translateComprehension(
507512
f = Function(c.ident().name());
508513
(*funcs)[c.ident().name()] = f;
509514
}
515+
516+
// we currently inline all of the let bindings generated in where clauses
517+
// in the future we may consider using Halide Let bindings when they
518+
// are supported later
519+
map<string, Expr> lets;
520+
510521
// Function is the internal Halide IR type for a pipeline
511522
// stage. Func is the front-end class that wraps it. Here it's
512523
// convenient to use both.
513524
Func func(f);
514525

515-
vector<Var> lhs;
516-
vector<Expr> lhs_as_exprs;
517-
for (lang::Ident id : c.indices()) {
518-
lhs.push_back(Var(id.name()));
519-
lhs_as_exprs.push_back(lhs.back());
526+
vector<Expr> lhs;
527+
vector<Var> lhs_vars;
528+
bool total_definition = true;
529+
for (lang::TreeRef idx : c.indices()) {
530+
Expr e = translateExpr(idx, params, *funcs, lets);
531+
if (const Variable *op = e.as<Variable>()) {
532+
lhs_vars.push_back(Var(op->name));
533+
} else {
534+
total_definition = false;
535+
lhs_vars.push_back(Var());
536+
}
537+
lhs.push_back(e);
520538
}
521539

522-
// we currently inline all of the let bindings generated in where clauses
523-
// in the future we may consider using Halide Let bindings when they
524-
// are supported later
525-
map<string, Expr> lets;
526540
for (auto wc : c.whereClauses()) {
527541
if (wc->kind() == lang::TK_LET) {
528542
auto let = lang::Let(wc);
@@ -546,9 +560,8 @@ void translateComprehension(
546560
auto setupIdentity = [&](const Expr& identity, bool zero) {
547561
if (!f.has_pure_definition()) {
548562
added_implicit_initialization = true;
549-
func(lhs) = (zero) ? identity
550-
: undef(rhs.type()); // undef causes the original value
551-
// to remain in input arrays
563+
// undef causes the original value to remain in input arrays
564+
func(lhs_vars) = (zero) ? identity : undef(rhs.type());
552565
}
553566
};
554567

@@ -587,6 +600,9 @@ void translateComprehension(
587600
break;
588601

589602
case '=':
603+
if (!total_definition) {
604+
setupIdentity(rhs, false);
605+
}
590606
break;
591607
default:
592608
throw lang::ErrorReport(c) << "Unimplemented reduction "
@@ -618,9 +634,10 @@ void translateComprehension(
618634
for (auto& exp : all_exprs) {
619635
exp = bindParams.mutate(exp);
620636
}
621-
622-
// TODO: When the LHS incorporates general expressions we'll need to
623-
// bind params there too.
637+
for (auto &e : lhs) {
638+
e = bindParams.mutate(e);
639+
all_exprs.push_back(e);
640+
}
624641

625642
// Do forward bounds inference -- construct an expression that says
626643
// this expression never reads out of bounds on its inputs, and
@@ -660,19 +677,34 @@ void translateComprehension(
660677
// (e.g. an in-place stencil)?. The .bound directive will use the
661678
// bounds of the last stage for all stages.
662679

663-
// Does a tensor have a single bound, or can its bounds shrink over
664-
// time? Solve for a single bound for now.
665-
666-
for (Var v : lhs) {
667-
if (!solution.contains(v.name())) {
668-
throw lang::ErrorReport(c)
680+
// Set the bounds to be the union of the boxes written to by every
681+
// comprehension touching the tensor.
682+
for (size_t i = 0; i < lhs.size(); i++) {
683+
Expr e = lhs[i];
684+
if (const Variable *v = e.as<Variable>()) {
685+
if (!solution.contains(v->name)) {
686+
throw lang::ErrorReport(c)
669687
<< "Free variable " << v
670688
<< " was not solved in range inference. May not be used right-hand side";
689+
}
690+
}
691+
692+
Interval in = bounds_of_expr_in_scope(e, solution);
693+
if (!in.is_bounded()) {
694+
throw lang::ErrorReport(c.indices()[i])
695+
<< "Left-hand side expression is unbounded";
696+
}
697+
in.min = cast<int>(in.min);
698+
in.max = cast<int>(in.max);
699+
700+
map<string, Interval> &b = (*bounds)[f];
701+
string dim_name = f.dimensions() ? f.args()[i] : lhs_vars[i].name();
702+
auto old = b.find(dim_name);
703+
if (old != b.end()) {
704+
// Take the union with any existing bounds
705+
in.include(old->second);
671706
}
672-
// TODO: We're enforcing a single bound across all comprehensions
673-
// for now. We should really check later ones are equal to earlier
674-
// ones instead of just clobbering.
675-
(*bounds)[f][v.name()] = solution.get(v.name());
707+
b[dim_name] = in;
676708
}
677709

678710
// Free variables that appear on the rhs but not the lhs are
@@ -703,6 +735,9 @@ void translateComprehension(
703735
for (auto v : unbound) {
704736
Expr rv = Variable::make(Int(32), v->name, domain);
705737
rhs = substitute(v->name, rv, rhs);
738+
for (Expr &e : lhs) {
739+
e = substitute(v->name, rv, e);
740+
}
706741
}
707742
rdom = RDom(domain);
708743
}
@@ -718,9 +753,12 @@ void translateComprehension(
718753
}
719754
}
720755
while (!lhs.empty()) {
721-
loop_nest.push_back(lhs.back());
756+
if (const Variable *v = lhs.back().as<Variable>()) {
757+
loop_nest.push_back(Var(v->name));
758+
}
722759
lhs.pop_back();
723760
}
761+
stage.reorder(loop_nest);
724762

725763
if (added_implicit_initialization) {
726764
// Also reorder reduction initializations to the TC convention
@@ -734,7 +772,6 @@ void translateComprehension(
734772
}
735773

736774
func.compute_root();
737-
stage.reorder(loop_nest);
738775
}
739776

740777
HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {

src/lang/tc_format.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ std::ostream& operator<<(std::ostream& s, const Param& p) {
6060
}
6161

6262
std::ostream& operator<<(std::ostream& s, const Comprehension& comp) {
63-
s << comp.ident() << "(" << comp.indices() << ") "
64-
<< kindToToken(comp.assignment()->kind()) << " ";
63+
s << comp.ident() << "(";
64+
showList(s, comp.indices(), showExpr);
65+
s << ") " << kindToToken(comp.assignment()->kind()) << " ";
6566
showExpr(s, comp.rhs());
6667
if (!comp.whereClauses().empty())
6768
throw std::runtime_error("Printing of where clauses is not supported yet");

test/test_execution_engine.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,25 @@ def concat(float(M, N) A, float(M, N) B) -> (O1) {
145145
outputs);
146146
}
147147

148+
TEST_F(ATenCompilationUnitTest, Concat2) {
149+
at::Tensor a = at::CUDA(at::kFloat).rand({32, 16});
150+
at::Tensor b = at::CUDA(at::kFloat).rand({32, 16});
151+
std::vector<at::Tensor> inputs = {a, b};
152+
std::vector<at::Tensor> outputs;
153+
154+
Check(
155+
R"(
156+
def concat(float(M, N) A, float(M, N) B) -> (O1) {
157+
O1(n, 0, m) = A(m, n)
158+
O1(n, 1, m) = B(m, n)
159+
}
160+
)",
161+
"concat",
162+
tc::CudaMappingOptions::makeNaiveCudaMappingOptions(),
163+
inputs,
164+
outputs);
165+
}
166+
148167
TEST_F(ATenCompilationUnitTest, Indexing) {
149168
at::Tensor a = at::CUDA(at::kFloat).rand({3, 4});
150169
at::Tensor b = at::CUDA(at::kInt).ones({2});

test/test_tc_mapper.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,51 @@ def fun(float(B, R) LUT, int32(B, N) I) -> (O) {
521521
checkFun);
522522
}
523523

524+
525+
TEST_F(TcMapperTest, Histogram) {
526+
const int N = 17, M = 82;
527+
at::Tensor I =
528+
at::CUDA(at::kFloat).rand({N, M}).mul_(256).floor_().toType(at::kByte);
529+
std::vector<at::Tensor> inputs = {I};
530+
std::vector<at::Tensor> outputs;
531+
532+
static constexpr auto TC = R"TC(
533+
def fun(uint8(N, M) I) -> (O) {
534+
O(I(i, j)) +=! 1
535+
}
536+
)TC";
537+
538+
auto checkFun = [=](const std::vector<at::Tensor>& inputs,
539+
std::vector<at::Tensor>& outputs) {
540+
at::Tensor I = inputs[0].toBackend(at::kCPU);
541+
at::Tensor O = outputs[0].toBackend(at::kCPU);
542+
auto IAccessor = I.accessor<uint8_t, 2>();
543+
auto OAccessor = O.accessor<int, 1>();
544+
int sum = 0;
545+
for (int i = 0; i < 256; i++) {
546+
sum += OAccessor[i];
547+
}
548+
CHECK_EQ(sum, N * M);
549+
550+
for (int i = 0; i < N; i++) {
551+
for (int j = 0; j < M; j++) {
552+
OAccessor[IAccessor[i][j]]--;
553+
}
554+
}
555+
556+
for (int i = 0; i < 256; i++) {
557+
CHECK_EQ(OAccessor[i], 0);
558+
}
559+
};
560+
Check(
561+
TC,
562+
"fun",
563+
tc::CudaMappingOptions::makeNaiveCudaMappingOptions(),
564+
inputs,
565+
checkFun);
566+
}
567+
568+
524569
TEST_F(TcMapperTest, DISABLED_SpatialBatchNormalization) {
525570
N = 32;
526571
at::Tensor eps = at::CUDA(at::kFloat).rand({});

0 commit comments

Comments
 (0)