Skip to content

[mlir][Parser] Add nan and inf keywords #116176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions mlir/lib/AsmParser/AttributeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Endian.h"
#include <cmath>
#include <optional>

using namespace mlir;
Expand Down Expand Up @@ -121,14 +123,16 @@ Attribute Parser::parseAttribute(Type type) {

// Parse floating point and integer attributes.
case Token::floatliteral:
case Token::kw_inf:
case Token::kw_nan:
return parseFloatAttr(type, /*isNegative=*/false);
case Token::integer:
return parseDecOrHexAttr(type, /*isNegative=*/false);
case Token::minus: {
consumeToken(Token::minus);
if (getToken().is(Token::integer))
return parseDecOrHexAttr(type, /*isNegative=*/true);
if (getToken().is(Token::floatliteral))
if (getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan))
return parseFloatAttr(type, /*isNegative=*/true);

return (emitWrongTokenError(
Expand Down Expand Up @@ -342,21 +346,25 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {

/// Parse a float attribute.
Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
auto val = getToken().getFloatingPointValue();
if (!val)
return (emitError("floating point value too large for attribute"), nullptr);
consumeToken(Token::floatliteral);
const Token tok = getToken();
consumeToken();
if (!type) {
// Default to F64 when no type is specified.
if (!consumeIf(Token::colon))
type = builder.getF64Type();
else if (!(type = parseType()))
return nullptr;
}
if (!isa<FloatType>(type))
return (emitError("floating point value not valid for specified type"),
auto floatType = dyn_cast<FloatType>(type);
if (!floatType)
return (emitError(tok.getLoc(),
"floating point value not valid for specified type"),
nullptr);
return FloatAttr::get(type, isNegative ? -*val : *val);
std::optional<APFloat> apResult;
if (failed(parseFloatFromLiteral(apResult, tok, isNegative,
floatType.getFloatSemantics())))
return Attribute();
return FloatAttr::get(floatType, *apResult);
}

/// Construct an APint from a parsed value, a known attribute type and
Expand Down Expand Up @@ -622,7 +630,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
}

// Check to see if floating point values were parsed.
if (token.is(Token::floatliteral)) {
if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) {
return p.emitError(tokenLoc)
<< "expected integer elements, but parsed floating-point";
}
Expand Down Expand Up @@ -729,6 +737,8 @@ ParseResult TensorLiteralParser::parseElement() {
// Parse a boolean element.
case Token::kw_true:
case Token::kw_false:
case Token::kw_inf:
case Token::kw_nan:
case Token::floatliteral:
case Token::integer:
storage.emplace_back(/*isNegative=*/false, p.getToken());
Expand All @@ -738,7 +748,8 @@ ParseResult TensorLiteralParser::parseElement() {
// Parse a signed integer or a negative floating-point element.
case Token::minus:
p.consumeToken(Token::minus);
if (!p.getToken().isAny(Token::floatliteral, Token::integer))
if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan,
Token::integer))
return p.emitError("expected integer or floating point literal");
storage.emplace_back(/*isNegative=*/true, p.getToken());
p.consumeToken();
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/AsmParser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,33 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
ParseResult Parser::parseFloatFromLiteral(std::optional<APFloat> &result,
const Token &tok, bool isNegative,
const llvm::fltSemantics &semantics) {
// Check for inf keyword.
if (tok.is(Token::kw_inf)) {
if (!APFloat::semanticsHasInf(semantics))
return emitError(tok.getLoc())
<< "floating point type does not support infinity";
result = APFloat::getInf(semantics, isNegative);
return success();
}

// Check for NaN keyword.
if (tok.is(Token::kw_nan)) {
if (!APFloat::semanticsHasNaN(semantics))
return emitError(tok.getLoc())
<< "floating point type does not support NaN";
result = APFloat::getNaN(semantics, isNegative);
return success();
}

// Check for a floating point value.
if (tok.is(Token::floatliteral)) {
auto val = tok.getFloatingPointValue();
if (!val)
return emitError(tok.getLoc()) << "floating point value too large";
if (std::fpclassify(*val) == FP_ZERO &&
!APFloat::semanticsHasZero(semantics))
return emitError(tok.getLoc())
<< "floating point type does not support zero";

result.emplace(isNegative ? -*val : *val);
bool unused;
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/AsmParser/TokenKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,13 @@ TOK_KEYWORD(floordiv)
TOK_KEYWORD(for)
TOK_KEYWORD(func)
TOK_KEYWORD(index)
TOK_KEYWORD(inf)
TOK_KEYWORD(loc)
TOK_KEYWORD(max)
TOK_KEYWORD(memref)
TOK_KEYWORD(min)
TOK_KEYWORD(mod)
TOK_KEYWORD(nan)
TOK_KEYWORD(none)
TOK_KEYWORD(offset)
TOK_KEYWORD(size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func.func @vecdim_reduction(%in: memref<256x512xf32>, %out: memref<256xf32>) {
// -----

func.func @vecdim_reduction_minf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
%cst = arith.constant 0x7F800000 : f32
%cst = arith.constant inf : f32
affine.for %i = 0 to 256 {
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
%ld = affine.load %in[%i, %j] : memref<256x512xf32>
Expand All @@ -57,7 +57,7 @@ func.func @vecdim_reduction_minf(%in: memref<256x512xf32>, %out: memref<256xf32>
// -----

func.func @vecdim_reduction_maxf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
%cst = arith.constant 0xFF800000 : f32
%cst = arith.constant -inf : f32
affine.for %i = 0 to 256 {
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
%ld = affine.load %in[%i, %j] : memref<256x512xf32>
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/Arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1880,7 +1880,7 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
%inf = arith.constant 0x7F800000 : f32
%inf = arith.constant inf : f32
%0 = arith.minimumf %c0, %arg0 : f32
%1 = arith.minimumf %arg0, %arg0 : f32
%2 = arith.minimumf %inf, %arg0 : f32
Expand All @@ -1895,7 +1895,7 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
%-inf = arith.constant 0xFF800000 : f32
%-inf = arith.constant -inf : f32
%0 = arith.maximumf %c0, %arg0 : f32
%1 = arith.maximumf %arg0, %arg0 : f32
%2 = arith.maximumf %-inf, %arg0 : f32
Expand All @@ -1910,7 +1910,7 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
%inf = arith.constant 0x7F800000 : f32
%inf = arith.constant inf : f32
%0 = arith.minnumf %c0, %arg0 : f32
%1 = arith.minnumf %arg0, %arg0 : f32
%2 = arith.minnumf %inf, %arg0 : f32
Expand All @@ -1925,7 +1925,7 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
%-inf = arith.constant 0xFF800000 : f32
%-inf = arith.constant -inf : f32
%0 = arith.maxnumf %c0, %arg0 : f32
%1 = arith.maxnumf %arg0, %arg0 : f32
%2 = arith.maxnumf %-inf, %arg0 : f32
Expand Down Expand Up @@ -2024,7 +2024,7 @@ func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
// CHECK-DAG: %[[T:.*]] = arith.constant true
// CHECK-DAG: %[[F:.*]] = arith.constant false
// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]]
%nan = arith.constant 0x7fffffff : f32
%nan = arith.constant nan : f32
%0 = arith.cmpf olt, %nan, %arg0 : f32
%1 = arith.cmpf olt, %arg0, %nan : f32
%2 = arith.cmpf ugt, %nan, %arg0 : f32
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ func.func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, te
// CHECK: linalg.generic
// CHECK: linalg.generic
func.func @no_fusion_missing_reduction_shape(%arg0: tensor<f32>, %arg1: index) -> tensor<?xf32> {
%cst = arith.constant 0xFF800000 : f32
%cst = arith.constant -inf : f32
%4 = tensor.empty(%arg1, %arg1) : tensor<?x?xf32>
%5 = linalg.generic {
indexing_maps = [#map0, #map1],
Expand Down
6 changes: 2 additions & 4 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ func.func @clamp_f32_not_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
func.func @clamp_f16_is_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
// CHECK: return %arg0
// CHECK-NOT: "tosa.clamp"
// 0xFF800000 and 0x7F800000 are respectively negative and positive F32 infinity.
%0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = 0xFF800000 : f32, max_fp = 0x7F800000 : f32} : (tensor<4xf16>) -> tensor<4xf16>
%0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = -inf : f32, max_fp = inf : f32} : (tensor<4xf16>) -> tensor<4xf16>
return %0 : tensor<4xf16>
}

Expand All @@ -91,8 +90,7 @@ func.func @clamp_f16_is_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: return %arg0
// CHECK-NOT: "tosa.clamp"
// 0xFF800000 and 0x7F800000 are respectively negative and positive F32 infinity.
%0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = 0xFF800000 : f32, max_fp = 0x7F800000 : f32} : (tensor<4xf32>) -> tensor<4xf32>
%0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = -inf : f32, max_fp = inf : f32} : (tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func.func @reciprocal_div_infinity() -> tensor<f32> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<0.{{0*}}e+00>
// CHECK-NOT: tosa.reciprocal
// CHECK: return [[RES]]
%0 = "tosa.const"() {value = dense<0x7F800000> : tensor<f32>} : () -> tensor<f32>
%0 = "tosa.const"() {value = dense<inf> : tensor<f32>} : () -> tensor<f32>
%1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
return %1 : tensor<f32>
}
Expand Down
54 changes: 54 additions & 0 deletions mlir/test/IR/attribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,63 @@ func.func @float_attrs_pass() {
// CHECK: float_attr = 2.000000e+00 : f128
float_attr = 2. : f128
} : () -> ()
"test.float_attrs"() {
// Note: nan/inf are printed in binary format because there may be multiple
// nan/inf representations.
// CHECK: float_attr = 0x7FC00000 : f32
float_attr = nan : f32
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 0x7C : f8E4M3
float_attr = nan : f8E4M3
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 0xFFC00000 : f32
float_attr = -nan : f32
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 0xFC : f8E4M3
float_attr = -nan : f8E4M3
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 0x7F800000 : f32
float_attr = inf : f32
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 0x78 : f8E4M3
float_attr = inf : f8E4M3
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 0xFF800000 : f32
float_attr = -inf : f32
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 0xF8 : f8E4M3
float_attr = -inf : f8E4M3
} : () -> ()
return
}

// -----

func.func @float_nan_unsupported() {
"test.float_attrs"() {
// expected-error @below{{floating point type does not support NaN}}
float_attr = nan : f4E2M1FN
} : () -> ()
}

// -----

func.func @float_inf_unsupported() {
"test.float_attrs"() {
// expected-error @below{{floating point type does not support infinity}}
float_attr = inf : f4E2M1FN
} : () -> ()
}

// -----

//===----------------------------------------------------------------------===//
// Test integer attributes
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ module attributes {transform.with_named_sequence} {

func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 0xFF800000 : f32
%cst_0 = arith.constant -inf : f32
%0 = tensor.empty() : tensor<30xf32>
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
%2 = linalg.generic {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ module attributes {transform.with_named_sequence} {

func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 0xFF800000 : f32
%cst_0 = arith.constant -inf : f32
%0 = tensor.empty() : tensor<30xf32>
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
%2 = linalg.generic {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Transforms/constant-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ func.func @cmpf_nan() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
// CHECK-LABEL: func @cmpf_inf
func.func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
%c42 = arith.constant 42. : f32
%cpinf = arith.constant 0x7F800000 : f32
%cpinf = arith.constant inf : f32
// CHECK-DAG: [[F:%.+]] = arith.constant false
// CHECK-DAG: [[T:%.+]] = arith.constant true
// CHECK-NEXT: return [[F]],
Expand Down
Loading
Loading