@@ -714,6 +714,150 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
714
714
return
715
715
}
716
716
717
+ // CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
718
+ func.func @test_speculatable_op_with_read_side_effect_success (%lb: index , %ub: index , %step: index ) -> i32 {
719
+ // CHECK: test.always_speculatable_op
720
+ // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
721
+ // CHECK-NEXT: scf.if %[[CMP]]
722
+ // CHECK-NEXT: test.speculatable_op_with_memread
723
+ // CHECK: else
724
+ // CHECK-NEXT: ub.poison : i32
725
+ // CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
726
+ // CHECK-NOT: test.always_speculatable_op
727
+ // CHECK-NOT: test.speculatable_op_with_memread
728
+ %cst_0 = arith.constant 0 : i32
729
+ %cst_42 = arith.constant dense <42 > : tensor <64 xi32 >
730
+ %ind_42 = arith.constant 42 : index
731
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args (%acc = %cst_0 ) -> i32 {
732
+ %always_speculate = " test.always_speculatable_op" () : () -> i32
733
+ %only_read = " test.speculatable_op_with_memread" (%cst_42 , %ind_42 ) : (tensor <64 xi32 >, index ) -> i32
734
+ %i_cast = arith.index_cast %i: index to i32
735
+ %add = arith.addi %acc , %i_cast : i32
736
+ %sum = arith.addi %add , %only_read : i32
737
+ scf.yield %sum : i32
738
+ }
739
+ return %sum_result : i32
740
+ }
741
+
742
+ // CHECK-LABEL: test_speculatable_op_with_read_side_effect_multiple_result_success
743
+ func.func @test_speculatable_op_with_read_side_effect_multiple_result_success (%lb: index , %ub: index , %step: index ) -> i32 {
744
+ // CHECK: test.always_speculatable_op
745
+ // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
746
+ // CHECK-NEXT: scf.if %[[CMP]]
747
+ // CHECK-NEXT: test.speculatable_op_with_memread
748
+ // CHECK: else
749
+ // CHECK-NEXT: ub.poison : i32
750
+ // CHECK-NEXT: ub.poison : f32
751
+ // CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
752
+ // CHECK-NOT: test.always_speculatable_op
753
+ // CHECK-NOT: test.speculatable_op_with_memread
754
+ %cst_0 = arith.constant 0 : i32
755
+ %cst_42 = arith.constant dense <42 > : tensor <64 xi32 >
756
+ %ind_42 = arith.constant 42 : index
757
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args (%acc = %cst_0 ) -> i32 {
758
+ %always_speculate = " test.always_speculatable_op" () : () -> i32
759
+ %only_read:2 = " test.speculatable_op_with_memread" (%cst_42 , %ind_42 ) : (tensor <64 xi32 >, index ) -> (i32 , f32 )
760
+ %i_cast = arith.index_cast %i: index to i32
761
+ %add = arith.addi %acc , %i_cast : i32
762
+ %sum = arith.addi %add , %only_read#0 : i32
763
+ scf.yield %sum : i32
764
+ }
765
+ return %sum_result : i32
766
+ }
767
+
768
+ // CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
769
+ func.func @test_speculatable_op_with_read_side_effect_success_with_dependents (%lb: index , %ub: index , %step: index ) -> i32 {
770
+ // CHECK: %[[ALWAYS:.*]] = "test.always_speculatable_op"
771
+ // CHECK-NEXT: %[[CMP0:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
772
+ // CHECK-NEXT: %[[IF0:.*]] = scf.if %[[CMP0]]
773
+ // CHECK-NEXT: test.speculatable_op_with_memread
774
+ // CHECK: else
775
+ // CHECK-NEXT: ub.poison : i32
776
+ // CHECK: %[[CMP1:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
777
+ // CHECK-NEXT: %[[IF1:.*]] = scf.if %[[CMP1]]
778
+ // CHECK-NEXT: arith.addi %[[ALWAYS]], %[[IF0]]
779
+ // CHECK: else
780
+ // CHECK-NEXT: ub.poison : i32
781
+ // CHECK: %[[CMP2:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
782
+ // CHECK-NEXT: %[[IF2:.*]] = scf.if %[[CMP2]]
783
+ // CHECK-NEXT: test.speculatable_op_with_memread
784
+ // CHECK: else
785
+ // CHECK-NEXT: ub.poison : i32
786
+ // CHECK: %[[CMP3:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
787
+ // CHECK-NEXT: %{{.*}} = scf.if %[[CMP3]]
788
+ // CHECK-NEXT: arith.addi %[[IF1]], %[[IF2]]
789
+ // CHECK: else
790
+ // CHECK-NEXT: ub.poison : i32
791
+ // CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]]
792
+ // CHECK-NOT: test.always_speculatable_op
793
+ // CHECK-NOT: test.speculatable_op_with_memread
794
+ %cst_0 = arith.constant 0 : i32
795
+ %cst_42 = arith.constant dense <42 > : tensor <64 xi32 >
796
+ %ind_42 = arith.constant 42 : index
797
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args (%acc = %cst_0 ) -> i32 {
798
+ %always_speculate = " test.always_speculatable_op" () : () -> i32
799
+ %only_read_0 = " test.speculatable_op_with_memread" (%cst_42 , %ind_42 ) : (tensor <64 xi32 >, index ) -> i32
800
+ %add_0 = arith.addi %always_speculate , %only_read_0 : i32
801
+ %only_read_1 = " test.speculatable_op_with_memread" (%cst_42 , %ind_42 ) : (tensor <64 xi32 >, index ) -> i32
802
+ %add_1 = arith.addi %add_0 , %only_read_1 : i32
803
+ %i_cast = arith.index_cast %i: index to i32
804
+ %sum = arith.addi %add_1 , %i_cast : i32
805
+ scf.yield %sum : i32
806
+ }
807
+ return %sum_result : i32
808
+ }
809
+
810
+ // CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_write
811
+ func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write (%lb: index , %ub: index , %step: index ) -> i32 {
812
+ // CHECK: test.always_speculatable_op
813
+ // CHECK-NEXT: scf.for
814
+ // CHECK-NOT: test.always_speculatable_op
815
+ // CHECK: test.speculatable_op_with_memread
816
+ // CHECK: test.speculatable_op_with_memwrite
817
+ %cst_0 = arith.constant 0 : i32
818
+ %cst_42 = arith.constant dense <42 > : tensor <64 xi32 >
819
+ %ind_42 = arith.constant 42 : index
820
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args (%acc = %cst_0 ) -> i32 {
821
+ %always_speculate = " test.always_speculatable_op" () : () -> i32
822
+ %only_read = " test.speculatable_op_with_memread" (%cst_42 , %ind_42 ) : (tensor <64 xi32 >, index ) -> i32
823
+ %i_cast = arith.index_cast %i: index to i32
824
+ %add = arith.addi %acc , %i_cast : i32
825
+ %sum = arith.addi %add , %only_read : i32
826
+ %write = " test.speculatable_op_with_memwrite" (%cst_42 ) : (tensor <64 xi32 >) -> i32
827
+ scf.yield %sum : i32
828
+ }
829
+ return %sum_result : i32
830
+ }
831
+
832
+ // CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_nested_write
833
+ func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_write (%lb: index , %ub: index , %step: index ) -> i32 {
834
+ // CHECK: test.always_speculatable_op
835
+ // CHECK-NEXT: scf.for
836
+ // CHECK-NOT: test.always_speculatable_op
837
+ // CHECK: test.speculatable_op_with_memread
838
+ // CHECK: scf.for
839
+ // CHECK: scf.if
840
+ // CHECK: test.speculatable_op_with_memwrite
841
+ %cst_0 = arith.constant 0 : i32
842
+ %cst_42 = arith.constant dense <42 > : tensor <64 xi32 >
843
+ %ind_42 = arith.constant 42 : index
844
+ %sum_result = scf.for %i = %lb to %ub step %step iter_args (%acc = %cst_0 ) -> i32 {
845
+ %always_speculate = " test.always_speculatable_op" () : () -> i32
846
+ %only_read = " test.speculatable_op_with_memread" (%cst_42 , %ind_42 ) : (tensor <64 xi32 >, index ) -> i32
847
+ %i_cast = arith.index_cast %i: index to i32
848
+ %add = arith.addi %acc , %i_cast : i32
849
+ %sum = arith.addi %add , %only_read : i32
850
+ scf.for %j = %lb to %ub step %step {
851
+ %eq42 = arith.cmpi eq , %j , %ind_42 : index
852
+ scf.if %eq42 {
853
+ %always_write = " test.speculatable_op_with_memwrite" (%cst_42 ) : (tensor <64 xi32 >) -> i32
854
+ }
855
+ }
856
+ scf.yield %sum : i32
857
+ }
858
+ return %sum_result : i32
859
+ }
860
+
717
861
// -----
718
862
719
863
func.func @speculate_tensor_dim_unknown_rank_unknown_dim (
0 commit comments