Refactor LinalgDialect::parseType to use the DialectAsmParser methods directly.

This simplifies the implementation, and removes the need to do explicit string manipulation. A utility method 'parseDimensionList' is added to the DialectAsmParser to simplify defining types and attributes that contain shapes.

PiperOrigin-RevId: 278020604
This commit is contained in:
River Riddle 2019-11-01 16:13:37 -07:00 committed by A. Unique TensorFlower
parent e94a8bfca8
commit 68cfc89a0d
4 changed files with 68 additions and 38 deletions

View File

@ -295,6 +295,19 @@ public:
return emitError(loc, "invalid kind of type specified");
return success();
}
/// Parse a 'x' separated dimension list. This populates the dimension list,
/// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on
/// `?` otherwise.
///
/// dimension-list ::= (dimension `x`)*
/// dimension ::= `?` | integer
///
/// When `allowDynamic` is not set, this is used to parse:
///
/// static-dimension-list ::= (integer `x`)*
virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
bool allowDynamic = true) = 0;
};
} // end namespace mlir

View File

@ -109,54 +109,46 @@ Optional<int64_t> mlir::linalg::BufferType::getBufferSize() {
}
Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const {
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
StringRef spec = parser.getFullSymbolSpec();
StringRef origSpec = spec;
// Parse the main keyword for the type.
StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type();
MLIRContext *context = getContext();
if (spec == "range")
return RangeType::get(getContext());
else if (spec.consume_front("buffer")) {
if (spec.consume_front("<") && spec.consume_back(">")) {
StringRef sizeSpec, typeSpec;
std::tie(sizeSpec, typeSpec) = spec.split('x');
if (typeSpec.empty()) {
emitError(loc, "expected 'x' followed by element type");
return Type();
}
// Check for '?'
int64_t bufferSize = -1;
if (!sizeSpec.consume_front("?")) {
if (sizeSpec.consumeInteger(10, bufferSize)) {
emitError(loc, "expected buffer size to be an unsigned integer");
return Type();
}
}
if (!sizeSpec.empty()) {
emitError(loc, "unexpected token '") << sizeSpec << "'";
}
typeSpec = typeSpec.trim();
auto t = mlir::parseType(typeSpec, context);
if (!t) {
emitError(loc, "invalid type specification: '") << typeSpec << "'";
return Type();
}
return (bufferSize == -1 ? BufferType::get(getContext(), t)
: BufferType::get(getContext(), t, bufferSize));
// Handle 'range' types.
if (keyword == "range")
return RangeType::get(context);
// Handle 'buffer' types.
if (keyword == "buffer") {
llvm::SMLoc dimensionLoc;
SmallVector<int64_t, 1> size;
Type type;
if (parser.parseLess() || parser.getCurrentLocation(&dimensionLoc) ||
parser.parseDimensionList(size) || parser.parseType(type) ||
parser.parseGreater())
return Type();
if (size.size() != 1) {
parser.emitError(dimensionLoc, "expected single element in size list");
return Type();
}
return (size.front() == -1 ? BufferType::get(context, type)
: BufferType::get(context, type, size.front()));
}
return (emitError(loc, "unknown Linalg type: " + origSpec), Type());
parser.emitError(parser.getNameLoc(), "unknown Linalg type: " + keyword);
return Type();
}
/// BufferType prints as "buffer<element_type>".
/// BufferType prints as "buffer<size x element_type>".
static void print(BufferType bt, DialectAsmPrinter &os) {
os << "buffer<";
auto bs = bt.getBufferSize();
if (bs) {
if (Optional<int64_t> bs = bt.getBufferSize())
os << bs.getValue();
} else {
else
os << "?";
}
os << "x" << bt.getElementType() << ">";
}

View File

@ -633,6 +633,11 @@ public:
return success(static_cast<bool>(result));
}
ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
bool allowDynamic) override {
return parser.parseDimensionListRanked(dimensions, allowDynamic);
}
private:
/// The full symbol specification.
StringRef fullSpec;

View File

@ -349,3 +349,23 @@ func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)
linalg.yield %0: i1
}: memref<?xf32, (i)[off]->(off + i)>
}
// -----
// expected-error @+1 {{unknown Linalg type}}
!invalid_type = type !linalg.unknown
// -----
// expected-error @+1 {{expected single element in size list}}
!invalid_type = type !linalg.buffer<1x1xf32>
// -----
// expected-error @+1 {{expected '>'}}
!invalid_type = type !linalg<"buffer<1xf32">
// -----
// expected-error @+1 {{expected valid keyword}}
!invalid_type = type !linalg<"?">