forked from OSchip/llvm-project
[MLIR][Linalg] Handle Attribute in InitTensorOp
In some cases, the result of an initTensorOp may have an attribute. However, the Attribute was not passed to `inferResultType`, failing the verifier. Therefore, propagate the Attribute to `inferResultType`. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D117192
This commit is contained in:
parent
ab3f100bec
commit
ca2ac2bb14
mlir
include/mlir/Dialect/Linalg/IR
lib/Dialect/Linalg/IR
test/Dialect/Linalg
|
@ -67,7 +67,8 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
|
||||||
|
|
||||||
// Infer the shape of the result tensor given the static shapes
|
// Infer the shape of the result tensor given the static shapes
|
||||||
// and element type of the result tensor.
|
// and element type of the result tensor.
|
||||||
static Type inferResultType(ArrayRef<int64_t> staticSizes, Type elementType);
|
static Type inferResultType(ArrayRef<int64_t> staticSizes, Type elementType,
|
||||||
|
Attribute encoding = {});
|
||||||
|
|
||||||
// Return true if the size of the tensor is dynamic at `idx`
|
// Return true if the size of the tensor is dynamic at `idx`
|
||||||
bool isDynamicSize(unsigned idx) {
|
bool isDynamicSize(unsigned idx) {
|
||||||
|
|
|
@ -906,8 +906,8 @@ static LogicalResult verify(InitTensorOp op) {
|
||||||
return op->emitError("expected ")
|
return op->emitError("expected ")
|
||||||
<< resultType.getRank() << " sizes values";
|
<< resultType.getRank() << " sizes values";
|
||||||
|
|
||||||
Type expectedType =
|
Type expectedType = InitTensorOp::inferResultType(
|
||||||
InitTensorOp::inferResultType(staticSizes, resultType.getElementType());
|
staticSizes, resultType.getElementType(), resultType.getEncoding());
|
||||||
if (resultType != expectedType) {
|
if (resultType != expectedType) {
|
||||||
return op.emitError("specified type ")
|
return op.emitError("specified type ")
|
||||||
<< resultType << " does not match the inferred type "
|
<< resultType << " does not match the inferred type "
|
||||||
|
@ -917,8 +917,8 @@ static LogicalResult verify(InitTensorOp op) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
|
Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
|
||||||
Type elementType) {
|
Type elementType, Attribute encoding) {
|
||||||
return RankedTensorType::get(staticSizes, elementType);
|
return RankedTensorType::get(staticSizes, elementType, encoding);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
@ -435,15 +435,18 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
#attr = {"foo"}
|
||||||
func @init_tensor(%arg0 : index, %arg1 : index)
|
func @init_tensor(%arg0 : index, %arg1 : index)
|
||||||
{
|
{
|
||||||
%0 = linalg.init_tensor [3, 42] : tensor<3x42xf32>
|
%0 = linalg.init_tensor [3, 42] : tensor<3x42xf32>
|
||||||
%1 = linalg.init_tensor [4, %arg0, %arg1, 5] : tensor<4x?x?x5xf32>
|
%1 = linalg.init_tensor [4, %arg0, %arg1, 5] : tensor<4x?x?x5xf32>
|
||||||
|
%2 = linalg.init_tensor [2, 2] : tensor<2x2xf32, #attr>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @init_tensor
|
// CHECK-LABEL: func @init_tensor
|
||||||
// CHECK: linalg.init_tensor [3, 42] : tensor<3x42xf32>
|
// CHECK: linalg.init_tensor [3, 42] : tensor<3x42xf32>
|
||||||
// CHECK: linalg.init_tensor [4, %{{.*}}, %{{.*}}, 5] : tensor<4x?x?x5xf32>
|
// CHECK: linalg.init_tensor [4, %{{.*}}, %{{.*}}, 5] : tensor<4x?x?x5xf32>
|
||||||
|
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xf32, {foo}>
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue