Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 2f9871d

Browse files
committed
teach sema to extract iteration variables from LHS subtrees
Sema needs the list of iteration variables on the Comprehension LHS to differentiate between reduction and non-reduction variables, the former appearing only on the RHS. Original implementation assumes Comprehension LHS is a tensor whose indices are Idents and ignores more complex constructs. With indirection support, comprehensions like O(A(i)) = B(i) are possible but i is interpreted as a reduction dimension by Sema. Traverse indices of the LHS Tensor in Comprehension recursively, inspecting subtrees of Access and Apply trees and collecting all Idents.
1 parent e9d1dc4 commit 2f9871d

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

tc/lang/sema.h

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -437,17 +437,31 @@ struct Sema {
437437
return checkRangeConstraint(RangeConstraint(ref));
438438
}
439439
}
440+
441+
private:
442+
// Traverse the list of trees, recursively descending into arguments of APPLY
443+
// and ACCESS subtrees, and collect names and types of IDENT subtrees in
444+
// "index_env". Expects to be called on the indices of the LHS tensor.
445+
void registerLHSIndices(const ListView<TreeRef>& treeRefs) {
446+
for (const auto& treeRef : treeRefs) {
447+
if (treeRef->kind() == TK_IDENT) {
448+
std::string idx = Ident(treeRef).name();
449+
auto typ = indexType(treeRef);
450+
insert(index_env, Ident(treeRef), typ, true);
451+
} else if (treeRef->kind() == TK_APPLY) {
452+
registerLHSIndices(Apply(treeRef).arguments());
453+
} else if (treeRef->kind() == TK_ACCESS) {
454+
registerLHSIndices(Access(treeRef).arguments());
455+
}
456+
}
457+
}
458+
459+
public:
440460
TreeRef checkStmt(TreeRef stmt_) {
441461
auto stmt = Comprehension(stmt_);
442462

443463
// register index variables (non-reductions)
444-
for (const auto& index : stmt.indices()) {
445-
if (index->kind() == TK_IDENT) {
446-
std::string idx = Ident(index).name();
447-
auto typ = indexType(index);
448-
insert(index_env, Ident(index), typ, true);
449-
}
450-
}
464+
registerLHSIndices(stmt.indices());
451465

452466
// make dimension variables for each dimension of the output tensor
453467
std::string name = stmt.ident().name();
@@ -464,6 +478,7 @@ struct Sema {
464478

465479
// where clauses are checked _before_ the rhs because they
466480
// introduce let bindings that are in scope for the rhs
481+
//
467482
auto where_clauses_ = stmt.whereClauses().map(
468483
[&](TreeRef rc) { return checkWhereClause(rc); });
469484

0 commit comments

Comments
 (0)