Revert "Revert "[mlir] Reuse the code between `getMixed*s()` funcs in ViewLikeInterface.cpp.""

This reverts commit e78d7637fb.

Differential Revision: https://reviews.llvm.org/D130706
This commit is contained in:
Alexander Belyaev 2022-07-31 21:44:24 +02:00
parent e78d7637fb
commit 68b0aaad56
3 changed files with 61 additions and 35 deletions

View File

@ -31,6 +31,12 @@ struct Range {
class OffsetSizeAndStrideOpInterface;
/// Return a vector of OpFoldResults given the special value
/// that indicates whether of the value is dynamic or not.
SmallVector<OpFoldResult, 4> getMixedValues(ArrayAttr staticValues,
ValueRange dynamicValues,
int64_t dynamicValueIndicator);
/// Return a vector of all the static or dynamic offsets of the op from provided
/// external static and dynamic offsets.
SmallVector<OpFoldResult, 4> getMixedOffsets(OffsetSizeAndStrideOpInterface op,
@ -49,6 +55,13 @@ SmallVector<OpFoldResult, 4> getMixedStrides(OffsetSizeAndStrideOpInterface op,
ArrayAttr staticStrides,
ValueRange strides);
/// Decompose a vector of mixed static or dynamic values into the corresponding
/// pair of arrays. This is the inverse function of `getMixedValues`.
std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedValues(Builder &b,
const SmallVectorImpl<OpFoldResult> &mixedValues,
const int64_t dynamicValueIndicator);
/// Decompose a vector of mixed static or dynamic strides/offsets into the
/// corresponding pair of arrays. This is the inverse function of
/// `getMixedStrides` and `getMixedOffsets`.

View File

@ -237,7 +237,30 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
return ::mlir::ShapedType::isDynamicStrideOrOffset(v.getSExtValue());
}]
>,
StaticInterfaceMethod<
/*desc=*/"Return constant that indicates the offset is dynamic",
/*retTy=*/"int64_t",
/*methodName=*/"getDynamicOffsetIndicator",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }]
>,
StaticInterfaceMethod<
/*desc=*/"Return constant that indicates the size is dynamic",
/*retTy=*/"int64_t",
/*methodName=*/"getDynamicSizeIndicator",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicSize; }]
>,
StaticInterfaceMethod<
/*desc=*/"Return constant that indicates the stride is dynamic",
/*retTy=*/"int64_t",
/*methodName=*/"getDynamicStrideIndicator",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }]
>,
InterfaceMethod<
/*desc=*/[{
Assert the offset `idx` is a static constant and return its value.

View File

@ -180,61 +180,50 @@ bool mlir::detail::sameOffsetsSizesAndStrides(
}
SmallVector<OpFoldResult, 4>
mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op,
ArrayAttr staticOffsets, ValueRange offsets) {
mlir::getMixedValues(ArrayAttr staticValues, ValueRange dynamicValues,
int64_t dynamicValueIndicator) {
SmallVector<OpFoldResult, 4> res;
res.reserve(staticValues.size());
unsigned numDynamic = 0;
unsigned count = static_cast<unsigned>(staticOffsets.size());
unsigned count = static_cast<unsigned>(staticValues.size());
for (unsigned idx = 0; idx < count; ++idx) {
if (op.isDynamicOffset(idx))
res.push_back(offsets[numDynamic++]);
else
res.push_back(staticOffsets[idx]);
APInt value = staticValues[idx].cast<IntegerAttr>().getValue();
res.push_back(value.getSExtValue() == dynamicValueIndicator
? OpFoldResult{dynamicValues[numDynamic++]}
: OpFoldResult{staticValues[idx]});
}
return res;
}
SmallVector<OpFoldResult, 4>
mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op,
ArrayAttr staticOffsets, ValueRange offsets) {
return getMixedValues(staticOffsets, offsets, op.getDynamicOffsetIndicator());
}
SmallVector<OpFoldResult, 4>
mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes,
ValueRange sizes) {
SmallVector<OpFoldResult, 4> res;
unsigned numDynamic = 0;
unsigned count = static_cast<unsigned>(staticSizes.size());
for (unsigned idx = 0; idx < count; ++idx) {
if (op.isDynamicSize(idx))
res.push_back(sizes[numDynamic++]);
else
res.push_back(staticSizes[idx]);
}
return res;
return getMixedValues(staticSizes, sizes, op.getDynamicSizeIndicator());
}
SmallVector<OpFoldResult, 4>
mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op,
ArrayAttr staticStrides, ValueRange strides) {
SmallVector<OpFoldResult, 4> res;
unsigned numDynamic = 0;
unsigned count = static_cast<unsigned>(staticStrides.size());
for (unsigned idx = 0; idx < count; ++idx) {
if (op.isDynamicStride(idx))
res.push_back(strides[numDynamic++]);
else
res.push_back(staticStrides[idx]);
}
return res;
return getMixedValues(staticStrides, strides, op.getDynamicStrideIndicator());
}
static std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedImpl(OpBuilder &b,
const SmallVectorImpl<OpFoldResult> &mixedValues,
const int64_t dynamicValuePlaceholder) {
std::pair<ArrayAttr, SmallVector<Value>>
mlir::decomposeMixedValues(Builder &b,
const SmallVectorImpl<OpFoldResult> &mixedValues,
const int64_t dynamicValueIndicator) {
SmallVector<int64_t> staticValues;
SmallVector<Value> dynamicValues;
for (const auto &it : mixedValues) {
if (it.is<Attribute>()) {
staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
} else {
staticValues.push_back(ShapedType::kDynamicStrideOrOffset);
staticValues.push_back(dynamicValueIndicator);
dynamicValues.push_back(it.get<Value>());
}
}
@ -243,11 +232,12 @@ decomposeMixedImpl(OpBuilder &b,
std::pair<ArrayAttr, SmallVector<Value>> mlir::decomposeMixedStridesOrOffsets(
OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues) {
return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicStrideOrOffset);
return decomposeMixedValues(b, mixedValues,
ShapedType::kDynamicStrideOrOffset);
}
std::pair<ArrayAttr, SmallVector<Value>>
mlir::decomposeMixedSizes(OpBuilder &b,
const SmallVectorImpl<OpFoldResult> &mixedValues) {
return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicSize);
return decomposeMixedValues(b, mixedValues, ShapedType::kDynamicSize);
}