Resolving buffer operand of linalg.view doesnt have the information

about the buffer size. This is needed to resolve the operand
correctly. Add that information to view op
serialization/deserialization

Also modify the parsing of buffer type by splitting at 'x' to
side-step issues with StringRef number parsing.

PiperOrigin-RevId: 256188319
This commit is contained in:
Mahesh Ravishankar 2019-07-02 10:13:34 -07:00 committed by jpienaar
parent 5c4ae813ee
commit 25094e90bd
8 changed files with 84 additions and 48 deletions

View File

@ -205,6 +205,12 @@ public:
// Token Parsing
//===--------------------------------------------------------------------===//
/// Parse a '->' token.
virtual ParseResult parseArrow() = 0;
/// Parse a '->' token if present
virtual ParseResult parseOptionalArrow() = 0;
/// Parse a `:` token.
virtual ParseResult parseColon() = 0;

View File

@ -490,12 +490,19 @@ ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType bufferInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
Type type;
Type bType, type;
if (parser->parseOperand(bufferInfo) ||
parser->parseOperandList(indexingsInfo, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
parser->parseColon() || parser->parseType(bType) ||
parser->parseArrow() || parser->parseType(type)) {
return failure();
}
BufferType bufferType = bType.dyn_cast<BufferType>();
if (!bufferType) {
return parser->emitError(parser->getNameLoc(), "buffer type expected");
}
ViewType viewType = type.dyn_cast<ViewType>();
if (!viewType)
@ -504,10 +511,7 @@ ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
return parser->emitError(parser->getNameLoc(), "expected")
<< viewType.getRank() << " range indexings";
return failure(
parser->resolveOperand(
bufferInfo,
BufferType::get(type.getContext(), viewType.getElementType()),
result->operands) ||
parser->resolveOperand(bufferInfo, bufferType, result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo, RangeType::get(type.getContext()),
result->operands)) ||
@ -517,7 +521,7 @@ ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
// A ViewOp prints as:
//
// ```{.mlir}
// linalg.view %0[%1, %2] : !linalg.view<?x?xf32>
// linalg.view %0[%1, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
// ```
//
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
@ -527,7 +531,7 @@ void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
interleave(
getIndexings().begin(), getIndexings().end(), [&](Value *v) { *p << *v; },
[&]() { *p << ", "; });
*p << "] : " << getType();
*p << "] : " << getSupportingBuffer()->getType() << " -> " << getType();
}
///////////////////// Operations defined with Tablegen /////////////////////////

View File

