[mlir] NFC - Address post-commit comments

Address comments from https://reviews.llvm.org/D113745
which landed as aa37318067
This commit is contained in:
Nicolas Vasilache 2021-11-12 14:58:03 +00:00
parent 388e8110db
commit 0e185ceafb
3 changed files with 57 additions and 55 deletions

View File

@ -87,13 +87,16 @@ bool canFoldIntoConsumerOp(CastOp castOp);
LogicalResult foldTensorCast(Operation *op); LogicalResult foldTensorCast(Operation *op);
/// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and /// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and
/// appropriate sizes to reduce the rank of `tensor` to `targetType`. /// appropriate sizes (i.e. `tensor.getSizes()`) to reduce the rank of `tensor`
/// to that of `targetType`.
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc,
Value tensor, Value tensor,
RankedTensorType targetType); RankedTensorType targetType);
/// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and /// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and
/// appropriate sizes to increase the rank of `tensor` to `dest`. /// appropriate sizes (i.e. `dest.getSizes()`). The result is a new tensor with
/// rank increased to that of `dest`, obtained by inserting `tensor` into `dest`
/// at the canonical [0 .. 0] position.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc,
Value tensor, Value dest); Value tensor, Value dest);

View File

@ -203,62 +203,12 @@ public:
namespace mlir { namespace mlir {
//===----------------------------------------------------------------------===//
// RankedTensorType
//===----------------------------------------------------------------------===//
/// This is a builder type that keeps local references to arguments. Arguments
/// that are passed into the builder must out-live the builder.
class RankedTensorType::Builder {
public:
/// Build from another RankedTensorType.
explicit Builder(RankedTensorType other)
: shape(other.getShape()), elementType(other.getElementType()),
encoding(other.getEncoding()) {}
/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType, Attribute encoding)
: shape(shape), elementType(elementType), encoding(encoding) {}
Builder &setShape(ArrayRef<int64_t> newShape) {
shape = newShape;
return *this;
}
Builder &setElementType(Type newElementType) {
elementType = newElementType;
return *this;
}
Builder &setEncoding(Attribute newEncoding) {
encoding = newEncoding;
return *this;
}
/// Create a new RankedTensor by erasing a dim from shape.
// Note: the newly created type has ownership of a new shape vector.
RankedTensorType dropDim(unsigned dim) {
SmallVector<int64_t, 4> newShape(shape.begin(), shape.end());
newShape.erase(newShape.begin() + dim);
return setShape(newShape);
}
operator RankedTensorType() {
return RankedTensorType::get(shape, elementType, encoding);
}
private:
ArrayRef<int64_t> shape;
Type elementType;
Attribute encoding;
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MemRefType // MemRefType
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// This is a builder type that keeps local references to arguments. Arguments /// This is a builder type that keeps local references to arguments. Arguments
/// that are passed into the builder must out-live the builder. /// that are passed into the builder must outlive the builder.
class MemRefType::Builder { class MemRefType::Builder {
public: public:
// Build from another MemRefType. // Build from another MemRefType.
@ -304,6 +254,55 @@ private:
Attribute memorySpace; Attribute memorySpace;
}; };
//===----------------------------------------------------------------------===//
// RankedTensorType
//===----------------------------------------------------------------------===//
/// This is a builder type that keeps local references to arguments. Arguments
/// that are passed into the builder must outlive the builder.
class RankedTensorType::Builder {
public:
/// Build from another RankedTensorType.
explicit Builder(RankedTensorType other)
: shape(other.getShape()), elementType(other.getElementType()),
encoding(other.getEncoding()) {}
/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType, Attribute encoding)
: shape(shape), elementType(elementType), encoding(encoding) {}
Builder &setShape(ArrayRef<int64_t> newShape) {
shape = newShape;
return *this;
}
Builder &setElementType(Type newElementType) {
elementType = newElementType;
return *this;
}
Builder &setEncoding(Attribute newEncoding) {
encoding = newEncoding;
return *this;
}
/// Create a new RankedTensorType by erasing a dim from shape.
RankedTensorType dropDim(unsigned dim) {
SmallVector<int64_t, 4> newShape(shape.begin(), shape.end());
newShape.erase(newShape.begin() + dim);
return setShape(newShape);
}
operator RankedTensorType() {
return RankedTensorType::get(shape, elementType, encoding);
}
private:
ArrayRef<int64_t> shape;
Type elementType;
Attribute encoding;
};
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
/// `originalShape` with some `1` entries erased, return the set of indices /// `originalShape` with some `1` entries erased, return the set of indices
/// that specifies which of the entries of `originalShape` are dropped to obtain /// that specifies which of the entries of `originalShape` are dropped to obtain

View File

@ -542,7 +542,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [
]; ];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
/// This is a builder type that keeps local references to arguments. /// This is a builder type that keeps local references to arguments.
/// Arguments that are passed into the builder must out-live the builder. /// Arguments that are passed into the builder must outlive the builder.
class Builder; class Builder;
/// [deprecated] Returns the memory space in old raw integer representation. /// [deprecated] Returns the memory space in old raw integer representation.
@ -703,7 +703,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
]; ];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
/// This is a builder type that keeps local references to arguments. /// This is a builder type that keeps local references to arguments.
/// Arguments that are passed into the builder must out-live the builder. /// Arguments that are passed into the builder must outlive the builder.
class Builder; class Builder;
}]; }];
let skipDefaultBuilders = 1; let skipDefaultBuilders = 1;