[mlir] NFC - Add a builder to vector.transpose

Summary: Also expose some more vector ops to EDSCs.

Differential Revision: https://reviews.llvm.org/D80333
This commit is contained in:
Nicolas Vasilache 2020-05-21 05:12:31 -04:00
parent b2a485e37e
commit 941005f51a
4 changed files with 33 additions and 0 deletions

View File

@ -20,9 +20,11 @@ using vector_insert = ValueBuilder<vector::InsertOp>;
using vector_fma = ValueBuilder<vector::FMAOp>;
using vector_extract = ValueBuilder<vector::ExtractOp>;
using vector_matmul = ValueBuilder<vector::MatmulOp>;
using vector_outerproduct = ValueBuilder<vector::OuterProductOp>;
using vector_print = OperationBuilder<vector::PrintOp>;
using vector_transfer_read = ValueBuilder<vector::TransferReadOp>;
using vector_transfer_write = OperationBuilder<vector::TransferWriteOp>;
using vector_transpose = ValueBuilder<vector::TransposeOp>;
using vector_type_cast = ValueBuilder<vector::TypeCastOp>;
using vector_insert = ValueBuilder<vector::InsertOp>;
using vector_fma = ValueBuilder<vector::FMAOp>;

View File

@ -1385,6 +1385,9 @@ def Vector_TransposeOp :
[c, f] ]
```
}];
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value vector, "
"ArrayRef<int64_t> transp">];
let extraClassDeclaration = [{
VectorType getVectorType() {
return vector().getType().cast<VectorType>();
@ -1393,6 +1396,7 @@ def Vector_TransposeOp :
return result().getType().cast<VectorType>();
}
void getTransp(SmallVectorImpl<int64_t> &results);
static StringRef getTranspAttrName() { return "transp"; }
}];
let assemblyFormat = [{
$vector `,` $transp attr-dict `:` type($vector) `to` type($result)

View File

@ -358,8 +358,23 @@ public:
/// Emits a `load` when converting to a Value.
operator Value() const { return Load(value, indices); }
/// Returns the base memref.
Value getBase() const { return value; }
/// Returns the underlying memref.
MemRefType getMemRefType() const {
return value.getType().template cast<MemRefType>();
}
/// Returns the underlying MemRef elemental type cast as `T`.
template <typename T>
T getElementalTypeAs() const {
return value.getType()
.template cast<MemRefType>()
.getElementType()
.template cast<T>();
}
/// Arithmetic operator overloadings.
Value operator+(Value e);
Value operator-(Value e);

View File

@ -1713,6 +1713,18 @@ static LogicalResult verify(TupleOp op) { return success(); }
// TransposeOp
//===----------------------------------------------------------------------===//
void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
Value vector, ArrayRef<int64_t> transp) {
VectorType vt = vector.getType().cast<VectorType>();
SmallVector<int64_t, 4> transposedShape(vt.getRank());
for (unsigned i = 0; i < transp.size(); ++i)
transposedShape[i] = vt.getShape()[transp[i]];
result.addOperands(vector);
result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp));
}
// Eliminates transpose operations, which produce values identical to their
// input values. This happens when the dimensions of the input vector remain in
// their original order after the transpose operation.