forked from OSchip/llvm-project
[mlir] Fix unintentional mutation by VectorType/RankedTensorType::Builder dropDim
Differential Revision: https://reviews.llvm.org/D113933
This commit is contained in:
parent
0ccc44cec0
commit
789c88e80e
|
@ -283,12 +283,14 @@ public:
|
|||
return *this;
|
||||
}
|
||||
|
||||
/// Create a new RankedTensor by erasing a dim from shape @pos.
|
||||
RankedTensorType dropDim(unsigned pos) {
|
||||
/// Erase a dim from shape @pos.
|
||||
Builder &dropDim(unsigned pos) {
|
||||
assert(pos < shape.size() && "overflow");
|
||||
SmallVector<int64_t, 4> newShape(shape.begin(), shape.end());
|
||||
newShape.erase(newShape.begin() + pos);
|
||||
return setShape(newShape);
|
||||
if (storage.empty())
|
||||
storage.append(shape.begin(), shape.end());
|
||||
storage.erase(storage.begin() + pos);
|
||||
shape = {storage.data(), storage.size()};
|
||||
return *this;
|
||||
}
|
||||
|
||||
operator RankedTensorType() {
|
||||
|
@ -297,6 +299,8 @@ public:
|
|||
|
||||
private:
|
||||
ArrayRef<int64_t> shape;
|
||||
// Owning shape data for copy-on-write operations.
|
||||
SmallVector<int64_t> storage;
|
||||
Type elementType;
|
||||
Attribute encoding;
|
||||
};
|
||||
|
@ -327,23 +331,29 @@ public:
|
|||
return *this;
|
||||
}
|
||||
|
||||
/// Create a new VectorType by erasing a dim from shape @pos.
|
||||
/// Erase a dim from shape @pos.
|
||||
Builder &dropDim(unsigned pos) {
|
||||
assert(pos < shape.size() && "overflow");
|
||||
if (storage.empty())
|
||||
storage.append(shape.begin(), shape.end());
|
||||
storage.erase(storage.begin() + pos);
|
||||
shape = {storage.data(), storage.size()};
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In the particular case where the vector has a single dimension that we
|
||||
/// drop, return the scalar element type.
|
||||
// TODO: unify once we have a VectorType that supports 0-D.
|
||||
Type dropDim(unsigned pos) {
|
||||
assert(pos < shape.size() && "overflow");
|
||||
if (shape.size() == 1)
|
||||
operator Type() {
|
||||
if (shape.empty())
|
||||
return elementType;
|
||||
SmallVector<int64_t, 4> newShape(shape.begin(), shape.end());
|
||||
newShape.erase(newShape.begin() + pos);
|
||||
return setShape(newShape);
|
||||
return VectorType::get(shape, elementType);
|
||||
}
|
||||
|
||||
operator VectorType() { return VectorType::get(shape, elementType); }
|
||||
|
||||
private:
|
||||
ArrayRef<int64_t> shape;
|
||||
// Owning shape data for copy-on-write operations.
|
||||
SmallVector<int64_t> storage;
|
||||
Type elementType;
|
||||
};
|
||||
|
||||
|
|
|
@ -876,9 +876,12 @@ struct DownscaleSizeOneWindowed2DConvolution final
|
|||
// Get new shapes and types for all operands by removing the size-1
|
||||
// dimension.
|
||||
using RTTBuilder = RankedTensorType::Builder;
|
||||
auto newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
|
||||
auto newFilterType = RTTBuilder(filterType).dropDim((removeH ? 0 : 1));
|
||||
auto newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
|
||||
RankedTensorType newInputType =
|
||||
RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
|
||||
RankedTensorType newFilterType =
|
||||
RTTBuilder(filterType).dropDim((removeH ? 0 : 1));
|
||||
RankedTensorType newOutputType =
|
||||
RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
|
||||
|
||||
// Rank-reduce operands.
|
||||
Location loc = convOp.getLoc();
|
||||
|
@ -948,9 +951,12 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final
|
|||
// Get new shapes and types for all operands by removing the size-1
|
||||
// dimension.
|
||||
using RTTBuilder = RankedTensorType::Builder;
|
||||
auto newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
|
||||
auto newKernelType = RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
|
||||
auto newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
|
||||
RankedTensorType newInputType =
|
||||
RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
|
||||
RankedTensorType newKernelType =
|
||||
RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
|
||||
RankedTensorType newOutputType =
|
||||
RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
|
||||
|
||||
// Rank-reduce operands.
|
||||
Location loc = convOp.getLoc();
|
||||
|
|
|
@ -94,15 +94,16 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
|
|||
}
|
||||
// Unroll leading dimensions.
|
||||
VectorType vType = lowType.cast<VectorType>();
|
||||
auto resType = VectorType::Builder(type).dropDim(index).cast<VectorType>();
|
||||
Type resType = VectorType::Builder(type).dropDim(index);
|
||||
auto resVectorType = resType.cast<VectorType>();
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, resType, rewriter.getZeroAttr(resType));
|
||||
for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
|
||||
loc, resVectorType, rewriter.getZeroAttr(resVectorType));
|
||||
for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(d);
|
||||
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
|
||||
Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
|
||||
result =
|
||||
rewriter.create<vector::InsertOp>(loc, resType, load, result, posAttr);
|
||||
result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
|
||||
posAttr);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue