forked from OSchip/llvm-project
[mlir] NFC - Address post-commit comments
Address comments from https://reviews.llvm.org/D113745
which landed as aa37318067
This commit is contained in:
parent
388e8110db
commit
0e185ceafb
|
@ -87,13 +87,16 @@ bool canFoldIntoConsumerOp(CastOp castOp);
|
||||||
LogicalResult foldTensorCast(Operation *op);
|
LogicalResult foldTensorCast(Operation *op);
|
||||||
|
|
||||||
/// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and
|
/// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and
|
||||||
/// appropriate sizes to reduce the rank of `tensor` to `targetType`.
|
/// appropriate sizes (i.e. `tensor.getSizes()`) to reduce the rank of `tensor`
|
||||||
|
/// to that of `targetType`.
|
||||||
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc,
|
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc,
|
||||||
Value tensor,
|
Value tensor,
|
||||||
RankedTensorType targetType);
|
RankedTensorType targetType);
|
||||||
|
|
||||||
/// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and
|
/// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and
|
||||||
/// appropriate sizes to increase the rank of `tensor` to `dest`.
|
/// appropriate sizes (i.e. `dest.getSizes()`). The result is a new tensor with
|
||||||
|
/// rank increased to that of `dest`, obtained by inserting `tensor` into `dest`
|
||||||
|
/// at the canonical [0 .. 0] position.
|
||||||
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc,
|
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc,
|
||||||
Value tensor, Value dest);
|
Value tensor, Value dest);
|
||||||
|
|
||||||
|
|
|
@ -203,62 +203,12 @@ public:
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// RankedTensorType
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
/// This is a builder type that keeps local references to arguments. Arguments
|
|
||||||
/// that are passed into the builder must out-live the builder.
|
|
||||||
class RankedTensorType::Builder {
|
|
||||||
public:
|
|
||||||
/// Build from another RankedTensorType.
|
|
||||||
explicit Builder(RankedTensorType other)
|
|
||||||
: shape(other.getShape()), elementType(other.getElementType()),
|
|
||||||
encoding(other.getEncoding()) {}
|
|
||||||
|
|
||||||
/// Build from scratch.
|
|
||||||
Builder(ArrayRef<int64_t> shape, Type elementType, Attribute encoding)
|
|
||||||
: shape(shape), elementType(elementType), encoding(encoding) {}
|
|
||||||
|
|
||||||
Builder &setShape(ArrayRef<int64_t> newShape) {
|
|
||||||
shape = newShape;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
Builder &setElementType(Type newElementType) {
|
|
||||||
elementType = newElementType;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
Builder &setEncoding(Attribute newEncoding) {
|
|
||||||
encoding = newEncoding;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new RankedTensor by erasing a dim from shape.
|
|
||||||
// Note: the newly created type has ownership of a new shape vector.
|
|
||||||
RankedTensorType dropDim(unsigned dim) {
|
|
||||||
SmallVector<int64_t, 4> newShape(shape.begin(), shape.end());
|
|
||||||
newShape.erase(newShape.begin() + dim);
|
|
||||||
return setShape(newShape);
|
|
||||||
}
|
|
||||||
|
|
||||||
operator RankedTensorType() {
|
|
||||||
return RankedTensorType::get(shape, elementType, encoding);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
ArrayRef<int64_t> shape;
|
|
||||||
Type elementType;
|
|
||||||
Attribute encoding;
|
|
||||||
};
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// MemRefType
|
// MemRefType
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// This is a builder type that keeps local references to arguments. Arguments
|
/// This is a builder type that keeps local references to arguments. Arguments
|
||||||
/// that are passed into the builder must out-live the builder.
|
/// that are passed into the builder must outlive the builder.
|
||||||
class MemRefType::Builder {
|
class MemRefType::Builder {
|
||||||
public:
|
public:
|
||||||
// Build from another MemRefType.
|
// Build from another MemRefType.
|
||||||
|
@ -304,6 +254,55 @@ private:
|
||||||
Attribute memorySpace;
|
Attribute memorySpace;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// RankedTensorType
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// This is a builder type that keeps local references to arguments. Arguments
|
||||||
|
/// that are passed into the builder must outlive the builder.
|
||||||
|
class RankedTensorType::Builder {
|
||||||
|
public:
|
||||||
|
/// Build from another RankedTensorType.
|
||||||
|
explicit Builder(RankedTensorType other)
|
||||||
|
: shape(other.getShape()), elementType(other.getElementType()),
|
||||||
|
encoding(other.getEncoding()) {}
|
||||||
|
|
||||||
|
/// Build from scratch.
|
||||||
|
Builder(ArrayRef<int64_t> shape, Type elementType, Attribute encoding)
|
||||||
|
: shape(shape), elementType(elementType), encoding(encoding) {}
|
||||||
|
|
||||||
|
Builder &setShape(ArrayRef<int64_t> newShape) {
|
||||||
|
shape = newShape;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
Builder &setElementType(Type newElementType) {
|
||||||
|
elementType = newElementType;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
Builder &setEncoding(Attribute newEncoding) {
|
||||||
|
encoding = newEncoding;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new RankedTensorType by erasing a dim from shape.
|
||||||
|
RankedTensorType dropDim(unsigned dim) {
|
||||||
|
SmallVector<int64_t, 4> newShape(shape.begin(), shape.end());
|
||||||
|
newShape.erase(newShape.begin() + dim);
|
||||||
|
return setShape(newShape);
|
||||||
|
}
|
||||||
|
|
||||||
|
operator RankedTensorType() {
|
||||||
|
return RankedTensorType::get(shape, elementType, encoding);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
ArrayRef<int64_t> shape;
|
||||||
|
Type elementType;
|
||||||
|
Attribute encoding;
|
||||||
|
};
|
||||||
|
|
||||||
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
|
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
|
||||||
/// `originalShape` with some `1` entries erased, return the set of indices
|
/// `originalShape` with some `1` entries erased, return the set of indices
|
||||||
/// that specifies which of the entries of `originalShape` are dropped to obtain
|
/// that specifies which of the entries of `originalShape` are dropped to obtain
|
||||||
|
|
|
@ -542,7 +542,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [
|
||||||
];
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
/// This is a builder type that keeps local references to arguments.
|
/// This is a builder type that keeps local references to arguments.
|
||||||
/// Arguments that are passed into the builder must out-live the builder.
|
/// Arguments that are passed into the builder must outlive the builder.
|
||||||
class Builder;
|
class Builder;
|
||||||
|
|
||||||
/// [deprecated] Returns the memory space in old raw integer representation.
|
/// [deprecated] Returns the memory space in old raw integer representation.
|
||||||
|
@ -703,7 +703,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
|
||||||
];
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
/// This is a builder type that keeps local references to arguments.
|
/// This is a builder type that keeps local references to arguments.
|
||||||
/// Arguments that are passed into the builder must out-live the builder.
|
/// Arguments that are passed into the builder must outlive the builder.
|
||||||
class Builder;
|
class Builder;
|
||||||
}];
|
}];
|
||||||
let skipDefaultBuilders = 1;
|
let skipDefaultBuilders = 1;
|
||||||
|
|
Loading…
Reference in New Issue