forked from OSchip/llvm-project
[shape] Basic constant folding.
- Implement a first constant fold for shape.shape_of (more ops coming in subsequent patches) - Implement the right builder interfaces for ShapeType and other types - Splits shape.constant into shape.const_size and shape.const_shape which plays better with dyn_cast and building vs one polymorphic op. Also, fix the RUN line in ops.mlir to properly verify round-tripping.
This commit is contained in:
parent
e4a9190ad7
commit
d1ad267a56
|
@ -40,10 +40,13 @@ def ShapeDialect : Dialect {
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let cppNamespace = "shape";
|
let cppNamespace = "shape";
|
||||||
|
|
||||||
|
let hasConstantMaterializer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Shape_ComponentType : DialectType<ShapeDialect,
|
def Shape_ComponentType : DialectType<ShapeDialect,
|
||||||
CPred<"$_self.isa<::mlir::shape::ComponentType>()">, "component type"> {
|
CPred<"$_self.isa<::mlir::shape::ComponentType>()">, "component type">,
|
||||||
|
BuildableType<"$_builder.getType<::mlir::shape::ComponentType>()"> {
|
||||||
let typeDescription = [{
|
let typeDescription = [{
|
||||||
`shape.element_type` represents the element type of the ShapedType. It may
|
`shape.element_type` represents the element type of the ShapedType. It may
|
||||||
be unknown, error or regular element type supported by ShapedType.
|
be unknown, error or regular element type supported by ShapedType.
|
||||||
|
@ -51,7 +54,8 @@ def Shape_ComponentType : DialectType<ShapeDialect,
|
||||||
}
|
}
|
||||||
|
|
||||||
def Shape_ElementType : DialectType<ShapeDialect,
|
def Shape_ElementType : DialectType<ShapeDialect,
|
||||||
CPred<"$_self.isa<::mlir::shape::ElementType>()">, "element type"> {
|
CPred<"$_self.isa<::mlir::shape::ElementType>()">, "element type">,
|
||||||
|
BuildableType<"$_builder.getType<::mlir::shape::ElementType>()"> {
|
||||||
let typeDescription = [{
|
let typeDescription = [{
|
||||||
`shape.element_type` represents the element type of the ShapedType. It may
|
`shape.element_type` represents the element type of the ShapedType. It may
|
||||||
be unknown, error or regular element type supported by ShapedType.
|
be unknown, error or regular element type supported by ShapedType.
|
||||||
|
@ -59,7 +63,8 @@ def Shape_ElementType : DialectType<ShapeDialect,
|
||||||
}
|
}
|
||||||
|
|
||||||
def Shape_ShapeType : DialectType<ShapeDialect,
|
def Shape_ShapeType : DialectType<ShapeDialect,
|
||||||
CPred<"$_self.isa<::mlir::shape::ShapeType>()">, "shape"> {
|
CPred<"$_self.isa<::mlir::shape::ShapeType>()">, "shape">,
|
||||||
|
BuildableType<"$_builder.getType<::mlir::shape::ShapeType>()"> {
|
||||||
let typeDescription = [{
|
let typeDescription = [{
|
||||||
`shape.type` represents either an unranked shape, a ranked shape with
|
`shape.type` represents either an unranked shape, a ranked shape with
|
||||||
possibly unknown dimensions or an invalid shape. The rank is of type
|
possibly unknown dimensions or an invalid shape. The rank is of type
|
||||||
|
@ -77,7 +82,8 @@ def Shape_ShapeType : DialectType<ShapeDialect,
|
||||||
}
|
}
|
||||||
|
|
||||||
def Shape_SizeType : DialectType<ShapeDialect,
|
def Shape_SizeType : DialectType<ShapeDialect,
|
||||||
CPred<"$_self.isa<::mlir::shape::SizeType>()">, "size"> {
|
CPred<"$_self.isa<::mlir::shape::SizeType>()">, "size">,
|
||||||
|
BuildableType<"$_builder.getType<::mlir::shape::SizeType>()"> {
|
||||||
let typeDescription = [{
|
let typeDescription = [{
|
||||||
`shape.size` represents a non-negative integer with support for being
|
`shape.size` represents a non-negative integer with support for being
|
||||||
unknown and invalid.
|
unknown and invalid.
|
||||||
|
@ -89,7 +95,9 @@ def Shape_SizeType : DialectType<ShapeDialect,
|
||||||
}
|
}
|
||||||
|
|
||||||
def Shape_ValueShapeType : DialectType<ShapeDialect,
|
def Shape_ValueShapeType : DialectType<ShapeDialect,
|
||||||
CPred<"$_self.isa<::mlir::shape::ValueShapeType>()">, "value shape"> {
|
CPred<"$_self.isa<::mlir::shape::ValueShapeType>()">, "value shape">,
|
||||||
|
BuildableType<"::mlir::shape::ValueShapeType::get($_builder.getContext())">
|
||||||
|
{
|
||||||
let typeDescription = [{
|
let typeDescription = [{
|
||||||
`shape.value_shape` represents the value produced by an operation (this
|
`shape.value_shape` represents the value produced by an operation (this
|
||||||
corresponds to `Value` in the compiler) and a shape. Conceptually this is a
|
corresponds to `Value` in the compiler) and a shape. Conceptually this is a
|
||||||
|
@ -146,27 +154,46 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", []> {
|
||||||
let results = (outs Shape_ShapeType:$result);
|
let results = (outs Shape_ShapeType:$result);
|
||||||
}
|
}
|
||||||
|
|
||||||
def Shape_ConstantOp : Shape_Op<"constant", []> {
|
def Shape_ConstShapeOp : Shape_Op<"const_shape",
|
||||||
let summary = "Creates a shape constant";
|
[ConstantLike,
|
||||||
|
NoSideEffect,
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||||
|
let summary = "Creates a constant of !shape.shape type.";
|
||||||
let description = [{
|
let description = [{
|
||||||
An operation that builds a size or shape from integer or array attribute.
|
Creates a !shape.shape with rank given by the length of `shape` and with
|
||||||
It allows for creating dynamically valued shapes by using `?` for unknown
|
dimension sizes given by the values of `shape`.
|
||||||
values. A constant shape specified with `*` will return an unranked shape.
|
|
||||||
|
|
||||||
```mlir
|
```mlir
|
||||||
%x = shape.constant 10 : !shape.size
|
%0 = shape.const_shape []
|
||||||
|
%1 = shape.const_shape [1, 2, 3]
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
let arguments = (ins I64ElementsAttr:$shape);
|
||||||
// TODO(jpienaar): Change to a more specialized attribute that would
|
let results = (outs Shape_ShapeType:$result);
|
||||||
// encapsulate the unknown parsing while using denser packing.
|
|
||||||
let arguments = (ins AnyAttr:$value);
|
|
||||||
let results = (outs Shape_ShapeOrSizeType:$result);
|
|
||||||
|
|
||||||
// TODO: Move this to main so that all shape ops implement these.
|
// TODO: Move this to main so that all shape ops implement these.
|
||||||
let printer = [{ return ::print(p, *this); }];
|
let printer = [{ return ::print(p, *this); }];
|
||||||
let verifier = [{ return ::verify(*this); }];
|
|
||||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def Shape_ConstSizeOp : Shape_Op<"const_size",
|
||||||
|
[ConstantLike,
|
||||||
|
NoSideEffect,
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||||
|
let summary = "Creates a constant of !shape.size type.";
|
||||||
|
let description = [{
|
||||||
|
Creates a !shape.size type representing the constant size given by `value`.
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%x = shape.const_size 10
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins IndexAttr:$value);
|
||||||
|
let results = (outs Shape_SizeType:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "attr-dict $value";
|
||||||
}
|
}
|
||||||
|
|
||||||
def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> {
|
def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> {
|
||||||
|
@ -291,6 +318,8 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", []> {
|
||||||
|
|
||||||
let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg);
|
let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg);
|
||||||
let results = (outs Shape_ShapeType:$result);
|
let results = (outs Shape_ShapeType:$result);
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> {
|
def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> {
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
|
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
@ -29,6 +30,19 @@ ShapeDialect::ShapeDialect(MLIRContext *context)
|
||||||
allowUnknownOperations();
|
allowUnknownOperations();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
|
||||||
|
Attribute value, Type type,
|
||||||
|
Location loc) {
|
||||||
|
if (auto shapeType = type.dyn_cast<ShapeType>()) {
|
||||||
|
return builder.create<ConstShapeOp>(loc, type,
|
||||||
|
value.cast<DenseIntElementsAttr>());
|
||||||
|
}
|
||||||
|
if (auto sizeType = type.dyn_cast<SizeType>()) {
|
||||||
|
return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
/// Parse a type registered to this dialect.
|
/// Parse a type registered to this dialect.
|
||||||
Type ShapeDialect::parseType(DialectAsmParser &parser) const {
|
Type ShapeDialect::parseType(DialectAsmParser &parser) const {
|
||||||
StringRef keyword;
|
StringRef keyword;
|
||||||
|
@ -74,37 +88,79 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Constant*Op
|
// ConstShapeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, ConstantOp &op) {
|
static void print(OpAsmPrinter &p, ConstShapeOp &op) {
|
||||||
p << "shape.constant ";
|
p << "shape.const_shape ";
|
||||||
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
|
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
|
||||||
|
p << "[";
|
||||||
if (op.getAttrs().size() > 1)
|
interleaveComma(op.shape().getValues<int64_t>(), p,
|
||||||
p << ' ';
|
[&](int64_t i) { p << i; });
|
||||||
p.printAttributeWithoutType(op.value());
|
p << "]";
|
||||||
p << " : " << op.getType();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static ParseResult parseConstantOp(OpAsmParser &parser,
|
static ParseResult parseConstShapeOp(OpAsmParser &parser,
|
||||||
OperationState &result) {
|
OperationState &result) {
|
||||||
Attribute valueAttr;
|
|
||||||
if (parser.parseOptionalAttrDict(result.attributes))
|
if (parser.parseOptionalAttrDict(result.attributes))
|
||||||
return failure();
|
return failure();
|
||||||
Type i64Type = parser.getBuilder().getIntegerType(64);
|
// We piggy-back on ArrayAttr parsing, though we don't internally store the
|
||||||
if (parser.parseAttribute(valueAttr, i64Type, "value", result.attributes))
|
// shape as an ArrayAttr.
|
||||||
|
// TODO: Implement custom parser and maybe make syntax a bit more concise.
|
||||||
|
Attribute extentsRaw;
|
||||||
|
SmallVector<NamedAttribute, 6> dummy;
|
||||||
|
if (parser.parseAttribute(extentsRaw, "dummy", dummy))
|
||||||
return failure();
|
return failure();
|
||||||
|
auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
|
||||||
Type type;
|
if (!extentsArray)
|
||||||
if (parser.parseColonType(type))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
SmallVector<int64_t, 6> ints;
|
||||||
|
for (Attribute extent : extentsArray) {
|
||||||
|
IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
|
||||||
|
if (!attr)
|
||||||
|
return failure();
|
||||||
|
ints.push_back(attr.getInt());
|
||||||
|
}
|
||||||
|
Builder &builder = parser.getBuilder();
|
||||||
|
result.addAttribute("shape", builder.getI64TensorAttr(ints));
|
||||||
|
|
||||||
// Add the attribute type to the list.
|
result.types.push_back(ShapeType::get(builder.getContext()));
|
||||||
return parser.addTypeToList(type, result.types);
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verify(ConstantOp &op) { return success(); }
|
OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shape(); }
|
||||||
|
|
||||||
|
LogicalResult ConstShapeOp::inferReturnTypes(
|
||||||
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||||
|
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
|
inferredReturnTypes.push_back(ShapeType::get(context));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConstSizeOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult ConstSizeOp::inferReturnTypes(
|
||||||
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||||
|
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
|
inferredReturnTypes.push_back(SizeType::get(context));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ShapeOfOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
|
||||||
|
auto type = getOperand().getType().dyn_cast<ShapedType>();
|
||||||
|
if (!type || !type.hasStaticShape())
|
||||||
|
return nullptr;
|
||||||
|
Builder builder(getContext());
|
||||||
|
return builder.getI64TensorAttr(type.getShape());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// SplitAtOp
|
// SplitAtOp
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
// RUN: mlir-opt -canonicalize <%s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @f
|
||||||
|
func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
|
||||||
|
// CHECK: shape.const_shape [2, 3, 4]
|
||||||
|
%0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape
|
||||||
|
return %0 : !shape.shape
|
||||||
|
}
|
|
@ -1,8 +1,8 @@
|
||||||
// RUN: mlir-opt -split-input-file %s | FileCheck %s --dump-input-on-failure
|
// RUN: mlir-opt -split-input-file %s | mlir-opt | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
// CHECK-LABEL: shape_num_elements
|
// CHECK-LABEL: shape_num_elements
|
||||||
func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
|
func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
|
||||||
%0 = shape.constant 0 : !shape.size
|
%0 = shape.const_size 0
|
||||||
%1 = "shape.reduce"(%shape, %0) ( {
|
%1 = "shape.reduce"(%shape, %0) ( {
|
||||||
^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size):
|
^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size):
|
||||||
%acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size
|
%acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size
|
||||||
|
@ -19,40 +19,46 @@ func @test_shape_num_elements_unknown() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_shape_num_elements_fixed() {
|
func @test_shape_num_elements_fixed() {
|
||||||
%0 = "shape.constant"() { value = [1, 57, 92] }: () -> !shape.shape
|
%0 = shape.const_shape [1, 57, 92]
|
||||||
%1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
|
%1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
|
||||||
%3 = "shape.print"(%1) : (!shape.size) -> !shape.size
|
%3 = "shape.print"(%1) : (!shape.size) -> !shape.size
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_broadcastable_fixed() {
|
func @test_broadcastable_fixed() {
|
||||||
%0 = "shape.constant"() { value = [10, 1, 57, 92] }: () -> !shape.shape
|
%0 = shape.const_shape [10, 1, 57, 92]
|
||||||
%1 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
|
%1 = shape.const_shape [4, 57, 92]
|
||||||
%2 = "shape.broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
%2 = "shape.broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||||
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_shape_any_fixed() {
|
func @test_shape_any_fixed() {
|
||||||
%0 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
|
%0 = shape.const_shape [4, 57, 92]
|
||||||
%1 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
|
%1 = shape.const_shape [4, 57, 92]
|
||||||
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||||
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_shape_any_unknown() {
|
func @test_shape_any_unknown() {
|
||||||
%0 = "shape.constant"() { value = [4, -1, 92] }: () -> !shape.shape
|
%0 = shape.const_shape [4, -1, 92]
|
||||||
%1 = "shape.constant"() { value = [-1, 57, 92] }: () -> !shape.shape
|
%1 = shape.const_shape [-1, 57, 92]
|
||||||
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||||
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_shape_any_fixed_mismatch() {
|
func @test_shape_any_fixed_mismatch() {
|
||||||
%0 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
|
%0 = shape.const_shape [4, 57, 92]
|
||||||
%1 = "shape.constant"() { value = [2, 57, 92] }: () -> !shape.shape
|
%1 = shape.const_shape [2, 57, 92]
|
||||||
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||||
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_parse_const_shape() {
|
||||||
|
%0 = shape.const_shape []
|
||||||
|
%1 = shape.const_shape [1, 2, 3]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue