forked from OSchip/llvm-project
[mlir] NFC - Add a builder to vector.transpose
Summary: Also expose some more vector ops to EDSCs. Differential Revision: https://reviews.llvm.org/D80333
This commit is contained in:
parent
b2a485e37e
commit
941005f51a
|
@ -20,9 +20,11 @@ using vector_insert = ValueBuilder<vector::InsertOp>;
|
|||
using vector_fma = ValueBuilder<vector::FMAOp>;
|
||||
using vector_extract = ValueBuilder<vector::ExtractOp>;
|
||||
using vector_matmul = ValueBuilder<vector::MatmulOp>;
|
||||
using vector_outerproduct = ValueBuilder<vector::OuterProductOp>;
|
||||
using vector_print = OperationBuilder<vector::PrintOp>;
|
||||
using vector_transfer_read = ValueBuilder<vector::TransferReadOp>;
|
||||
using vector_transfer_write = OperationBuilder<vector::TransferWriteOp>;
|
||||
using vector_transpose = ValueBuilder<vector::TransposeOp>;
|
||||
using vector_type_cast = ValueBuilder<vector::TypeCastOp>;
|
||||
using vector_insert = ValueBuilder<vector::InsertOp>;
|
||||
using vector_fma = ValueBuilder<vector::FMAOp>;
|
||||
|
|
|
@ -1385,6 +1385,9 @@ def Vector_TransposeOp :
|
|||
[c, f] ]
|
||||
```
|
||||
}];
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value vector, "
|
||||
"ArrayRef<int64_t> transp">];
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getVectorType() {
|
||||
return vector().getType().cast<VectorType>();
|
||||
|
@ -1393,6 +1396,7 @@ def Vector_TransposeOp :
|
|||
return result().getType().cast<VectorType>();
|
||||
}
|
||||
void getTransp(SmallVectorImpl<int64_t> &results);
|
||||
static StringRef getTranspAttrName() { return "transp"; }
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$vector `,` $transp attr-dict `:` type($vector) `to` type($result)
|
||||
|
|
|
@ -358,8 +358,23 @@ public:
|
|||
/// Emits a `load` when converting to a Value.
|
||||
operator Value() const { return Load(value, indices); }
|
||||
|
||||
/// Returns the base memref.
|
||||
Value getBase() const { return value; }
|
||||
|
||||
/// Returns the underlying memref.
|
||||
MemRefType getMemRefType() const {
|
||||
return value.getType().template cast<MemRefType>();
|
||||
}
|
||||
|
||||
/// Returns the underlying MemRef elemental type cast as `T`.
|
||||
template <typename T>
|
||||
T getElementalTypeAs() const {
|
||||
return value.getType()
|
||||
.template cast<MemRefType>()
|
||||
.getElementType()
|
||||
.template cast<T>();
|
||||
}
|
||||
|
||||
/// Arithmetic operator overloadings.
|
||||
Value operator+(Value e);
|
||||
Value operator-(Value e);
|
||||
|
|
|
@ -1713,6 +1713,18 @@ static LogicalResult verify(TupleOp op) { return success(); }
|
|||
// TransposeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value vector, ArrayRef<int64_t> transp) {
|
||||
VectorType vt = vector.getType().cast<VectorType>();
|
||||
SmallVector<int64_t, 4> transposedShape(vt.getRank());
|
||||
for (unsigned i = 0; i < transp.size(); ++i)
|
||||
transposedShape[i] = vt.getShape()[transp[i]];
|
||||
|
||||
result.addOperands(vector);
|
||||
result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
|
||||
result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp));
|
||||
}
|
||||
|
||||
// Eliminates transpose operations, which produce values identical to their
|
||||
// input values. This happens when the dimensions of the input vector remain in
|
||||
// their original order after the transpose operation.
|
||||
|
|
Loading…
Reference in New Issue