forked from OSchip/llvm-project
[mlir][shape] Use IndexElementsAttr in Shape dialect.
Summary: Index is the proper type for storing shapes when constant folding, so this fixes the previous code (which was using i64). Differential Revision: https://reviews.llvm.org/D80600
This commit is contained in:
parent
9546d8b108
commit
25132b36a8
|
@ -102,7 +102,7 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
|
|||
%1 = shape.const_shape [1, 2, 3]
|
||||
```
|
||||
}];
|
||||
let arguments = (ins I64ElementsAttr:$shape);
|
||||
let arguments = (ins IndexElementsAttr:$shape);
|
||||
let results = (outs Shape_ShapeType:$result);
|
||||
|
||||
// TODO: Move this to main so that all shape ops implement these.
|
||||
|
@ -206,13 +206,8 @@ def Shape_GetExtentOp : Shape_Op<"get_extent",
|
|||
let builders = [
|
||||
// Builder that allows passing a simple integer instead of an IntegerAttr.
|
||||
OpBuilder<
|
||||
[{
|
||||
OpBuilder &builder, OperationState &result,
|
||||
Value shape, int64_t dim
|
||||
}],
|
||||
[{
|
||||
build(builder, result, shape, builder.getI64IntegerAttr(dim));
|
||||
}]
|
||||
[{OpBuilder &builder, OperationState &result, Value shape, int64_t dim}],
|
||||
[{build(builder, result, shape, builder.getI64IntegerAttr(dim));}]
|
||||
>
|
||||
];
|
||||
|
||||
|
|
|
@ -177,7 +177,7 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
|||
if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
|
||||
return nullptr;
|
||||
Builder builder(getContext());
|
||||
return builder.getI64TensorAttr(resultShape);
|
||||
return builder.getIndexTensorAttr(resultShape);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -215,7 +215,7 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
|
|||
ints.push_back(attr.getInt());
|
||||
}
|
||||
Builder &builder = parser.getBuilder();
|
||||
result.addAttribute("shape", builder.getI64TensorAttr(ints));
|
||||
result.addAttribute("shape", builder.getIndexTensorAttr(ints));
|
||||
|
||||
result.types.push_back(ShapeType::get(builder.getContext()));
|
||||
return success();
|
||||
|
@ -257,7 +257,7 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
|
|||
for (auto attr : operands)
|
||||
extents.push_back(attr.cast<IntegerAttr>().getInt());
|
||||
Builder builder(getContext());
|
||||
return builder.getI64TensorAttr(extents);
|
||||
return builder.getIndexTensorAttr(extents);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -281,14 +281,7 @@ OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
|
|||
// TODO: Constant fold this to some kind of constant error.
|
||||
if (dimToGet >= (uint64_t)elements.getNumElements())
|
||||
return nullptr;
|
||||
// This is a little inconvenient because getValue returns an IntegerAttr
|
||||
// that is not of IndexType, but the result here needs to be of
|
||||
// IndexType.
|
||||
// TODO: Make ConstShapeOp hold an tensor of index instead of i64.
|
||||
Builder builder(getContext());
|
||||
return builder.getIntegerAttr(
|
||||
builder.getIndexType(),
|
||||
elements.getValue<IntegerAttr>({dimToGet}).getInt());
|
||||
return elements.getValue({dimToGet});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -309,7 +302,7 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
|
|||
if (!type || !type.hasStaticShape())
|
||||
return nullptr;
|
||||
Builder builder(getContext());
|
||||
return builder.getI64TensorAttr(type.getShape());
|
||||
return builder.getIndexTensorAttr(type.getShape());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -343,8 +336,8 @@ LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
|
|||
if (splitPoint < 0)
|
||||
splitPoint += shape.size();
|
||||
Builder builder(operands[0].getContext());
|
||||
results.push_back(builder.getI64TensorAttr(shape.take_front(splitPoint)));
|
||||
results.push_back(builder.getI64TensorAttr(shape.drop_front(splitPoint)));
|
||||
results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
|
||||
results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -373,7 +366,7 @@ OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
|
|||
resultShape.append(lhsShape.begin(), lhsShape.end());
|
||||
resultShape.append(rhsShape.begin(), rhsShape.end());
|
||||
Builder builder(getContext());
|
||||
return builder.getI64TensorAttr(resultShape);
|
||||
return builder.getIndexTensorAttr(resultShape);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -15,7 +15,7 @@ func @f() -> (!shape.shape, !shape.shape) {
|
|||
// CHECK: shape.const_shape [2, 3]
|
||||
// CHECK: shape.const_shape [4, 5]
|
||||
%c2 = constant 2 : i32
|
||||
%0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape
|
||||
%0 = shape.const_shape [2, 3, 4, 5]
|
||||
%head, %tail = "shape.split_at"(%0, %c2) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
|
||||
return %head, %tail : !shape.shape, !shape.shape
|
||||
|
||||
|
@ -28,7 +28,7 @@ func @f() -> (!shape.shape, !shape.shape) {
|
|||
// CHECK: shape.const_shape [2, 3, 4]
|
||||
// CHECK: shape.const_shape [5]
|
||||
%c-1 = constant -1 : i32
|
||||
%0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape
|
||||
%0 = shape.const_shape [2, 3, 4, 5]
|
||||
%head, %tail = "shape.split_at"(%0, %c-1) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
|
||||
return %head, %tail : !shape.shape, !shape.shape
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ func @f() -> (!shape.shape, !shape.shape) {
|
|||
func @f() -> (!shape.shape, !shape.shape) {
|
||||
// CHECK: shape.split_at
|
||||
%c5 = constant 5 : i32
|
||||
%0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape
|
||||
%0 = shape.const_shape [2, 3, 4, 5]
|
||||
%head, %tail = "shape.split_at"(%0, %c5) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
|
||||
return %head, %tail : !shape.shape, !shape.shape
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue