forked from OSchip/llvm-project
Resolving buffer operand of linalg.view doesnt have the information
about the buffer size. This is needed to resolve the operand correctly. Add that information to view op serialization/deserialization Also modify the parsing of buffer type by splitting at 'x' to side-step issues with StringRef number parsing. PiperOrigin-RevId: 256188319
This commit is contained in:
parent
5c4ae813ee
commit
25094e90bd
|
@ -205,6 +205,12 @@ public:
|
|||
// Token Parsing
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Parse a '->' token.
|
||||
virtual ParseResult parseArrow() = 0;
|
||||
|
||||
/// Parse a '->' token if present
|
||||
virtual ParseResult parseOptionalArrow() = 0;
|
||||
|
||||
/// Parse a `:` token.
|
||||
virtual ParseResult parseColon() = 0;
|
||||
|
||||
|
|
|
@ -490,12 +490,19 @@ ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
|
|||
OperationState *result) {
|
||||
OpAsmParser::OperandType bufferInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
|
||||
Type type;
|
||||
Type bType, type;
|
||||
if (parser->parseOperand(bufferInfo) ||
|
||||
parser->parseOperandList(indexingsInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(type))
|
||||
parser->parseColon() || parser->parseType(bType) ||
|
||||
parser->parseArrow() || parser->parseType(type)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
BufferType bufferType = bType.dyn_cast<BufferType>();
|
||||
if (!bufferType) {
|
||||
return parser->emitError(parser->getNameLoc(), "buffer type expected");
|
||||
}
|
||||
|
||||
ViewType viewType = type.dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
|
@ -504,10 +511,7 @@ ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
|
|||
return parser->emitError(parser->getNameLoc(), "expected")
|
||||
<< viewType.getRank() << " range indexings";
|
||||
return failure(
|
||||
parser->resolveOperand(
|
||||
bufferInfo,
|
||||
BufferType::get(type.getContext(), viewType.getElementType()),
|
||||
result->operands) ||
|
||||
parser->resolveOperand(bufferInfo, bufferType, result->operands) ||
|
||||
(!indexingsInfo.empty() &&
|
||||
parser->resolveOperands(indexingsInfo, RangeType::get(type.getContext()),
|
||||
result->operands)) ||
|
||||
|
@ -517,7 +521,7 @@ ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
|
|||
// A ViewOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.view %0[%1, %2] : !linalg.view<?x?xf32>
|
||||
// linalg.view %0[%1, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
// ```
|
||||
//
|
||||
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
|
||||
|
@ -527,7 +531,7 @@ void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
|
|||
interleave(
|
||||
getIndexings().begin(), getIndexings().end(), [&](Value *v) { *p << *v; },
|
||||
[&]() { *p << ", "; });
|
||||
*p << "] : " << getType();
|
||||
*p << "] : " << getSupportingBuffer()->getType() << " -> " << getType();
|
||||
}
|
||||
|
||||
///////////////////// Operations defined with Tablegen /////////////////////////
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "mlir/Parser.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -115,24 +116,32 @@ Type mlir::linalg::LinalgDialect::parseType(StringRef spec,
|
|||
return RangeType::get(getContext());
|
||||
else if (spec.consume_front("buffer")) {
|
||||
if (spec.consume_front("<") && spec.consume_back(">")) {
|
||||
StringRef sizeSpec, typeSpec;
|
||||
std::tie(sizeSpec, typeSpec) = spec.split('x');
|
||||
if (typeSpec.empty()) {
|
||||
emitError(loc, "expected 'x' followed by element type");
|
||||
return Type();
|
||||
}
|
||||
// Check for '?'
|
||||
int64_t bufferSize = -1;
|
||||
if (!spec.consume_front("?")) {
|
||||
unsigned long long parsedBufferSize = 0;
|
||||
if (spec.consumeInteger(10, parsedBufferSize)) {
|
||||
if (!sizeSpec.consume_front("?")) {
|
||||
if (sizeSpec.consumeInteger(10, bufferSize)) {
|
||||
emitError(loc, "expected buffer size to be an unsigned integer");
|
||||
return Type();
|
||||
}
|
||||
bufferSize = static_cast<int64_t>(parsedBufferSize);
|
||||
}
|
||||
if (!spec.consume_front("x")) {
|
||||
emitError(loc, "missing x in buffer type descrition : ") << spec;
|
||||
if (!sizeSpec.empty()) {
|
||||
emitError(loc, "unexpected token '") << sizeSpec << "'";
|
||||
}
|
||||
|
||||
typeSpec = typeSpec.trim();
|
||||
auto t = mlir::parseType(typeSpec, context);
|
||||
if (!t) {
|
||||
emitError(loc, "invalid type specification: '") << typeSpec << "'";
|
||||
return Type();
|
||||
}
|
||||
if (auto t = mlir::parseType(spec, context))
|
||||
return (bufferSize == -1
|
||||
? BufferType::get(getContext(), t)
|
||||
: BufferType::get(getContext(), t, bufferSize));
|
||||
return (bufferSize == -1 ? BufferType::get(getContext(), t)
|
||||
: BufferType::get(getContext(), t, bufferSize));
|
||||
}
|
||||
} else if (spec.consume_front("view")) {
|
||||
if (spec.consume_front("<") && spec.consume_back(">")) {
|
||||
|
|
|
@ -3199,6 +3199,16 @@ public:
|
|||
// Token Parsing
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Parse a `->` token.
|
||||
ParseResult parseArrow() override {
|
||||
return parser.parseToken(Token::arrow, "expected '->'");
|
||||
}
|
||||
|
||||
/// Parses a `->` if present.
|
||||
ParseResult parseOptionalArrow() override {
|
||||
return success(parser.consumeIf(Token::arrow));
|
||||
}
|
||||
|
||||
/// Parse a `:` token.
|
||||
ParseResult parseColon() override {
|
||||
return parser.parseToken(Token::colon, "expected ':'");
|
||||
|
|
|
@ -22,7 +22,7 @@ func @range(%arg0: index) {
|
|||
// CHECK-NEXT: %5 = llvm.insertvalue %1, %4[2] : !llvm<"{ i64, i64, i64 }">
|
||||
|
||||
func @view(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
|
||||
%0 = linalg.view %arg0[%arg1] : !linalg.view<?xf32>
|
||||
%0 = linalg.view %arg0[%arg1] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @view(%arg0: !llvm<"{ float*, i64 }">, %arg1: !llvm<"{ i64, i64, i64 }">) {
|
||||
|
@ -41,7 +41,7 @@ func @view(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
|
|||
// CHECK-NEXT: %12 = llvm.insertvalue %11, %8[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
|
||||
func @view3d(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) {
|
||||
%0 = linalg.view %arg0[%arg1, %arg2, %arg3] : !linalg.view<?x?x?xf32>
|
||||
%0 = linalg.view %arg0[%arg1, %arg2, %arg3] : !linalg.buffer<?xf32> -> !linalg.view<?x?x?xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @view3d(%arg0: !llvm<"{ float*, i64 }">, %arg1: !llvm<"{ i64, i64, i64 }">, %arg2: !llvm<"{ i64, i64, i64 }">, %arg3: !llvm<"{ i64, i64, i64 }">) {
|
||||
|
@ -56,7 +56,7 @@ func @view3d(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range, %arg2: !linalg.
|
|||
// CHECK-NEXT: %16 = llvm.insertvalue %15, %12[3, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
|
||||
|
||||
func @slice(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
|
||||
%0 = linalg.view %arg0[%arg1] : !linalg.view<?xf32>
|
||||
%0 = linalg.view %arg0[%arg1] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
|
||||
%1 = linalg.slice %0[%arg1] : !linalg.view<?xf32>, !linalg.range, !linalg.view<?xf32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -10,16 +10,16 @@ func @matmul(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: in
|
|||
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
|
||||
%J = linalg.range %c0:%arg2:%c1 : !linalg.range
|
||||
%K = linalg.range %c0:%arg3:%c1 : !linalg.range
|
||||
%A = linalg.view %arg0[%I, %K] : !linalg.view<?x?xf32>
|
||||
%B = linalg.view %arg0[%K, %J] : !linalg.view<?x?xf32>
|
||||
%C = linalg.view %arg0[%I, %J] : !linalg.view<?x?xf32>
|
||||
%A = linalg.view %arg0[%I, %K] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
%B = linalg.view %arg0[%K, %J] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
%C = linalg.view %arg0[%I, %J] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @matmul(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
|
||||
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
// CHECK: %[[M:.*]] = linalg.dim %[[A]], 0 : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[K:.*]] = linalg.dim %[[A]], 1 : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[N:.*]] = linalg.dim %[[B]], 1 : !linalg.view<?x?xf32>
|
||||
|
@ -38,16 +38,16 @@ func @matvec(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: in
|
|||
%c1 = constant 1 : index
|
||||
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
|
||||
%J = linalg.range %c0:%arg2:%c1 : !linalg.range
|
||||
%2 = linalg.view %arg0[%I, %J] : !linalg.view<?x?xf32>
|
||||
%3 = linalg.view %arg0[%J] : !linalg.view<?xf32>
|
||||
%4 = linalg.view %arg0[%I] : !linalg.view<?xf32>
|
||||
%2 = linalg.view %arg0[%I, %J] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
%3 = linalg.view %arg0[%J] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
|
||||
%4 = linalg.view %arg0[%I] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
|
||||
linalg.matvec(%2, %3, %4) : !linalg.view<?x?xf32>, !linalg.view<?xf32>, !linalg.view<?xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @matvec(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
|
||||
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
|
||||
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
|
||||
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
|
||||
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
|
||||
// CHECK: %[[M:.*]] = linalg.dim %[[A]], 0 : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[K:.*]] = linalg.dim %[[A]], 1 : !linalg.view<?x?xf32>
|
||||
// CHECK: linalg.for %i0 = %c0 to %[[M]] step %c1 {
|
||||
|
@ -63,16 +63,16 @@ func @dot(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index
|
|||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
|
||||
%1 = linalg.view %arg0[%I] : !linalg.view<?xf32>
|
||||
%2 = linalg.view %arg0[%I] : !linalg.view<?xf32>
|
||||
%3 = linalg.view %arg0[] : !linalg.view<f32>
|
||||
%1 = linalg.view %arg0[%I] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
|
||||
%2 = linalg.view %arg0[%I] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
|
||||
%3 = linalg.view %arg0[] : !linalg.buffer<?xf32> -> !linalg.view<f32>
|
||||
linalg.dot(%1, %2, %3) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dot(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
|
||||
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
|
||||
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
|
||||
// CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.view<f32>
|
||||
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
|
||||
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
|
||||
// CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.buffer<?xf32> -> !linalg.view<f32>
|
||||
// CHECK: %[[K:.*]] = linalg.dim %[[A]], 0 : !linalg.view<?xf32>
|
||||
// CHECK: linalg.for %i0 = %c0 to %[[K]] step %c1 {
|
||||
// CHECK-DAG: %[[a:.*]] = linalg.load %[[A]][%i0] : !linalg.view<?xf32>
|
||||
|
|
|
@ -34,7 +34,7 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
|
|||
%0 = muli %arg0, %arg0 : index
|
||||
%1 = linalg.buffer_alloc %0 : !linalg.buffer<?xf32>
|
||||
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
%3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
%3 = linalg.view %1[%2, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
%4 = linalg.slice %3[%2, %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
%5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
|
||||
%6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
|
||||
|
@ -46,7 +46,7 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
|
|||
// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
|
||||
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xf32>
|
||||
// CHECK-NEXT: %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
// CHECK-NEXT: %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %3 = linalg.view %1[%2, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %4 = linalg.slice %3[%2, %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
|
||||
// CHECK-NEXT: %6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
|
||||
|
@ -157,3 +157,10 @@ func @subview(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
|
|||
// CHECK-LABEL: func @subview(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
|
||||
// CHECK: %c0 = constant 0 : index
|
||||
// CHECK: %0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : !linalg.view<?x?xvector<3x4xi4>>
|
||||
|
||||
func @const_buffer_view(%arg0: index, %arg1: index, %arg2: index) {
|
||||
%c0 = linalg.buffer_alloc : !linalg.buffer<17xf32>
|
||||
%c1 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
||||
%c2 = linalg.view %c0[%c1] : !linalg.buffer<17xf32> -> !linalg.view<?xf32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ func @fill_f32(%arg0 : !linalg.buffer<?xf32>, %f : f32) {
|
|||
%c1 = constant 1 : index
|
||||
%s = linalg.buffer_size %arg0 : !linalg.buffer<?xf32>
|
||||
%R = linalg.range %c0:%s:%c1 : !linalg.range
|
||||
%V = linalg.view %arg0[%R] : !linalg.view<?xf32>
|
||||
%V = linalg.view %arg0[%R] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
|
||||
affine.for %i0 = 0 to %s {
|
||||
linalg.store %f, %V[%i0] : !linalg.view<?xf32>
|
||||
}
|
||||
|
@ -34,9 +34,9 @@ func @dot() -> f32 {
|
|||
%bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (!linalg.buffer<?xf32>)
|
||||
|
||||
%R = linalg.range %c0:%c16:%c1 : !linalg.range
|
||||
%A = linalg.view %bA[%R] : !linalg.view<?xf32>
|
||||
%B = linalg.view %bB[%R] : !linalg.view<?xf32>
|
||||
%C = linalg.view %bC[] : !linalg.view<f32>
|
||||
%A = linalg.view %bA[%R] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
|
||||
%B = linalg.view %bB[%R] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
|
||||
%C = linalg.view %bC[] : !linalg.buffer<?xf32> -> !linalg.view<f32>
|
||||
|
||||
linalg.dot(%A, %B, %C) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
|
||||
%res = linalg.load %C[] : !linalg.view<f32>
|
||||
|
@ -68,9 +68,9 @@ func @matmul() -> f32 {
|
|||
%M = linalg.range %c0:%c10:%c1 : !linalg.range
|
||||
%N = linalg.range %c0:%c10:%c1 : !linalg.range
|
||||
%K = linalg.range %c0:%c16:%c1 : !linalg.range
|
||||
%A = linalg.view %bA[%M, %K] : !linalg.view<?x?xf32>
|
||||
%B = linalg.view %bB[%K, %N] : !linalg.view<?x?xf32>
|
||||
%C = linalg.view %bC[%M, %N] : !linalg.view<?x?xf32>
|
||||
%A = linalg.view %bA[%M, %K] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
%B = linalg.view %bB[%K, %N] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
%C = linalg.view %bC[%M, %N] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
|
||||
linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
|
||||
%res = linalg.load %C[%c6, %c7] : !linalg.view<?x?xf32>
|
||||
|
|
Loading…
Reference in New Issue