Skip to content

Commit ba4cb5c

Browse files
[mlir][Parser] Add nan and inf keywords
1 parent 04de524 commit ba4cb5c

14 files changed

+134
-47
lines changed

mlir/lib/AsmParser/AttributeParser.cpp

+21-10
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
#include "mlir/IR/DialectImplementation.h"
2222
#include "mlir/IR/DialectResourceBlobManager.h"
2323
#include "mlir/IR/IntegerSet.h"
24+
#include "llvm/ADT/APFloat.h"
2425
#include "llvm/ADT/StringExtras.h"
2526
#include "llvm/Support/Endian.h"
27+
#include <cmath>
2628
#include <optional>
2729

2830
using namespace mlir;
@@ -121,14 +123,16 @@ Attribute Parser::parseAttribute(Type type) {
121123

122124
// Parse floating point and integer attributes.
123125
case Token::floatliteral:
126+
case Token::kw_inf:
127+
case Token::kw_nan:
124128
return parseFloatAttr(type, /*isNegative=*/false);
125129
case Token::integer:
126130
return parseDecOrHexAttr(type, /*isNegative=*/false);
127131
case Token::minus: {
128132
consumeToken(Token::minus);
129133
if (getToken().is(Token::integer))
130134
return parseDecOrHexAttr(type, /*isNegative=*/true);
131-
if (getToken().is(Token::floatliteral))
135+
if (getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan))
132136
return parseFloatAttr(type, /*isNegative=*/true);
133137

134138
return (emitWrongTokenError(
@@ -342,21 +346,25 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
342346

343347
/// Parse a float attribute.
344348
Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
345-
auto val = getToken().getFloatingPointValue();
346-
if (!val)
347-
return (emitError("floating point value too large for attribute"), nullptr);
348-
consumeToken(Token::floatliteral);
349+
const Token tok = getToken();
350+
consumeToken();
349351
if (!type) {
350352
// Default to F64 when no type is specified.
351353
if (!consumeIf(Token::colon))
352354
type = builder.getF64Type();
353355
else if (!(type = parseType()))
354356
return nullptr;
355357
}
356-
if (!isa<FloatType>(type))
357-
return (emitError("floating point value not valid for specified type"),
358+
auto floatType = dyn_cast<FloatType>(type);
359+
if (!floatType)
360+
return (emitError(tok.getLoc(),
361+
"floating point value not valid for specified type"),
358362
nullptr);
359-
return FloatAttr::get(type, isNegative ? -*val : *val);
363+
std::optional<APFloat> apResult;
364+
if (failed(parseFloatFromLiteral(apResult, tok, isNegative,
365+
floatType.getFloatSemantics())))
366+
return Attribute();
367+
return FloatAttr::get(floatType, *apResult);
360368
}
361369

362370
/// Construct an APint from a parsed value, a known attribute type and
@@ -622,7 +630,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
622630
}
623631

624632
// Check to see if floating point values were parsed.
625-
if (token.is(Token::floatliteral)) {
633+
if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) {
626634
return p.emitError(tokenLoc)
627635
<< "expected integer elements, but parsed floating-point";
628636
}
@@ -729,6 +737,8 @@ ParseResult TensorLiteralParser::parseElement() {
729737
// Parse a boolean element.
730738
case Token::kw_true:
731739
case Token::kw_false:
740+
case Token::kw_inf:
741+
case Token::kw_nan:
732742
case Token::floatliteral:
733743
case Token::integer:
734744
storage.emplace_back(/*isNegative=*/false, p.getToken());
@@ -738,7 +748,8 @@ ParseResult TensorLiteralParser::parseElement() {
738748
// Parse a signed integer or a negative floating-point element.
739749
case Token::minus:
740750
p.consumeToken(Token::minus);
741-
if (!p.getToken().isAny(Token::floatliteral, Token::integer))
751+
if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan,
752+
Token::integer))
742753
return p.emitError("expected integer or floating point literal");
743754
storage.emplace_back(/*isNegative=*/true, p.getToken());
744755
p.consumeToken();

mlir/lib/AsmParser/Parser.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -350,11 +350,33 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
350350
ParseResult Parser::parseFloatFromLiteral(std::optional<APFloat> &result,
351351
const Token &tok, bool isNegative,
352352
const llvm::fltSemantics &semantics) {
353+
// Check for inf keyword.
354+
if (tok.is(Token::kw_inf)) {
355+
if (!APFloat::semanticsHasInf(semantics))
356+
return emitError(tok.getLoc())
357+
<< "floating point type does not support infinity";
358+
result = APFloat::getInf(semantics, isNegative);
359+
return success();
360+
}
361+
362+
// Check for NaN keyword.
363+
if (tok.is(Token::kw_nan)) {
364+
if (!APFloat::semanticsHasNaN(semantics))
365+
return emitError(tok.getLoc())
366+
<< "floating point type does not support NaN";
367+
result = APFloat::getNaN(semantics, isNegative);
368+
return success();
369+
}
370+
353371
// Check for a floating point value.
354372
if (tok.is(Token::floatliteral)) {
355373
auto val = tok.getFloatingPointValue();
356374
if (!val)
357375
return emitError(tok.getLoc()) << "floating point value too large";
376+
if (std::fpclassify(*val) == FP_ZERO &&
377+
!APFloat::semanticsHasZero(semantics))
378+
return emitError(tok.getLoc())
379+
<< "floating point type does not support zero";
358380

359381
result.emplace(isNegative ? -*val : *val);
360382
bool unused;

mlir/lib/AsmParser/TokenKinds.def

+2
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,13 @@ TOK_KEYWORD(floordiv)
111111
TOK_KEYWORD(for)
112112
TOK_KEYWORD(func)
113113
TOK_KEYWORD(index)
114+
TOK_KEYWORD(inf)
114115
TOK_KEYWORD(loc)
115116
TOK_KEYWORD(max)
116117
TOK_KEYWORD(memref)
117118
TOK_KEYWORD(min)
118119
TOK_KEYWORD(mod)
120+
TOK_KEYWORD(nan)
119121
TOK_KEYWORD(none)
120122
TOK_KEYWORD(offset)
121123
TOK_KEYWORD(size)

mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func.func @vecdim_reduction(%in: memref<256x512xf32>, %out: memref<256xf32>) {
3030
// -----
3131

3232
func.func @vecdim_reduction_minf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
33-
%cst = arith.constant 0x7F800000 : f32
33+
%cst = arith.constant inf : f32
3434
affine.for %i = 0 to 256 {
3535
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
3636
%ld = affine.load %in[%i, %j] : memref<256x512xf32>
@@ -57,7 +57,7 @@ func.func @vecdim_reduction_minf(%in: memref<256x512xf32>, %out: memref<256xf32>
5757
// -----
5858

5959
func.func @vecdim_reduction_maxf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
60-
%cst = arith.constant 0xFF800000 : f32
60+
%cst = arith.constant -inf : f32
6161
affine.for %i = 0 to 256 {
6262
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
6363
%ld = affine.load %in[%i, %j] : memref<256x512xf32>

mlir/test/Dialect/Arith/canonicalize.mlir

+5-5
Original file line numberDiff line numberDiff line change
@@ -1880,7 +1880,7 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
18801880
// CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
18811881
// CHECK-NEXT: return %[[X]], %arg0, %arg0
18821882
%c0 = arith.constant 0.0 : f32
1883-
%inf = arith.constant 0x7F800000 : f32
1883+
%inf = arith.constant inf : f32
18841884
%0 = arith.minimumf %c0, %arg0 : f32
18851885
%1 = arith.minimumf %arg0, %arg0 : f32
18861886
%2 = arith.minimumf %inf, %arg0 : f32
@@ -1895,7 +1895,7 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
18951895
// CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
18961896
// CHECK-NEXT: return %[[X]], %arg0, %arg0
18971897
%c0 = arith.constant 0.0 : f32
1898-
%-inf = arith.constant 0xFF800000 : f32
1898+
%-inf = arith.constant -inf : f32
18991899
%0 = arith.maximumf %c0, %arg0 : f32
19001900
%1 = arith.maximumf %arg0, %arg0 : f32
19011901
%2 = arith.maximumf %-inf, %arg0 : f32
@@ -1910,7 +1910,7 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
19101910
// CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
19111911
// CHECK-NEXT: return %[[X]], %arg0, %arg0
19121912
%c0 = arith.constant 0.0 : f32
1913-
%inf = arith.constant 0x7F800000 : f32
1913+
%inf = arith.constant inf : f32
19141914
%0 = arith.minnumf %c0, %arg0 : f32
19151915
%1 = arith.minnumf %arg0, %arg0 : f32
19161916
%2 = arith.minnumf %inf, %arg0 : f32
@@ -1925,7 +1925,7 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
19251925
// CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
19261926
// CHECK-NEXT: return %[[X]], %arg0, %arg0
19271927
%c0 = arith.constant 0.0 : f32
1928-
%-inf = arith.constant 0xFF800000 : f32
1928+
%-inf = arith.constant -inf : f32
19291929
%0 = arith.maxnumf %c0, %arg0 : f32
19301930
%1 = arith.maxnumf %arg0, %arg0 : f32
19311931
%2 = arith.maxnumf %-inf, %arg0 : f32
@@ -2024,7 +2024,7 @@ func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
20242024
// CHECK-DAG: %[[T:.*]] = arith.constant true
20252025
// CHECK-DAG: %[[F:.*]] = arith.constant false
20262026
// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]]
2027-
%nan = arith.constant 0x7fffffff : f32
2027+
%nan = arith.constant nan : f32
20282028
%0 = arith.cmpf olt, %nan, %arg0 : f32
20292029
%1 = arith.cmpf olt, %arg0, %nan : f32
20302030
%2 = arith.cmpf ugt, %nan, %arg0 : f32

mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,7 @@ func.func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, te
791791
// CHECK: linalg.generic
792792
// CHECK: linalg.generic
793793
func.func @no_fusion_missing_reduction_shape(%arg0: tensor<f32>, %arg1: index) -> tensor<?xf32> {
794-
%cst = arith.constant 0xFF800000 : f32
794+
%cst = arith.constant -inf : f32
795795
%4 = tensor.empty(%arg1, %arg1) : tensor<?x?xf32>
796796
%5 = linalg.generic {
797797
indexing_maps = [#map0, #map1],

mlir/test/Dialect/Tosa/canonicalize.mlir

+2-4
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ func.func @clamp_f32_not_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
8080
func.func @clamp_f16_is_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
8181
// CHECK: return %arg0
8282
// CHECK-NOT: "tosa.clamp"
83-
// 0xFF800000 and 0x7F800000 are respectively negative and positive F32 infinity.
84-
%0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = 0xFF800000 : f32, max_fp = 0x7F800000 : f32} : (tensor<4xf16>) -> tensor<4xf16>
83+
%0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = -inf : f32, max_fp = inf : f32} : (tensor<4xf16>) -> tensor<4xf16>
8584
return %0 : tensor<4xf16>
8685
}
8786

@@ -91,8 +90,7 @@ func.func @clamp_f16_is_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
9190
func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
9291
// CHECK: return %arg0
9392
// CHECK-NOT: "tosa.clamp"
94-
// 0xFF800000 and 0x7F800000 are respectively negative and positive F32 infinity.
95-
%0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = 0xFF800000 : f32, max_fp = 0x7F800000 : f32} : (tensor<4xf32>) -> tensor<4xf32>
93+
%0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = -inf : f32, max_fp = inf : f32} : (tensor<4xf32>) -> tensor<4xf32>
9694
return %0 : tensor<4xf32>
9795
}
9896

mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func.func @reciprocal_div_infinity() -> tensor<f32> {
5858
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<0.{{0*}}e+00>
5959
// CHECK-NOT: tosa.reciprocal
6060
// CHECK: return [[RES]]
61-
%0 = "tosa.const"() {value = dense<0x7F800000> : tensor<f32>} : () -> tensor<f32>
61+
%0 = "tosa.const"() {value = dense<inf> : tensor<f32>} : () -> tensor<f32>
6262
%1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
6363
return %1 : tensor<f32>
6464
}

mlir/test/IR/attribute.mlir

+54
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,63 @@ func.func @float_attrs_pass() {
108108
// CHECK: float_attr = 2.000000e+00 : f128
109109
float_attr = 2. : f128
110110
} : () -> ()
111+
"test.float_attrs"() {
112+
// Note: nan/inf are printed in binary format because there may be multiple
113+
// nan/inf representations.
114+
// CHECK: float_attr = 0x7FC00000 : f32
115+
float_attr = nan : f32
116+
} : () -> ()
117+
"test.float_attrs"() {
118+
// CHECK: float_attr = 0x7C : f8E4M3
119+
float_attr = nan : f8E4M3
120+
} : () -> ()
121+
"test.float_attrs"() {
122+
// CHECK: float_attr = 0xFFC00000 : f32
123+
float_attr = -nan : f32
124+
} : () -> ()
125+
"test.float_attrs"() {
126+
// CHECK: float_attr = 0xFC : f8E4M3
127+
float_attr = -nan : f8E4M3
128+
} : () -> ()
129+
"test.float_attrs"() {
130+
// CHECK: float_attr = 0x7F800000 : f32
131+
float_attr = inf : f32
132+
} : () -> ()
133+
"test.float_attrs"() {
134+
// CHECK: float_attr = 0x78 : f8E4M3
135+
float_attr = inf : f8E4M3
136+
} : () -> ()
137+
"test.float_attrs"() {
138+
// CHECK: float_attr = 0xFF800000 : f32
139+
float_attr = -inf : f32
140+
} : () -> ()
141+
"test.float_attrs"() {
142+
// CHECK: float_attr = 0xF8 : f8E4M3
143+
float_attr = -inf : f8E4M3
144+
} : () -> ()
111145
return
112146
}
113147

148+
// -----
149+
150+
func.func @float_nan_unsupported() {
151+
"test.float_attrs"() {
152+
// expected-error @below{{floating point type does not support NaN}}
153+
float_attr = nan : f4E2M1FN
154+
} : () -> ()
155+
}
156+
157+
// -----
158+
159+
func.func @float_inf_unsupported() {
160+
"test.float_attrs"() {
161+
// expected-error @below{{floating point type does not support infinity}}
162+
float_attr = inf : f4E2M1FN
163+
} : () -> ()
164+
}
165+
166+
// -----
167+
114168
//===----------------------------------------------------------------------===//
115169
// Test integer attributes
116170
//===----------------------------------------------------------------------===//

mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ module attributes {transform.with_named_sequence} {
470470

471471
func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
472472
%cst = arith.constant 0.000000e+00 : f32
473-
%cst_0 = arith.constant 0xFF800000 : f32
473+
%cst_0 = arith.constant -inf : f32
474474
%0 = tensor.empty() : tensor<30xf32>
475475
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
476476
%2 = linalg.generic {

mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-scfforall.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ module attributes {transform.with_named_sequence} {
101101

102102
func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
103103
%cst = arith.constant 0.000000e+00 : f32
104-
%cst_0 = arith.constant 0xFF800000 : f32
104+
%cst_0 = arith.constant -inf : f32
105105
%0 = tensor.empty() : tensor<30xf32>
106106
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
107107
%2 = linalg.generic {

mlir/test/Transforms/constant-fold.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ func.func @cmpf_nan() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
752752
// CHECK-LABEL: func @cmpf_inf
753753
func.func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
754754
%c42 = arith.constant 42. : f32
755-
%cpinf = arith.constant 0x7F800000 : f32
755+
%cpinf = arith.constant inf : f32
756756
// CHECK-DAG: [[F:%.+]] = arith.constant false
757757
// CHECK-DAG: [[T:%.+]] = arith.constant true
758758
// CHECK-NEXT: return [[F]],

0 commit comments

Comments
 (0)