forked from OSchip/llvm-project
Refactor LinalgDialect::parseType to use the DialectAsmParser methods directly.
This simplifies the implementation, and removes the need to do explicit string manipulation. A utility method 'parseDimensionList' is added to the DialectAsmParser to simplify defining types and attributes that contain shapes. PiperOrigin-RevId: 278020604
This commit is contained in:
parent
e94a8bfca8
commit
68cfc89a0d
|
@ -295,6 +295,19 @@ public:
|
|||
return emitError(loc, "invalid kind of type specified");
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Parse a 'x' separated dimension list. This populates the dimension list,
|
||||
/// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on
|
||||
/// `?` otherwise.
|
||||
///
|
||||
/// dimension-list ::= (dimension `x`)*
|
||||
/// dimension ::= `?` | integer
|
||||
///
|
||||
/// When `allowDynamic` is not set, this is used to parse:
|
||||
///
|
||||
/// static-dimension-list ::= (integer `x`)*
|
||||
virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
|
||||
bool allowDynamic = true) = 0;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -109,54 +109,46 @@ Optional<int64_t> mlir::linalg::BufferType::getBufferSize() {
|
|||
}
|
||||
|
||||
Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const {
|
||||
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
|
||||
StringRef spec = parser.getFullSymbolSpec();
|
||||
StringRef origSpec = spec;
|
||||
// Parse the main keyword for the type.
|
||||
StringRef keyword;
|
||||
if (parser.parseKeyword(&keyword))
|
||||
return Type();
|
||||
MLIRContext *context = getContext();
|
||||
if (spec == "range")
|
||||
return RangeType::get(getContext());
|
||||
else if (spec.consume_front("buffer")) {
|
||||
if (spec.consume_front("<") && spec.consume_back(">")) {
|
||||
StringRef sizeSpec, typeSpec;
|
||||
std::tie(sizeSpec, typeSpec) = spec.split('x');
|
||||
if (typeSpec.empty()) {
|
||||
emitError(loc, "expected 'x' followed by element type");
|
||||
return Type();
|
||||
}
|
||||
// Check for '?'
|
||||
int64_t bufferSize = -1;
|
||||
if (!sizeSpec.consume_front("?")) {
|
||||
if (sizeSpec.consumeInteger(10, bufferSize)) {
|
||||
emitError(loc, "expected buffer size to be an unsigned integer");
|
||||
return Type();
|
||||
}
|
||||
}
|
||||
if (!sizeSpec.empty()) {
|
||||
emitError(loc, "unexpected token '") << sizeSpec << "'";
|
||||
}
|
||||
|
||||
typeSpec = typeSpec.trim();
|
||||
auto t = mlir::parseType(typeSpec, context);
|
||||
if (!t) {
|
||||
emitError(loc, "invalid type specification: '") << typeSpec << "'";
|
||||
return Type();
|
||||
}
|
||||
return (bufferSize == -1 ? BufferType::get(getContext(), t)
|
||||
: BufferType::get(getContext(), t, bufferSize));
|
||||
// Handle 'range' types.
|
||||
if (keyword == "range")
|
||||
return RangeType::get(context);
|
||||
|
||||
// Handle 'buffer' types.
|
||||
if (keyword == "buffer") {
|
||||
llvm::SMLoc dimensionLoc;
|
||||
SmallVector<int64_t, 1> size;
|
||||
Type type;
|
||||
if (parser.parseLess() || parser.getCurrentLocation(&dimensionLoc) ||
|
||||
parser.parseDimensionList(size) || parser.parseType(type) ||
|
||||
parser.parseGreater())
|
||||
return Type();
|
||||
|
||||
if (size.size() != 1) {
|
||||
parser.emitError(dimensionLoc, "expected single element in size list");
|
||||
return Type();
|
||||
}
|
||||
|
||||
return (size.front() == -1 ? BufferType::get(context, type)
|
||||
: BufferType::get(context, type, size.front()));
|
||||
}
|
||||
return (emitError(loc, "unknown Linalg type: " + origSpec), Type());
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "unknown Linalg type: " + keyword);
|
||||
return Type();
|
||||
}
|
||||
|
||||
/// BufferType prints as "buffer<element_type>".
|
||||
/// BufferType prints as "buffer<size x element_type>".
|
||||
static void print(BufferType bt, DialectAsmPrinter &os) {
|
||||
os << "buffer<";
|
||||
auto bs = bt.getBufferSize();
|
||||
if (bs) {
|
||||
if (Optional<int64_t> bs = bt.getBufferSize())
|
||||
os << bs.getValue();
|
||||
} else {
|
||||
else
|
||||
os << "?";
|
||||
}
|
||||
os << "x" << bt.getElementType() << ">";
|
||||
}
|
||||
|
||||
|
|
|
@ -633,6 +633,11 @@ public:
|
|||
return success(static_cast<bool>(result));
|
||||
}
|
||||
|
||||
ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
|
||||
bool allowDynamic) override {
|
||||
return parser.parseDimensionListRanked(dimensions, allowDynamic);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The full symbol specification.
|
||||
StringRef fullSpec;
|
||||
|
|
|
@ -349,3 +349,23 @@ func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)
|
|||
linalg.yield %0: i1
|
||||
}: memref<?xf32, (i)[off]->(off + i)>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{unknown Linalg type}}
|
||||
!invalid_type = type !linalg.unknown
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected single element in size list}}
|
||||
!invalid_type = type !linalg.buffer<1x1xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected '>'}}
|
||||
!invalid_type = type !linalg<"buffer<1xf32">
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected valid keyword}}
|
||||
!invalid_type = type !linalg<"?">
|
||||
|
|
Loading…
Reference in New Issue