forked from OSchip/llvm-project
[MLIR][Linalg] Handle Attribute in InitTensorOp
In some cases, the result of an initTensorOp may have an attribute. However, the Attribute was not passed to `inferResultType`, failing the verifier. Therefore, propagate the Attribute to `inferResultType`. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D117192
This commit is contained in:
parent
ab3f100bec
commit
ca2ac2bb14
|
@ -67,7 +67,8 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
|
|||
|
||||
// Infer the shape of the result tensor given the static shapes
|
||||
// and element type of the result tensor.
|
||||
static Type inferResultType(ArrayRef<int64_t> staticSizes, Type elementType);
|
||||
static Type inferResultType(ArrayRef<int64_t> staticSizes, Type elementType,
|
||||
Attribute encoding = {});
|
||||
|
||||
// Return true if the size of the tensor is dynamic at `idx`
|
||||
bool isDynamicSize(unsigned idx) {
|
||||
|
|
|
@ -906,8 +906,8 @@ static LogicalResult verify(InitTensorOp op) {
|
|||
return op->emitError("expected ")
|
||||
<< resultType.getRank() << " sizes values";
|
||||
|
||||
Type expectedType =
|
||||
InitTensorOp::inferResultType(staticSizes, resultType.getElementType());
|
||||
Type expectedType = InitTensorOp::inferResultType(
|
||||
staticSizes, resultType.getElementType(), resultType.getEncoding());
|
||||
if (resultType != expectedType) {
|
||||
return op.emitError("specified type ")
|
||||
<< resultType << " does not match the inferred type "
|
||||
|
@ -917,8 +917,8 @@ static LogicalResult verify(InitTensorOp op) {
|
|||
}
|
||||
|
||||
Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
|
||||
Type elementType) {
|
||||
return RankedTensorType::get(staticSizes, elementType);
|
||||
Type elementType, Attribute encoding) {
|
||||
return RankedTensorType::get(staticSizes, elementType, encoding);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -435,15 +435,18 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x
|
|||
|
||||
// -----
|
||||
|
||||
#attr = {"foo"}
|
||||
func @init_tensor(%arg0 : index, %arg1 : index)
|
||||
{
|
||||
%0 = linalg.init_tensor [3, 42] : tensor<3x42xf32>
|
||||
%1 = linalg.init_tensor [4, %arg0, %arg1, 5] : tensor<4x?x?x5xf32>
|
||||
%2 = linalg.init_tensor [2, 2] : tensor<2x2xf32, #attr>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @init_tensor
|
||||
// CHECK: linalg.init_tensor [3, 42] : tensor<3x42xf32>
|
||||
// CHECK: linalg.init_tensor [4, %{{.*}}, %{{.*}}, 5] : tensor<4x?x?x5xf32>
|
||||
// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xf32, {foo}>
|
||||
|
||||
// -----
|
||||
|
||||
|
|
Loading…
Reference in New Issue