@ -26,6 +26,7 @@
#include "mlir/Parser.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@ -115,24 +116,32 @@ Type mlir::linalg::LinalgDialect::parseType(StringRef spec,
return RangeType::get(getContext());
else if (spec.consume_front("buffer")) {
if (spec.consume_front("<") && spec.consume_back(">")) {
StringRef sizeSpec, typeSpec;
std::tie(sizeSpec, typeSpec) = spec.split('x');
if (typeSpec.empty()) {
emitError(loc, "expected 'x' followed by element type");
return Type();
}
// Check for '?'
int64_t bufferSize = -1;
if (!spec.consume_front("?")) {
unsigned long long parsedBufferSize = 0;
if (spec.consumeInteger(10, parsedBufferSize)) {
if (!sizeSpec.consume_front("?")) {
if (sizeSpec.consumeInteger(10, bufferSize)) {
emitError(loc, "expected buffer size to be an unsigned integer");
return Type();
}
bufferSize = static_cast<int64_t>(parsedBufferSize);
}
if (!spec.consume_front("x")) {
emitError(loc, "missing x in buffer type descrition : ") << spec;
if (!sizeSpec.empty()) {
emitError(loc, "unexpected token '") << sizeSpec << "'";
}
typeSpec = typeSpec.trim();
auto t = mlir::parseType(typeSpec, context);
if (!t) {
emitError(loc, "invalid type specification: '") << typeSpec << "'";
return Type();
}
if (auto t = mlir::parseType(spec, context))
return (bufferSize == -1
? BufferType::get(getContext(), t)
: BufferType::get(getContext(), t, bufferSize));
return (bufferSize == -1 ? BufferType::get(getContext(), t)
: BufferType::get(getContext(), t, bufferSize));
}
} else if (spec.consume_front("view")) {
if (spec.consume_front("<") && spec.consume_back(">")) {

View File

@ -3199,6 +3199,16 @@ public:
// Token Parsing
//===--------------------------------------------------------------------===//
/// Parse a `->` token.
ParseResult parseArrow() override {
return parser.parseToken(Token::arrow, "expected '->'");
}
/// Parses a `->` if present.
ParseResult parseOptionalArrow() override {
return success(parser.consumeIf(Token::arrow));
}
/// Parse a `:` token.
ParseResult parseColon() override {
return parser.parseToken(Token::colon, "expected ':'");

View File

@ -22,7 +22,7 @@ func @range(%arg0: index) {
// CHECK-NEXT: %5 = llvm.insertvalue %1, %4[2] : !llvm<"{ i64, i64, i64 }">
func @view(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
%0 = linalg.view %arg0[%arg1] : !linalg.view<?xf32>
%0 = linalg.view %arg0[%arg1] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
return
}
// CHECK-LABEL: func @view(%arg0: !llvm<"{ float*, i64 }">, %arg1: !llvm<"{ i64, i64, i64 }">) {
@ -41,7 +41,7 @@ func @view(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
// CHECK-NEXT: %12 = llvm.insertvalue %11, %8[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
func @view3d(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) {
%0 = linalg.view %arg0[%arg1, %arg2, %arg3] : !linalg.view<?x?x?xf32>
%0 = linalg.view %arg0[%arg1, %arg2, %arg3] : !linalg.buffer<?xf32> -> !linalg.view<?x?x?xf32>
return
}
// CHECK-LABEL: func @view3d(%arg0: !llvm<"{ float*, i64 }">, %arg1: !llvm<"{ i64, i64, i64 }">, %arg2: !llvm<"{ i64, i64, i64 }">, %arg3: !llvm<"{ i64, i64, i64 }">) {
@ -56,7 +56,7 @@ func @view3d(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range, %arg2: !linalg.
// CHECK-NEXT: %16 = llvm.insertvalue %15, %12[3, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
func @slice(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
%0 = linalg.view %arg0[%arg1] : !linalg.view<?xf32>
%0 = linalg.view %arg0[%arg1] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
%1 = linalg.slice %0[%arg1] : !linalg.view<?xf32>, !linalg.range, !linalg.view<?xf32>
return
}

View File

@ -10,16 +10,16 @@ func @matmul(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: in
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
%J = linalg.range %c0:%arg2:%c1 : !linalg.range
%K = linalg.range %c0:%arg3:%c1 : !linalg.range
%A = linalg.view %arg0[%I, %K] : !linalg.view<?x?xf32>
%B = linalg.view %arg0[%K, %J] : !linalg.view<?x?xf32>
%C = linalg.view %arg0[%I, %J] : !linalg.view<?x?xf32>
%A = linalg.view %arg0[%I, %K] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
%B = linalg.view %arg0[%K, %J] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
%C = linalg.view %arg0[%I, %J] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
return
}
// CHECK-LABEL: func @matmul(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
// CHECK: %[[M:.*]] = linalg.dim %[[A]], 0 : !linalg.view<?x?xf32>
// CHECK: %[[K:.*]] = linalg.dim %[[A]], 1 : !linalg.view<?x?xf32>
// CHECK: %[[N:.*]] = linalg.dim %[[B]], 1 : !linalg.view<?x?xf32>
@ -38,16 +38,16 @@ func @matvec(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: in
%c1 = constant 1 : index
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
%J = linalg.range %c0:%arg2:%c1 : !linalg.range
%2 = linalg.view %arg0[%I, %J] : !linalg.view<?x?xf32>
%3 = linalg.view %arg0[%J] : !linalg.view<?xf32>
%4 = linalg.view %arg0[%I] : !linalg.view<?xf32>
%2 = linalg.view %arg0[%I, %J] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
%3 = linalg.view %arg0[%J] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
%4 = linalg.view %arg0[%I] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
linalg.matvec(%2, %3, %4) : !linalg.view<?x?xf32>, !linalg.view<?xf32>, !linalg.view<?xf32>
return
}
// CHECK-LABEL: func @matvec(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
// CHECK: %[[M:.*]] = linalg.dim %[[A]], 0 : !linalg.view<?x?xf32>
// CHECK: %[[K:.*]] = linalg.dim %[[A]], 1 : !linalg.view<?x?xf32>
// CHECK: linalg.for %i0 = %c0 to %[[M]] step %c1 {
@ -63,16 +63,16 @@ func @dot(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index
%c0 = constant 0 : index
%c1 = constant 1 : index
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
%1 = linalg.view %arg0[%I] : !linalg.view<?xf32>
%2 = linalg.view %arg0[%I] : !linalg.view<?xf32>
%3 = linalg.view %arg0[] : !linalg.view<f32>
%1 = linalg.view %arg0[%I] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
%2 = linalg.view %arg0[%I] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
%3 = linalg.view %arg0[] : !linalg.buffer<?xf32> -> !linalg.view<f32>
linalg.dot(%1, %2, %3) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
return
}
// CHECK-LABEL: func @dot(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
// CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.view<f32>
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
// CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.buffer<?xf32> -> !linalg.view<f32>
// CHECK: %[[K:.*]] = linalg.dim %[[A]], 0 : !linalg.view<?xf32>
// CHECK: linalg.for %i0 = %c0 to %[[K]] step %c1 {
// CHECK-DAG: %[[a:.*]] = linalg.load %[[A]][%i0] : !linalg.view<?xf32>

View File

@ -34,7 +34,7 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
%0 = muli %arg0, %arg0 : index
%1 = linalg.buffer_alloc %0 : !linalg.buffer<?xf32>
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
%3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
%3 = linalg.view %1[%2, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
%4 = linalg.slice %3[%2, %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
%5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
%6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
@ -46,7 +46,7 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xf32>
// CHECK-NEXT: %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
// CHECK-NEXT: %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
// CHECK-NEXT: %3 = linalg.view %1[%2, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
// CHECK-NEXT: %4 = linalg.slice %3[%2, %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK-NEXT: %5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
// CHECK-NEXT: %6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
@ -157,3 +157,10 @@ func @subview(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
// CHECK-LABEL: func @subview(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
// CHECK: %c0 = constant 0 : index
// CHECK: %0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : !linalg.view<?x?xvector<3x4xi4>>
func @const_buffer_view(%arg0: index, %arg1: index, %arg2: index) {
%c0 = linalg.buffer_alloc : !linalg.buffer<17xf32>
%c1 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
%c2 = linalg.view %c0[%c1] : !linalg.buffer<17xf32> -> !linalg.view<?xf32>
return
}

View File

@ -8,7 +8,7 @@ func @fill_f32(%arg0 : !linalg.buffer<?xf32>, %f : f32) {
%c1 = constant 1 : index
%s = linalg.buffer_size %arg0 : !linalg.buffer<?xf32>
%R = linalg.range %c0:%s:%c1 : !linalg.range
%V = linalg.view %arg0[%R] : !linalg.view<?xf32>
%V = linalg.view %arg0[%R] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
affine.for %i0 = 0 to %s {
linalg.store %f, %V[%i0] : !linalg.view<?xf32>
}
@ -34,9 +34,9 @@ func @dot() -> f32 {
%bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (!linalg.buffer<?xf32>)
%R = linalg.range %c0:%c16:%c1 : !linalg.range
%A = linalg.view %bA[%R] : !linalg.view<?xf32>
%B = linalg.view %bB[%R] : !linalg.view<?xf32>
%C = linalg.view %bC[] : !linalg.view<f32>
%A = linalg.view %bA[%R] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
%B = linalg.view %bB[%R] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
%C = linalg.view %bC[] : !linalg.buffer<?xf32> -> !linalg.view<f32>
linalg.dot(%A, %B, %C) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
%res = linalg.load %C[] : !linalg.view<f32>
@ -68,9 +68,9 @@ func @matmul() -> f32 {
%M = linalg.range %c0:%c10:%c1 : !linalg.range
%N = linalg.range %c0:%c10:%c1 : !linalg.range
%K = linalg.range %c0:%c16:%c1 : !linalg.range
%A = linalg.view %bA[%M, %K] : !linalg.view<?x?xf32>
%B = linalg.view %bB[%K, %N] : !linalg.view<?x?xf32>
%C = linalg.view %bC[%M, %N] : !linalg.view<?x?xf32>
%A = linalg.view %bA[%M, %K] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
%B = linalg.view %bB[%K, %N] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
%C = linalg.view %bC[%M, %N] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
%res = linalg.load %C[%c6, %c7] : !linalg.view<?x?xf32>