[MLIR][Linalg] Handle Attribute in InitTensorOp

In some cases, the result of an initTensorOp may have an attribute.
However, the Attribute was not passed to `inferResultType`, failing the
verifier. Therefore, propagate the Attribute to `inferResultType`.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D117192
This commit is contained in:
Lorenzo Chelini 2022-01-13 08:43:29 +01:00
parent ab3f100bec
commit ca2ac2bb14
3 changed files with 9 additions and 5 deletions

View File

@ -67,7 +67,8 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
// Infer the shape of the result tensor given the static shapes
// and element type of the result tensor.
static Type inferResultType(ArrayRef<int64_t> staticSizes, Type elementType);
static Type inferResultType(ArrayRef<int64_t> staticSizes, Type elementType,
Attribute encoding = {});
// Return true if the size of the tensor is dynamic at `idx`
bool isDynamicSize(unsigned idx) {

View File

@ -906,8 +906,8 @@ static LogicalResult verify(InitTensorOp op) {
return op->emitError("expected ")
<< resultType.getRank() << " sizes values";
Type expectedType =
InitTensorOp::inferResultType(staticSizes, resultType.getElementType());
Type expectedType = InitTensorOp::inferResultType(
staticSizes, resultType.getElementType(), resultType.getEncoding());
if (resultType != expectedType) {
return op.emitError("specified type ")
<< resultType << " does not match the inferred type "
@ -917,8 +917,8 @@ static LogicalResult verify(InitTensorOp op) {
}
Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
Type elementType) {
return RankedTensorType::get(staticSizes, elementType);
Type elementType, Attribute encoding) {
return RankedTensorType::get(staticSizes, elementType, encoding);
}
namespace {

View File

@ -435,15 +435,18 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x
// -----
#attr = {"foo"}
func @init_tensor(%arg0 : index, %arg1 : index)
{
%0 = linalg.init_tensor [3, 42] : tensor<3x42xf32>
%1 = linalg.init_tensor [4, %arg0, %arg1, 5] : tensor<4x?x?x5xf32>
%2 = linalg.init_tensor [2, 2] : tensor<2x2xf32, #attr>
return
}
// CHECK-LABEL: func @init_tensor
// CHECK: linalg.init_tensor [3, 42] : tensor<3x42xf32>
// CHECK: linalg.init_tensor [4, %{{.*}}, %{{.*}}, 5] : tensor<4x?x?x5xf32>
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xf32, {foo}>
// -----