[flang][runtime] Fix edge-case FP input bugs

Blanks are allowed in more places than I allowed for, and
"NAN(foobar)" is allowed to have any parenthesis-balanced
characters in parentheses.

Update: Fix up old sanity test, then avoid usage of "limit" when null.

Differential Revision: https://reviews.llvm.org/D124294
This commit is contained in:
Peter Klausler 2022-04-13 09:56:42 -07:00
parent 4683a2effa
commit f1dbf8e4ad
3 changed files with 65 additions and 21 deletions

View File

@ -408,19 +408,37 @@ BigRadixFloatingPointNumber<PREC, LOG10RADIX>::ConvertToBinary(
} else { } else {
// Could not parse a decimal floating-point number. p has been // Could not parse a decimal floating-point number. p has been
// advanced over any leading spaces. // advanced over any leading spaces.
if (toupper(p[0]) == 'N' && toupper(p[1]) == 'A' && toupper(p[2]) == 'N') { if ((!limit || limit >= p + 3) && toupper(p[0]) == 'N' &&
toupper(p[1]) == 'A' && toupper(p[2]) == 'N') {
// NaN // NaN
p += 3; p += 3;
if ((!limit || p < limit) && *p == '(') {
int depth{1};
do {
++p;
if (limit && p >= limit) {
// Invalid input
return {Real{NaN()}, Invalid};
} else if (*p == '(') {
++depth;
} else if (*p == ')') {
--depth;
}
} while (depth > 0);
++p;
}
return {Real{NaN()}}; return {Real{NaN()}};
} else { } else {
// Try to parse Inf, maybe with a sign // Try to parse Inf, maybe with a sign
const char *q{p}; const char *q{p};
isNegative_ = *q == '-'; if (!limit || q < limit) {
if (*q == '-' || *q == '+') { isNegative_ = *q == '-';
++q; if (isNegative_ || *q == '+') {
++q;
}
} }
if (toupper(q[0]) == 'I' && toupper(q[1]) == 'N' && if ((!limit || limit >= q + 3) && toupper(q[0]) == 'I' &&
toupper(q[2]) == 'F') { toupper(q[1]) == 'N' && toupper(q[2]) == 'F') {
p = q + 3; p = q + 3;
return {Real{Infinity()}}; return {Real{Infinity()}};
} else { } else {

View File

@ -176,9 +176,19 @@ static int ScanRealInput(char *buffer, int bufferSize, IoStatementState &io,
} }
} }
if (next && *next == '(') { // NaN(...) if (next && *next == '(') { // NaN(...)
while (next && *next != ')') { Put('(');
int depth{1};
do {
next = io.NextInField(remaining, edit); next = io.NextInField(remaining, edit);
} if (!next) {
break;
} else if (*next == '(') {
++depth;
} else if (*next == ')') {
--depth;
}
Put(*next);
} while (depth > 0);
} }
exponent = 0; exponent = 0;
} else if (first == decimal || (first >= '0' && first <= '9') || } else if (first == decimal || (first >= '0' && first <= '9') ||
@ -225,7 +235,7 @@ static int ScanRealInput(char *buffer, int bufferSize, IoStatementState &io,
exponent = -edit.modes.scale; exponent = -edit.modes.scale;
if (next && if (next &&
(*next == '-' || *next == '+' || (*next >= '0' && *next <= '9') || (*next == '-' || *next == '+' || (*next >= '0' && *next <= '9') ||
(bzMode && (*next == ' ' || *next == '\t')))) { *next == ' ' || *next == '\t')) {
bool negExpo{*next == '-'}; bool negExpo{*next == '-'};
if (negExpo || *next == '+') { if (negExpo || *next == '+') {
next = io.NextInField(remaining, edit); next = io.NextInField(remaining, edit);
@ -233,8 +243,10 @@ static int ScanRealInput(char *buffer, int bufferSize, IoStatementState &io,
for (exponent = 0; next; next = io.NextInField(remaining, edit)) { for (exponent = 0; next; next = io.NextInField(remaining, edit)) {
if (*next >= '0' && *next <= '9') { if (*next >= '0' && *next <= '9') {
exponent = 10 * exponent + *next - '0'; exponent = 10 * exponent + *next - '0';
} else if (bzMode && (*next == ' ' || *next == '\t')) { } else if (*next == ' ' || *next == '\t') {
exponent = 10 * exponent; if (bzMode) {
exponent = 10 * exponent;
}
} else { } else {
break; break;
} }
@ -328,11 +340,19 @@ static bool TryFastPathRealInput(
if (converted.flags & decimal::Invalid) { if (converted.flags & decimal::Invalid) {
return false; return false;
} }
if (edit.digits.value_or(0) != 0 && if (edit.digits.value_or(0) != 0) {
std::memchr(str, '.', p - str) == nullptr) { // Edit descriptor is Fw.d (or other) with d != 0, which
// No explicit decimal point, and edit descriptor is Fw.d (or other) // implies scaling
// with d != 0, which implies scaling. const char *q{str};
return false; for (; q < limit; ++q) {
if (*q == '.' || *q == 'n' || *q == 'N') {
break;
}
}
if (q == limit) {
// No explicit decimal point, and not NaN/Inf.
return false;
}
} }
for (; p < limit && (*p == ' ' || *p == '\t'); ++p) { for (; p < limit && (*p == ' ' || *p == '\t'); ++p) {
} }
@ -422,6 +442,10 @@ bool EditCommonRealInput(IoStatementState &io, const DataEdit &edit, void *n) {
converted.flags = static_cast<enum decimal::ConversionResultFlags>( converted.flags = static_cast<enum decimal::ConversionResultFlags>(
converted.flags | decimal::Inexact); converted.flags | decimal::Inexact);
} }
if (*p) { // unprocessed junk after value
io.GetIoErrorHandler().SignalError(IostatBadRealInput);
return false;
}
*reinterpret_cast<decimal::BinaryFloatingPointNumber<binaryPrecision> *>(n) = *reinterpret_cast<decimal::BinaryFloatingPointNumber<binaryPrecision> *>(n) =
converted.binary; converted.binary;
// Set FP exception flags // Set FP exception flags

View File

@ -61,13 +61,15 @@ void testReadback(float x, int flags) {
if (!(x == x)) { if (!(x == x)) {
if (y == y || *p != '\0' || (rflags & Invalid)) { if (y == y || *p != '\0' || (rflags & Invalid)) {
u.x = y; u.x = y;
failed(x) << " (NaN) " << flags << ": -> '" << result.str << "' -> 0x"; (failed(x) << " (NaN) " << flags << ": -> '" << result.str << "' -> 0x")
failed(x).write_hex(u.u) << " '" << p << "' " << rflags << '\n'; .write_hex(u.u)
<< " '" << p << "' " << rflags << '\n';
} }
} else if (x != y || *p != '\0' || (rflags & Invalid)) { } else if (x != y || *p != '\0' || (rflags & Invalid)) {
u.x = y; u.x = x;
failed(x) << ' ' << flags << ": -> '" << result.str << "' -> 0x"; (failed(x) << ' ' << flags << ": -> '" << result.str << "' -> 0x")
failed(x).write_hex(u.u) << " '" << p << "' " << rflags << '\n'; .write_hex(u.u)
<< " '" << p << "' " << rflags << '\n';
} }
} }
} }