diff --git a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h index 6b5c4be7b2f4..3a3551ddc3eb 100644 --- a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h @@ -20,9 +20,11 @@ using vector_insert = ValueBuilder; using vector_fma = ValueBuilder; using vector_extract = ValueBuilder; using vector_matmul = ValueBuilder; +using vector_outerproduct = ValueBuilder; using vector_print = OperationBuilder; using vector_transfer_read = ValueBuilder; using vector_transfer_write = OperationBuilder; +using vector_transpose = ValueBuilder; using vector_type_cast = ValueBuilder; using vector_insert = ValueBuilder; using vector_fma = ValueBuilder; diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 575b99d51c97..264c8ad034c8 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1385,6 +1385,9 @@ def Vector_TransposeOp : [c, f] ] ``` }]; + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value vector, " + "ArrayRef transp">]; let extraClassDeclaration = [{ VectorType getVectorType() { return vector().getType().cast(); @@ -1393,6 +1396,7 @@ def Vector_TransposeOp : return result().getType().cast(); } void getTransp(SmallVectorImpl &results); + static StringRef getTranspAttrName() { return "transp"; } }]; let assemblyFormat = [{ $vector `,` $transp attr-dict `:` type($vector) `to` type($result) diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index c1437892f6f6..a6045db3d998 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -358,8 +358,23 @@ public: /// Emits a `load` when converting to a Value. operator Value() const { return Load(value, indices); } + /// Returns the base memref. Value getBase() const { return value; } + /// Returns the underlying memref. + MemRefType getMemRefType() const { + return value.getType().template cast(); + } + + /// Returns the underlying MemRef elemental type cast as `T`. + template + T getElementalTypeAs() const { + return value.getType() + .template cast() + .getElementType() + .template cast(); + } + /// Arithmetic operator overloadings. Value operator+(Value e); Value operator-(Value e); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index ca07ee140774..5439233c96b1 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1713,6 +1713,18 @@ static LogicalResult verify(TupleOp op) { return success(); } // TransposeOp //===----------------------------------------------------------------------===// +void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, + Value vector, ArrayRef transp) { + VectorType vt = vector.getType().cast(); + SmallVector transposedShape(vt.getRank()); + for (unsigned i = 0; i < transp.size(); ++i) + transposedShape[i] = vt.getShape()[transp[i]]; + + result.addOperands(vector); + result.addTypes(VectorType::get(transposedShape, vt.getElementType())); + result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp)); +} + // Eliminates transpose operations, which produce values identical to their // input values. This happens when the dimensions of the input vector remain in // their original order after the transpose operation.