21
21
#include " mlir/IR/DialectImplementation.h"
22
22
#include " mlir/IR/DialectResourceBlobManager.h"
23
23
#include " mlir/IR/IntegerSet.h"
24
+ #include " llvm/ADT/APFloat.h"
24
25
#include " llvm/ADT/StringExtras.h"
25
26
#include " llvm/Support/Endian.h"
27
+ #include < cmath>
26
28
#include < optional>
27
29
28
30
using namespace mlir ;
@@ -121,14 +123,16 @@ Attribute Parser::parseAttribute(Type type) {
121
123
122
124
// Parse floating point and integer attributes.
123
125
case Token::floatliteral:
126
+ case Token::kw_inf:
127
+ case Token::kw_nan:
124
128
return parseFloatAttr (type, /* isNegative=*/ false );
125
129
case Token::integer:
126
130
return parseDecOrHexAttr (type, /* isNegative=*/ false );
127
131
case Token::minus: {
128
132
consumeToken (Token::minus);
129
133
if (getToken ().is (Token::integer))
130
134
return parseDecOrHexAttr (type, /* isNegative=*/ true );
131
- if (getToken ().is (Token::floatliteral))
135
+ if (getToken ().isAny (Token::floatliteral, Token::kw_inf, Token::kw_nan ))
132
136
return parseFloatAttr (type, /* isNegative=*/ true );
133
137
134
138
return (emitWrongTokenError (
@@ -342,21 +346,25 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
342
346
343
347
// / Parse a float attribute.
344
348
Attribute Parser::parseFloatAttr (Type type, bool isNegative) {
345
- auto val = getToken ().getFloatingPointValue ();
346
- if (!val)
347
- return (emitError (" floating point value too large for attribute" ), nullptr );
348
- consumeToken (Token::floatliteral);
349
+ const Token tok = getToken ();
350
+ consumeToken ();
349
351
if (!type) {
350
352
// Default to F64 when no type is specified.
351
353
if (!consumeIf (Token::colon))
352
354
type = builder.getF64Type ();
353
355
else if (!(type = parseType ()))
354
356
return nullptr ;
355
357
}
356
- if (!isa<FloatType>(type))
357
- return (emitError (" floating point value not valid for specified type" ),
358
+ auto floatType = dyn_cast<FloatType>(type);
359
+ if (!floatType)
360
+ return (emitError (tok.getLoc (),
361
+ " floating point value not valid for specified type" ),
358
362
nullptr );
359
- return FloatAttr::get (type, isNegative ? -*val : *val);
363
+ std::optional<APFloat> apResult;
364
+ if (failed (parseFloatFromLiteral (apResult, tok, isNegative,
365
+ floatType.getFloatSemantics ())))
366
+ return Attribute ();
367
+ return FloatAttr::get (floatType, *apResult);
360
368
}
361
369
362
370
// / Construct an APint from a parsed value, a known attribute type and
@@ -622,7 +630,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
622
630
}
623
631
624
632
// Check to see if floating point values were parsed.
625
- if (token.is (Token::floatliteral)) {
633
+ if (token.isAny (Token::floatliteral, Token::kw_inf, Token::kw_nan )) {
626
634
return p.emitError (tokenLoc)
627
635
<< " expected integer elements, but parsed floating-point" ;
628
636
}
@@ -729,6 +737,8 @@ ParseResult TensorLiteralParser::parseElement() {
729
737
// Parse a boolean element.
730
738
case Token::kw_true:
731
739
case Token::kw_false:
740
+ case Token::kw_inf:
741
+ case Token::kw_nan:
732
742
case Token::floatliteral:
733
743
case Token::integer:
734
744
storage.emplace_back (/* isNegative=*/ false , p.getToken ());
@@ -738,7 +748,8 @@ ParseResult TensorLiteralParser::parseElement() {
738
748
// Parse a signed integer or a negative floating-point element.
739
749
case Token::minus:
740
750
p.consumeToken (Token::minus);
741
- if (!p.getToken ().isAny (Token::floatliteral, Token::integer))
751
+ if (!p.getToken ().isAny (Token::floatliteral, Token::kw_inf, Token::kw_nan,
752
+ Token::integer))
742
753
return p.emitError (" expected integer or floating point literal" );
743
754
storage.emplace_back (/* isNegative=*/ true , p.getToken ());
744
755
p.consumeToken ();
0 commit comments