[mlir] Add a subtensor operation

This revision introduces a `subtensor` op, which is the counterpart of `subview` for a tensor operand. This also refactors the relevant pieces to allow reusing the `subview` implementation where appropriate.

This operation will be used to implement tiling for Linalg on tensors.
This commit is contained in:
Nicolas Vasilache 2020-10-02 05:32:35 -04:00
parent 670e60c023
commit e3de249a4c
11 changed files with 564 additions and 299 deletions

View File

@ -185,7 +185,7 @@ struct ProcInfo {
Value nprocs;
};
using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo, 2>(
OpBuilder &b, Location loc, ArrayRef<SubViewOp::Range> parallelLoopRanges)>;
OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges)>;
/// Options that allow distribution of loops generated in Linalg transforms to
/// processors while generating the loops.
@ -216,7 +216,7 @@ struct GenerateLoopNest {
AffineIndexedValue, StdIndexedValue>::type;
static void
doit(ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
doit(ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
ArrayRef<Attribute> iteratorTypes,
function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
Optional<LinalgLoopDistributionOptions> = None);

View File

@ -33,6 +33,17 @@ class Builder;
class FuncOp;
class OpBuilder;
/// Auxiliary range data structure to unpack the offset, size and stride
/// operands of the SubViewOp / SubTensorOp into a list of triples.
/// Such a list of triple is sometimes more convenient to manipulate.
struct Range {
Value offset;
Value size;
Value stride;
};
raw_ostream &operator<<(raw_ostream &os, Range &range);
#define GET_OP_CLASSES
#include "mlir/Dialect/StandardOps/IR/Ops.h.inc"
@ -300,8 +311,6 @@ ParseResult parseDimAndSymbolList(OpAsmParser &parser,
SmallVectorImpl<Value> &operands,
unsigned &numDims);
raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range);
/// Determines whether MemRefCastOp casts to a more dynamic version of the
/// source memref. This is useful to to fold a memref_cast into a consuming op
/// and implement canonicalization patterns for ops in different dialects that

View File

@ -2706,11 +2706,214 @@ def SubIOp : IntArithmeticOp<"subi"> {
// SubViewOp
//===----------------------------------------------------------------------===//
def SubViewOp : Std_Op<"subview", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
NoSideEffect,
]> {
class BaseOpWithOffsetSizesAndStrides<string mnemonic, list<OpTrait> traits = []> :
Std_Op<mnemonic,
!listconcat(traits, [NoSideEffect, AttrSizedOperandSegments])> {
let builders = [
// Build a SubViewOp with mixed static and dynamic entries.
OpBuilder<
"Value source, ArrayRef<int64_t> staticOffsets, "
"ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides, "
"ValueRange offsets, ValueRange sizes, ValueRange strides, "
"ArrayRef<NamedAttribute> attrs = {}">,
// Build a SubViewOp with all dynamic entries.
OpBuilder<
"Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, "
"ArrayRef<NamedAttribute> attrs = {}">
];
code extraBaseClassDeclaration = [{
/// Returns the number of dynamic offset operands.
int64_t getNumOffsets() { return llvm::size(offsets()); }
/// Returns the number of dynamic size operands.
int64_t getNumSizes() { return llvm::size(sizes()); }
/// Returns the number of dynamic stride operands.
int64_t getNumStrides() { return llvm::size(strides()); }
/// Returns the dynamic sizes for this subview operation if specified.
operand_range getDynamicSizes() { return sizes(); }
/// Returns in `staticStrides` the static value of the stride
/// operands. Returns failure() if the static value of the stride
/// operands could not be retrieved.
LogicalResult getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
if (!strides().empty())
return failure();
staticStrides.reserve(static_strides().size());
for (auto s : static_strides().getAsValueRange<IntegerAttr>())
staticStrides.push_back(s.getZExtValue());
return success();
}
/// Return the list of Range (i.e. offset, size, stride). Each
/// Range entry contains either the dynamic value or a ConstantIndexOp
/// constructed with `b` at location `loc`.
SmallVector<Range, 8> getOrCreateRanges(OpBuilder &b, Location loc);
/// Return the offsets as Values. Each Value is either the dynamic
/// value specified in the op or a ConstantIndexOp constructed
/// with `b` at location `loc`
SmallVector<Value, 4> getOrCreateOffsets(OpBuilder &b, Location loc) {
unsigned dynamicIdx = 1;
return llvm::to_vector<4>(llvm::map_range(
static_offsets().cast<ArrayAttr>(), [&](Attribute a) -> Value {
int64_t staticOffset = a.cast<IntegerAttr>().getInt();
if (ShapedType::isDynamicStrideOrOffset(staticOffset))
return getOperand(dynamicIdx++);
else
return b.create<ConstantOp>(
loc, b.getIndexType(), b.getIndexAttr(staticOffset));
}));
}
/// Return the sizes as Values. Each Value is either the dynamic
/// value specified in the op or a ConstantIndexOp constructed
/// with `b` at location `loc`
SmallVector<Value, 4> getOrCreateSizes(OpBuilder &b, Location loc) {
unsigned dynamicIdx = 1 + offsets().size();
return llvm::to_vector<4>(llvm::map_range(
static_sizes().cast<ArrayAttr>(), [&](Attribute a) -> Value {
int64_t staticSize = a.cast<IntegerAttr>().getInt();
if (ShapedType::isDynamic(staticSize))
return getOperand(dynamicIdx++);
else
return b.create<ConstantOp>(
loc, b.getIndexType(), b.getIndexAttr(staticSize));
}));
}
/// Return the strides as Values. Each Value is either the dynamic
/// value specified in the op or a ConstantIndexOp constructed with
/// `b` at location `loc`
SmallVector<Value, 4> getOrCreateStrides(OpBuilder &b, Location loc) {
unsigned dynamicIdx = 1 + offsets().size() + sizes().size();
return llvm::to_vector<4>(llvm::map_range(
static_strides().cast<ArrayAttr>(), [&](Attribute a) -> Value {
int64_t staticStride = a.cast<IntegerAttr>().getInt();
if (ShapedType::isDynamicStrideOrOffset(staticStride))
return getOperand(dynamicIdx++);
else
return b.create<ConstantOp>(
loc, b.getIndexType(), b.getIndexAttr(staticStride));
}));
}
/// Return the rank of the source ShapedType.
unsigned getSourceRank() {
return source().getType().cast<ShapedType>().getRank();
}
/// Return the rank of the result ShapedType.
unsigned getResultRank() { return getType().getRank(); }
/// Return true if the offset `idx` is a static constant.
bool isDynamicOffset(unsigned idx) {
APInt v = *(static_offsets().getAsValueRange<IntegerAttr>().begin() + idx);
return ShapedType::isDynamicStrideOrOffset(v.getSExtValue());
}
/// Return true if the size `idx` is a static constant.
bool isDynamicSize(unsigned idx) {
APInt v = *(static_sizes().getAsValueRange<IntegerAttr>().begin() + idx);
return ShapedType::isDynamic(v.getSExtValue());
}
/// Return true if the stride `idx` is a static constant.
bool isDynamicStride(unsigned idx) {
APInt v = *(static_strides().getAsValueRange<IntegerAttr>().begin() + idx);
return ShapedType::isDynamicStrideOrOffset(v.getSExtValue());
}
/// Assert the offset `idx` is a static constant and return its value.
int64_t getStaticOffset(unsigned idx) {
assert(!isDynamicOffset(idx) && "expected static offset");
APInt v = *(static_offsets().getAsValueRange<IntegerAttr>().begin() + idx);
return v.getSExtValue();
}
/// Assert the size `idx` is a static constant and return its value.
int64_t getStaticSize(unsigned idx) {
assert(!isDynamicSize(idx) && "expected static size");
APInt v = *(static_sizes().getAsValueRange<IntegerAttr>().begin() + idx);
return v.getSExtValue();
}
/// Assert the stride `idx` is a static constant and return its value.
int64_t getStaticStride(unsigned idx) {
assert(!isDynamicStride(idx) && "expected static stride");
APInt v = *(static_strides().getAsValueRange<IntegerAttr>().begin() + idx);
return v.getSExtValue();
}
unsigned getNumDynamicEntriesUpToIdx(ArrayAttr attr,
llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) {
return std::count_if(
attr.getValue().begin(), attr.getValue().begin() + idx,
[&](Attribute attr) {
return isDynamic(attr.cast<IntegerAttr>().getInt());
});
}
/// Assert the offset `idx` is dynamic and return the position of the
/// corresponding operand.
unsigned getIndexOfDynamicOffset(unsigned idx) {
assert(isDynamicOffset(idx) && "expected static offset");
auto numDynamic =
getNumDynamicEntriesUpToIdx(static_offsets().cast<ArrayAttr>(),
ShapedType::isDynamicStrideOrOffset, idx);
return 1 + numDynamic;
}
/// Assert the size `idx` is dynamic and return the position of the
/// corresponding operand.
unsigned getIndexOfDynamicSize(unsigned idx) {
assert(isDynamicSize(idx) && "expected static size");
auto numDynamic = getNumDynamicEntriesUpToIdx(
static_sizes().cast<ArrayAttr>(), ShapedType::isDynamic, idx);
return 1 + offsets().size() + numDynamic;
}
/// Assert the stride `idx` is dynamic and return the position of the
/// corresponding operand.
unsigned getIndexOfDynamicStride(unsigned idx) {
assert(isDynamicStride(idx) && "expected static stride");
auto numDynamic =
getNumDynamicEntriesUpToIdx(static_strides().cast<ArrayAttr>(),
ShapedType::isDynamicStrideOrOffset, idx);
return 1 + offsets().size() + sizes().size() + numDynamic;
}
/// Assert the offset `idx` is dynamic and return its value.
Value getDynamicOffset(unsigned idx) {
return getOperand(getIndexOfDynamicOffset(idx));
}
/// Assert the size `idx` is dynamic and return its value.
Value getDynamicSize(unsigned idx) {
return getOperand(getIndexOfDynamicSize(idx));
}
/// Assert the stride `idx` is dynamic and return its value.
Value getDynamicStride(unsigned idx) {
return getOperand(getIndexOfDynamicStride(idx));
}
static StringRef getStaticOffsetsAttrName() {
return "static_offsets";
}
static StringRef getStaticSizesAttrName() {
return "static_sizes";
}
static StringRef getStaticStridesAttrName() {
return "static_strides";
}
static ArrayRef<StringRef> getSpecialAttrNames() {
static SmallVector<StringRef, 4> names{
getStaticOffsetsAttrName(),
getStaticSizesAttrName(),
getStaticStridesAttrName(),
getOperandSegmentSizeAttr()};
return names;
}
}];
}
def SubViewOp : BaseOpWithOffsetSizesAndStrides<
"subview", [DeclareOpInterfaceMethods<ViewLikeOpInterface>] > {
let summary = "memref subview operation";
let description = [{
The "subview" operation converts a memref type to another memref type
@ -2726,8 +2929,11 @@ def SubViewOp : Std_Op<"subview", [
* Sizes: memref-rank number of dynamic sizes or static integer attributes
which specify the sizes of the result "view" memref type.
* Strides: memref-rank number of dynamic strides or static integer
attributes multiplicatively to the base memref strides in each
dimension.
attributes that compose multiplicatively with the base memref
strides in each dimension.
A subview operation may additionally reduce the rank of the resulting view
by removing dimensions that are statically known to be of size 1.
Example 1:
@ -2817,6 +3023,15 @@ def SubViewOp : Std_Op<"subview", [
// memref is "inbounds" w.r.t to base memref. It is upto the client
// to ensure that the subview is accessed in a manner that is
// in-bounds.
Example 5:
```
// Rank-reducing subview.
%1 = subview %0[0, 0, 0][1, 16, 4][1, 1, 1] :
memref<8x16x4xf32> to memref<16x4xf32>
%3 = subview %2[3, 4, 2][1, 6, 3][1, 1, 1] :
memref<8x16x4xf32> to memref<6x3xf32, offset: 210, strides: [4, 1]>
```
}
}];
@ -2859,137 +3074,97 @@ def SubViewOp : Std_Op<"subview", [
"ArrayRef<NamedAttribute> attrs = {}">
];
let extraClassDeclaration = [{
let extraClassDeclaration = extraBaseClassDeclaration # [{
/// Returns the type of the base memref operand.
MemRefType getBaseMemRefType() {
MemRefType getSourceMemRefType() {
return source().getType().cast<MemRefType>();
}
/// The result of a subview is always a memref.
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
/// Returns as integer value the number of offset operands.
int64_t getNumOffsets() { return llvm::size(offsets()); }
/// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
static Type inferResultType(MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
}];
/// Returns as integer value the number of size operands.
int64_t getNumSizes() { return llvm::size(sizes()); }
let hasCanonicalizer = 1;
}
/// Returns as integer value the number of stride operands.
int64_t getNumStrides() { return llvm::size(strides()); }
//===----------------------------------------------------------------------===//
// SubTensorOp
//===----------------------------------------------------------------------===//
/// Returns the dynamic sizes for this subview operation if specified.
operand_range getDynamicSizes() { return sizes(); }
def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> {
let summary = "subtensor operation";
let description = [{
The "subtensor" operation extract a tensor from another tensor as
specified by the operation's offsets, sizes and strides arguments.
/// Returns in `staticStrides` the static value of the stride
/// operands. Returns failure() if the static value of the stride
/// operands could not be retrieved.
LogicalResult getStaticStrides(SmallVectorImpl<int64_t> &staticStrides);
The subtensor operation supports the following arguments:
/// Auxiliary range data structure and helper function that unpacks the
/// offset, size and stride operands of the SubViewOp into a list of triples.
/// Such a list of triple is sometimes more convenient to manipulate.
struct Range {
Value offset, size, stride;
};
/// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each
/// Range entry contains either the dynamic value or a ConstantIndexOp
/// constructed with `b` at location `loc`.
SmallVector<Range, 8> getOrCreateRanges(OpBuilder &b, Location loc);
* tensor: the "base" tensor from which to extract a subtensor.
* offsets: tensor-rank number of dynamic offsets or static integer
attributes into the "base" tensor from which to extract the
subtensor.
* sizes: tensor-rank number of dynamic sizes or static integer attributes
which specify the sizes of the result tensor type.
* strides: tensor-rank number of dynamic strides or static integer
attributes specifying susampling in each dimension.
/// Return the offsets as Values. Each Value is either the dynamic
/// value specified in the op or a ConstantIndexOp constructed
/// with `b` at location `loc`
SmallVector<Value, 4> getOrCreateOffsets(OpBuilder &b, Location loc);
After buffer-allocation, the "subtensor" op is expected to lower into a
"subview" op.
/// Return the sizes as Values. Each Value is either the dynamic
/// value specified in the op or a ConstantIndexOp constructed
/// with `b` at location `loc`
SmallVector<Value, 4> getOrCreateSizes(OpBuilder &b, Location loc);
A subtensor operation may additionally reduce the rank of the resulting
tensor by removing dimensions that are statically known to be of size 1.
/// Return the strides as Values. Each Value is either the dynamic
/// value specified in the op or a ConstantIndexOp constructed with
/// `b` at location `loc`
SmallVector<Value, 4> getOrCreateStrides(OpBuilder &b, Location loc);
Example:
```
// Rank-reducing subtensor.
%1 = subtensor %0[0, 0, 0][1, 16, 4][1, 1, 1] :
tensor<8x16x4xf32> to tensor<16x4xf32>
%3 = subtensor %2[3, 4, 2][1, 6, 3][1, 1, 1] :
tensor<8x16x4xf32> to tensor<6x3xf32>
```
}];
let arguments = (ins
AnyRankedTensor:$source,
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
I64ArrayAttr:$static_offsets,
I64ArrayAttr:$static_sizes,
I64ArrayAttr:$static_strides
);
let results = (outs AnyRankedTensor:$result);
let extraClassDeclaration = extraBaseClassDeclaration # [{
/// Returns the type of the base tensor operand.
RankedTensorType getSourceRankedTensorType() {
return source().getType().cast<RankedTensorType>();
}
/// The result of a subtensor is always a tensor.
RankedTensorType getType() {
return getResult().getType().cast<RankedTensorType>();
}
/// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
static Type inferSubViewResultType(MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
/// Return the rank of the result MemRefType.
unsigned getRank() { return getType().getRank(); }
/// Return true if the offset `idx` is a static constant.
bool isDynamicOffset(unsigned idx);
/// Return true if the size `idx` is a static constant.
bool isDynamicSize(unsigned idx);
/// Return true if the stride `idx` is a static constant.
bool isDynamicStride(unsigned idx);
/// Assert the offset `idx` is a static constant and return its value.
int64_t getStaticOffset(unsigned idx) {
assert(!isDynamicOffset(idx) && "expected static offset");
return
static_offsets().cast<ArrayAttr>()[idx].cast<IntegerAttr>().getInt();
}
/// Assert the size `idx` is a static constant and return its value.
int64_t getStaticSize(unsigned idx) {
assert(!isDynamicSize(idx) && "expected static size");
return static_sizes().cast<ArrayAttr>()[idx].cast<IntegerAttr>().getInt();
}
/// Assert the stride `idx` is a static constant and return its value.
int64_t getStaticStride(unsigned idx) {
assert(!isDynamicStride(idx) && "expected static stride");
return
static_strides().cast<ArrayAttr>()[idx].cast<IntegerAttr>().getInt();
}
/// Assert the offset `idx` is dynamic and return the position of the
/// corresponding operand.
unsigned getIndexOfDynamicOffset(unsigned idx);
/// Assert the size `idx` is dynamic and return the position of the
/// corresponding operand.
unsigned getIndexOfDynamicSize(unsigned idx);
/// Assert the stride `idx` is dynamic and return the position of the
/// corresponding operand.
unsigned getIndexOfDynamicStride(unsigned idx);
/// Assert the offset `idx` is dynamic and return its value.
Value getDynamicOffset(unsigned idx) {
return getOperand(getIndexOfDynamicOffset(idx));
}
/// Assert the size `idx` is dynamic and return its value.
Value getDynamicSize(unsigned idx) {
return getOperand(getIndexOfDynamicSize(idx));
}
/// Assert the stride `idx` is dynamic and return its value.
Value getDynamicStride(unsigned idx) {
return getOperand(getIndexOfDynamicStride(idx));
}
static StringRef getStaticOffsetsAttrName() {
return "static_offsets";
}
static StringRef getStaticSizesAttrName() {
return "static_sizes";
}
static StringRef getStaticStridesAttrName() {
return "static_strides";
}
static ArrayRef<StringRef> getSpecialAttrNames() {
static SmallVector<StringRef, 4> names{
getStaticOffsetsAttrName(),
getStaticSizesAttrName(),
getStaticStridesAttrName(),
getOperandSegmentSizeAttr()};
return names;
}
static Type inferResultType(RankedTensorType sourceRankedTensorType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
}];
let hasCanonicalizer = 1;
// let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -60,7 +60,7 @@ using llvm::dbgs;
// This is achieved by applying the `loopToOperandRangesMaps` permutation maps
// to the `loopRanges` in order to obtain view ranges.
static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
ArrayRef<SubViewOp::Range> loopRanges) {
ArrayRef<Range> loopRanges) {
assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
auto maps = op.indexing_maps();
SmallVector<Value, 8> clonedViews;
@ -73,7 +73,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
auto map = maps[idx].cast<AffineMapAttr>().getValue();
LLVM_DEBUG(dbgs() << "map: " << map << "\n");
Value view = en.value();
SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
SmallVector<Range, 4> viewRanges(map.getNumResults());
for (auto en2 : llvm::enumerate(map.getResults())) {
unsigned d = en2.index();
// loopToOperandRangesMaps are permutations-only.
@ -182,7 +182,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
unsigned nPar = producer.getNumParallelLoops();
unsigned nRed = producer.getNumReductionLoops();
unsigned nWin = producer.getNumWindowLoops();
SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
// Iterate over dimensions identified by the producer map for `producerIdx`.
// This defines a subset of the loop ranges that we need to complete later.
@ -202,9 +202,9 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
<< "existing LoopRange: " << loopRanges[i] << "\n");
else {
auto viewDim = getViewDefiningLoopRange(producer, i);
loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0),
std_dim(viewDim.view, viewDim.dimension),
folded_std_constant_index(folder, 1)};
loopRanges[i] = Range{folded_std_constant_index(folder, 0),
std_dim(viewDim.view, viewDim.dimension),
folded_std_constant_index(folder, 1)};
LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
}
}
@ -300,8 +300,6 @@ static bool isSameSubView(Value a, Value b) {
return false;
if (sva.getType() != svb.getType())
return false;
if (sva.getRank() != svb.getRank())
return false;
if (sva.getNumOperands() != svb.getNumOperands())
return false;
if (sva.static_offsets() != svb.static_offsets())

View File

@ -65,22 +65,21 @@ static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
/// DimExpr or (DimExpr + DimExpr - SymbolExpr floordiv ConstExpr).
/// It expects a non-inverted, concatenated map and last values in
/// allViewSizes will be applied to the symbols in the map if it contains any.
static SmallVector<SubViewOp::Range, 4> emitLoopRanges(OpBuilder &b,
Location loc,
AffineMap map,
ValueRange viewSizes) {
static SmallVector<Range, 4> emitLoopRanges(OpBuilder &b, Location loc,
AffineMap map,
ValueRange viewSizes) {
unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
unsigned numSym = map.getNumSymbols();
assert(viewSizes.size() == numRes + numSym &&
"viewSizes must contain sizes of all views and values for symbols");
SmallVector<SubViewOp::Range, 4> res(numDims);
SmallVector<Range, 4> res(numDims);
for (unsigned idx = 0; idx < numRes; ++idx) {
auto result = map.getResult(idx);
if (auto d = result.dyn_cast<AffineDimExpr>()) {
if (res[d.getPosition()].offset)
continue;
res[d.getPosition()] = SubViewOp::Range{
std_constant_index(0), viewSizes[idx], std_constant_index(1)};
res[d.getPosition()] =
Range{std_constant_index(0), viewSizes[idx], std_constant_index(1)};
}
// If the access pattern is of form (m, n)[s] -> (m + n - s floordiv 2),
@ -124,7 +123,7 @@ static SmallVector<SubViewOp::Range, 4> emitLoopRanges(OpBuilder &b,
// Construction of the lower bound (s floordiv 2).
Value from = applyMapToValues(b, loc, fromMap, values).front();
Value to = applyMapToValues(b, loc, toMap, values).front();
res[mPos] = SubViewOp::Range{from, to, std_constant_index(1)};
res[mPos] = Range{from, to, std_constant_index(1)};
}
}
return res;

View File

@ -54,7 +54,7 @@ using LoopIndexToRangeIndexMap = DenseMap<int, int>;
// are tiled and for which new loops will be created. Also the function returns
// a map from loop indices of the LinalgOp to the corresponding non-empty range
// indices of newly created loops.
static std::tuple<SmallVector<SubViewOp::Range, 4>, LoopIndexToRangeIndexMap>
static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
ArrayRef<Value> allViewSizes,
ArrayRef<Value> allTileSizes) {
@ -76,10 +76,9 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
}
// Create a new range with the applied tile sizes.
SmallVector<SubViewOp::Range, 4> res;
SmallVector<Range, 4> res;
for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
res.push_back(SubViewOp::Range{std_constant_index(0), viewSizes[idx],
tileSizes[idx]});
res.push_back(Range{std_constant_index(0), viewSizes[idx], tileSizes[idx]});
return std::make_tuple(res, loopIndexToRangeIndex);
}
@ -346,7 +345,7 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
if (!viewSizesToLoopsMap)
return llvm::None;
SmallVector<SubViewOp::Range, 4> loopRanges;
SmallVector<Range, 4> loopRanges;
LoopIndexToRangeIndexMap loopIndexToRangeIndex;
std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(
b, op.getLoc(), viewSizesToLoopsMap, allViewSizes, tileSizes);

View File

@ -133,11 +133,10 @@ template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
/// Given a list of subview ranges, extract individual values for lower, upper
/// bounds and steps and put them into the corresponding vectors.
static void unpackRanges(ArrayRef<SubViewOp::Range> ranges,
SmallVectorImpl<Value> &lbs,
static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
SmallVectorImpl<Value> &ubs,
SmallVectorImpl<Value> &steps) {
for (SubViewOp::Range range : ranges) {
for (Range range : ranges) {
lbs.emplace_back(range.offset);
ubs.emplace_back(range.size);
steps.emplace_back(range.stride);
@ -194,7 +193,7 @@ getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, OperationFolder *folder) {
/// Specialization to build an scf "for" nest.
template <>
void GenerateLoopNest<scf::ForOp>::doit(
ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
ArrayRef<Attribute> iteratorTypes,
function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
Optional<LinalgLoopDistributionOptions>) {
@ -206,7 +205,7 @@ void GenerateLoopNest<scf::ForOp>::doit(
/// Specialization to build affine "for" nest.
template <>
void GenerateLoopNest<AffineForOp>::doit(
ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
ArrayRef<Attribute> iteratorTypes,
function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
Optional<LinalgLoopDistributionOptions>) {
@ -364,7 +363,7 @@ generateParallelLoopNest(ValueRange lbs, ValueRange ubs, ValueRange steps,
/// Specialization for generating a mix of parallel and sequential scf loops.
template <>
void GenerateLoopNest<scf::ParallelOp>::doit(
ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
ArrayRef<Attribute> iteratorTypes,
function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
Optional<LinalgLoopDistributionOptions> distributionOptions) {
@ -391,7 +390,7 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
Location loc = edsc::ScopedContext::getLocation();
distributionMethod.assign(distributionOptions->distributionMethod.begin(),
distributionOptions->distributionMethod.end());
SmallVector<SubViewOp::Range, 2> parallelLoopRanges;
SmallVector<Range, 2> parallelLoopRanges;
for (auto iteratorType : enumerate(iteratorTypes)) {
if (isParallelIteratorType(iteratorType.value()))
parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);

View File

@ -2587,10 +2587,10 @@ Wrapper operator*(Wrapper a, int64_t b) {
/// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides) {
Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides) {
unsigned rank = sourceMemRefType.getRank();
(void)rank;
assert(staticOffsets.size() == rank &&
@ -2638,7 +2638,8 @@ Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType,
/// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
/// `:` strided-memref-type `to` strided-memref-type
/// ```
static void print(OpAsmPrinter &p, SubViewOp op) {
template <typename OpType>
static void printOpWithOffsetsSizesAndStrides(OpAsmPrinter &p, OpType op) {
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
p << op.getOperand(0);
@ -2649,16 +2650,22 @@ static void print(OpAsmPrinter &p, SubViewOp op) {
printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
ShapedType::isDynamicStrideOrOffset);
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{SubViewOp::getSpecialAttrNames()});
/*elidedAttrs=*/{OpType::getSpecialAttrNames()});
p << " : " << op.getOperand(0).getType() << " to " << op.getType();
}
static void print(OpAsmPrinter &p, SubViewOp op) {
return printOpWithOffsetsSizesAndStrides<SubViewOp>(p, op);
}
/// Parse SubViewOp of the form:
/// ```
/// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
/// `name` ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
/// `:` strided-memref-type `to` strided-memref-type
/// ```
static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
template <typename OpType>
static ParseResult parseOpWithOffsetsSizesAndStrides(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType srcInfo;
SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
auto indexType = parser.getBuilder().getIndexType();
@ -2666,13 +2673,13 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
if (parser.parseOperand(srcInfo))
return failure();
if (parseListOfOperandsOrIntegers(
parser, result, SubViewOp::getStaticOffsetsAttrName(),
parser, result, OpType::getStaticOffsetsAttrName(),
ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
parseListOfOperandsOrIntegers(parser, result,
SubViewOp::getStaticSizesAttrName(),
OpType::getStaticSizesAttrName(),
ShapedType::kDynamicSize, sizesInfo) ||
parseListOfOperandsOrIntegers(
parser, result, SubViewOp::getStaticStridesAttrName(),
parser, result, OpType::getStaticStridesAttrName(),
ShapedType::kDynamicStrideOrOffset, stridesInfo))
return failure();
@ -2680,7 +2687,7 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
SmallVector<int, 4> segmentSizes{1, static_cast<int>(offsetsInfo.size()),
static_cast<int>(sizesInfo.size()),
static_cast<int>(stridesInfo.size())};
result.addAttribute(SubViewOp::getOperandSegmentSizeAttr(),
result.addAttribute(OpType::getOperandSegmentSizeAttr(),
b.getI32VectorAttr(segmentSizes));
return failure(
@ -2694,6 +2701,10 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
parser.addTypeToList(dstType, result.types));
}
static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
return parseOpWithOffsetsSizesAndStrides<SubViewOp>(parser, result);
}
void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
@ -2701,8 +2712,8 @@ void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
ValueRange sizes, ValueRange strides,
ArrayRef<NamedAttribute> attrs) {
auto sourceMemRefType = source.getType().cast<MemRefType>();
auto resultType = inferSubViewResultType(sourceMemRefType, staticOffsets,
staticSizes, staticStrides);
auto resultType = inferResultType(sourceMemRefType, staticOffsets,
staticSizes, staticStrides);
build(b, result, resultType, source, offsets, sizes, strides,
b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
b.getI64ArrayAttr(staticStrides));
@ -2760,15 +2771,18 @@ void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
staticStridesVector, offsets, sizes, strides, attrs);
}
/// For ViewLikeOpInterface.
Value SubViewOp::getViewSource() { return source(); }
/// Verify that a particular offset/size/stride static attribute is well-formed.
static LogicalResult
verifySubViewOpPart(SubViewOp op, StringRef name, StringRef attrName,
ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic,
ValueRange values) {
template <typename OpType>
static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
OpType op, StringRef name, StringRef attrName, ArrayAttr attr,
llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) {
/// Check static and dynamic offsets/sizes/strides breakdown.
size_t inputRank = op.source().getType().cast<MemRefType>().getRank();
if (attr.size() != inputRank)
return op.emitError("expected ") << inputRank << " " << name << " values";
if (attr.size() != op.getSourceRank())
return op.emitError("expected ")
<< op.getSourceRank() << " " << name << " values";
unsigned expectedNumDynamicEntries =
llvm::count_if(attr.getValue(), [&](Attribute attr) {
return isDynamic(attr.cast<IntegerAttr>().getInt());
@ -2787,17 +2801,26 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
}));
}
/// Checks if `original` MemRef type can be rank reduced to `reduced` type.
/// Checks if `original` Type type can be rank reduced to `reduced` type.
/// This function is slight variant of `is subsequence` algorithm where
/// not matching dimension must be 1.
static bool isRankReducedType(Type originalType, Type reducedType) {
if (originalType == reducedType)
return true;
if (!originalType.isa<RankedTensorType>() && !originalType.isa<MemRefType>())
return true;
if (originalType.isa<RankedTensorType>() &&
!reducedType.isa<RankedTensorType>())
return true;
if (originalType.isa<MemRefType>() && !reducedType.isa<MemRefType>())
return true;
MemRefType original = originalType.cast<MemRefType>();
MemRefType reduced = reducedType.cast<MemRefType>();
ArrayRef<int64_t> originalShape = original.getShape();
ArrayRef<int64_t> reducedShape = reduced.getShape();
ShapedType originalShapedType = originalType.cast<ShapedType>();
ShapedType reducedShapedType = reducedType.cast<ShapedType>();
// Rank and size logic is valid for all ShapedTypes.
ArrayRef<int64_t> originalShape = originalShapedType.getShape();
ArrayRef<int64_t> reducedShape = reducedShapedType.getShape();
unsigned originalRank = originalShape.size(),
reducedRank = reducedShape.size();
if (reducedRank > originalRank)
@ -2819,6 +2842,13 @@ static bool isRankReducedType(Type originalType, Type reducedType) {
if (reducedIdx != reducedRank)
return false;
// We are done for the tensor case.
if (originalType.isa<RankedTensorType>())
return true;
// Strided layout logic is relevant for MemRefType only.
MemRefType original = originalType.cast<MemRefType>();
MemRefType reduced = reducedType.cast<MemRefType>();
MLIRContext *c = original.getContext();
int64_t originalOffset, symCounter = 0, dimCounter = 0;
SmallVector<int64_t, 4> originalStrides;
@ -2843,10 +2873,29 @@ static bool isRankReducedType(Type originalType, Type reducedType) {
reducedMap == reduced.getAffineMaps().front());
}
template <typename OpType>
static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) {
// Verify static attributes offsets/sizes/strides.
if (failed(verifyOpWithOffsetSizesAndStridesPart(
op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(),
ShapedType::isDynamicStrideOrOffset, op.offsets())))
return failure();
if (failed(verifyOpWithOffsetSizesAndStridesPart(
op, "size", op.getStaticSizesAttrName(), op.static_sizes(),
ShapedType::isDynamic, op.sizes())))
return failure();
if (failed(verifyOpWithOffsetSizesAndStridesPart(
op, "stride", op.getStaticStridesAttrName(), op.static_strides(),
ShapedType::isDynamicStrideOrOffset, op.strides())))
return failure();
return success();
}
/// Verifier for SubViewOp.
static LogicalResult verify(SubViewOp op) {
auto baseType = op.getBaseMemRefType().cast<MemRefType>();
auto subViewType = op.getType();
MemRefType baseType = op.getSourceMemRefType();
MemRefType subViewType = op.getType();
// The base memref and the view memref should be in the same memory space.
if (baseType.getMemorySpace() != subViewType.getMemorySpace())
@ -2858,24 +2907,12 @@ static LogicalResult verify(SubViewOp op) {
if (!isStrided(baseType))
return op.emitError("base type ") << baseType << " is not strided";
// Verify static attributes offsets/sizes/strides.
if (failed(verifySubViewOpPart(
op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(),
ShapedType::isDynamicStrideOrOffset, op.offsets())))
return failure();
if (failed(verifySubViewOpPart(op, "size", op.getStaticSizesAttrName(),
op.static_sizes(), ShapedType::isDynamic,
op.sizes())))
return failure();
if (failed(verifySubViewOpPart(
op, "stride", op.getStaticStridesAttrName(), op.static_strides(),
ShapedType::isDynamicStrideOrOffset, op.strides())))
if (failed(verifyOpWithOffsetSizesAndStrides(op)))
return failure();
// Verify result type against inferred type.
auto expectedType = SubViewOp::inferSubViewResultType(
op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()),
auto expectedType = SubViewOp::inferResultType(
baseType, extractFromI64ArrayAttr(op.static_offsets()),
extractFromI64ArrayAttr(op.static_sizes()),
extractFromI64ArrayAttr(op.static_strides()));
if (!isRankReducedType(expectedType, subViewType))
@ -2885,123 +2922,41 @@ static LogicalResult verify(SubViewOp op) {
return success();
}
raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
return os << "range " << range.offset << ":" << range.size << ":"
<< range.stride;
}
static unsigned getNumDynamicEntriesUpToIdx(
ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) {
return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx,
[&](Attribute attr) {
return isDynamic(attr.cast<IntegerAttr>().getInt());
});
}
bool SubViewOp::isDynamicOffset(unsigned idx) {
return ShapedType::isDynamicStrideOrOffset(
extractFromI64ArrayAttr(static_offsets())[idx]);
}
bool SubViewOp::isDynamicSize(unsigned idx) {
return ShapedType::isDynamic(extractFromI64ArrayAttr(static_sizes())[idx]);
}
bool SubViewOp::isDynamicStride(unsigned idx) {
return ShapedType::isDynamicStrideOrOffset(
extractFromI64ArrayAttr(static_strides())[idx]);
}
unsigned SubViewOp::getIndexOfDynamicOffset(unsigned idx) {
assert(isDynamicOffset(idx) && "expected static offset");
auto numDynamic =
getNumDynamicEntriesUpToIdx(static_offsets().cast<ArrayAttr>(),
ShapedType::isDynamicStrideOrOffset, idx);
return 1 + numDynamic;
}
unsigned SubViewOp::getIndexOfDynamicSize(unsigned idx) {
assert(isDynamicSize(idx) && "expected static size");
auto numDynamic = getNumDynamicEntriesUpToIdx(
static_sizes().cast<ArrayAttr>(), ShapedType::isDynamic, idx);
return 1 + offsets().size() + numDynamic;
}
unsigned SubViewOp::getIndexOfDynamicStride(unsigned idx) {
assert(isDynamicStride(idx) && "expected static stride");
auto numDynamic =
getNumDynamicEntriesUpToIdx(static_strides().cast<ArrayAttr>(),
ShapedType::isDynamicStrideOrOffset, idx);
return 1 + offsets().size() + sizes().size() + numDynamic;
}
/// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each Range
/// Return the list of Range (i.e. offset, size, stride). Each Range
/// entry contains either the dynamic value or a ConstantIndexOp constructed
/// with `b` at location `loc`.
SmallVector<SubViewOp::Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b,
Location loc) {
template <typename OpType>
static SmallVector<Range, 8> getOrCreateRangesImpl(OpType op, OpBuilder &b,
Location loc) {
SmallVector<Range, 8> res;
unsigned rank = getType().getRank();
unsigned rank = op.getSourceRank();
res.reserve(rank);
for (unsigned idx = 0; idx < rank; ++idx) {
auto offset = isDynamicOffset(idx)
? getDynamicOffset(idx)
: b.create<ConstantIndexOp>(loc, getStaticOffset(idx));
auto size = isDynamicSize(idx)
? getDynamicSize(idx)
: b.create<ConstantIndexOp>(loc, getStaticSize(idx));
auto stride = isDynamicStride(idx)
? getDynamicStride(idx)
: b.create<ConstantIndexOp>(loc, getStaticStride(idx));
Value offset =
op.isDynamicOffset(idx)
? op.getDynamicOffset(idx)
: b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
Value size = op.isDynamicSize(idx)
? op.getDynamicSize(idx)
: b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
Value stride =
op.isDynamicStride(idx)
? op.getDynamicStride(idx)
: b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
res.emplace_back(Range{offset, size, stride});
}
return res;
}
SmallVector<Value, 4> SubViewOp::getOrCreateOffsets(OpBuilder &b,
Location loc) {
unsigned dynamicIdx = 1;
return llvm::to_vector<4>(llvm::map_range(
static_offsets().cast<ArrayAttr>(), [&](Attribute a) -> Value {
int64_t staticOffset = a.cast<IntegerAttr>().getInt();
if (ShapedType::isDynamicStrideOrOffset(staticOffset))
return getOperand(dynamicIdx++);
else
return b.create<ConstantIndexOp>(loc, staticOffset);
}));
SmallVector<Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b, Location loc) {
return ::getOrCreateRangesImpl(*this, b, loc);
}
SmallVector<Value, 4> SubViewOp::getOrCreateSizes(OpBuilder &b, Location loc) {
unsigned dynamicIdx = 1 + offsets().size();
return llvm::to_vector<4>(llvm::map_range(
static_sizes().cast<ArrayAttr>(), [&](Attribute a) -> Value {
int64_t staticSize = a.cast<IntegerAttr>().getInt();
if (ShapedType::isDynamic(staticSize))
return getOperand(dynamicIdx++);
else
return b.create<ConstantIndexOp>(loc, staticSize);
}));
}
SmallVector<Value, 4> SubViewOp::getOrCreateStrides(OpBuilder &b,
Location loc) {
unsigned dynamicIdx = 1 + offsets().size() + sizes().size();
return llvm::to_vector<4>(llvm::map_range(
static_strides().cast<ArrayAttr>(), [&](Attribute a) -> Value {
int64_t staticStride = a.cast<IntegerAttr>().getInt();
if (ShapedType::isDynamicStrideOrOffset(staticStride))
return getOperand(dynamicIdx++);
else
return b.create<ConstantIndexOp>(loc, staticStride);
}));
}
LogicalResult
SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
if (!strides().empty())
return failure();
staticStrides = extractFromI64ArrayAttr(static_strides());
return success();
}
Value SubViewOp::getViewSource() { return source(); }
namespace {
/// Take a list of `values` with potential new constant to extract and a list
@ -3053,20 +3008,20 @@ public:
SmallVector<Value, 8> newOffsets(subViewOp.offsets());
SmallVector<int64_t, 8> newStaticOffsets =
extractFromI64ArrayAttr(subViewOp.static_offsets());
assert(newStaticOffsets.size() == subViewOp.getRank());
assert(newStaticOffsets.size() == subViewOp.getSourceRank());
canonicalizeSubViewPart(newOffsets, newStaticOffsets,
ShapedType::isDynamicStrideOrOffset);
SmallVector<Value, 8> newSizes(subViewOp.sizes());
SmallVector<int64_t, 8> newStaticSizes =
extractFromI64ArrayAttr(subViewOp.static_sizes());
assert(newStaticOffsets.size() == subViewOp.getRank());
assert(newStaticOffsets.size() == subViewOp.getSourceRank());
canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic);
SmallVector<Value, 8> newStrides(subViewOp.strides());
SmallVector<int64_t, 8> newStaticStrides =
extractFromI64ArrayAttr(subViewOp.static_strides());
assert(newStaticOffsets.size() == subViewOp.getRank());
assert(newStaticOffsets.size() == subViewOp.getSourceRank());
canonicalizeSubViewPart(newStrides, newStaticStrides,
ShapedType::isDynamicStrideOrOffset);
@ -3210,7 +3165,7 @@ public:
/// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
/// the cast source operand type and the SubViewOp static information. This
/// is the resulting type if the MemRefCastOp were folded.
Type resultType = SubViewOp::inferSubViewResultType(
Type resultType = SubViewOp::inferResultType(
castOp.source().getType().cast<MemRefType>(),
extractFromI64ArrayAttr(subViewOp.static_offsets()),
extractFromI64ArrayAttr(subViewOp.static_sizes()),
@ -3232,6 +3187,94 @@ void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
context);
}
//===----------------------------------------------------------------------===//
// SubTensorOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, SubTensorOp op) {
return printOpWithOffsetsSizesAndStrides<SubTensorOp>(p, op);
}
static ParseResult parseSubTensorOp(OpAsmParser &parser,
OperationState &result) {
return parseOpWithOffsetsSizesAndStrides<SubTensorOp>(parser, result);
}
/// A subtensor result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides) {
unsigned rank = sourceRankedTensorType.getRank();
(void)rank;
assert(staticOffsets.size() == rank &&
"unexpected staticOffsets size mismatch");
assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch");
assert(staticStrides.size() == rank &&
"unexpected staticStrides size mismatch");
return RankedTensorType::get(staticSizes,
sourceRankedTensorType.getElementType());
}
void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,
Value source, ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides,
ValueRange offsets, ValueRange sizes,
ValueRange strides,
ArrayRef<NamedAttribute> attrs) {
auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
auto resultType = inferResultType(sourceRankedTensorType, staticOffsets,
staticSizes, staticStrides);
build(b, result, resultType, source, offsets, sizes, strides,
b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
b.getI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
/// Build a SubTensorOp with all dynamic entries: `staticOffsets`, `staticSizes`
/// and `staticStrides` are automatically filled with sentinel values that
/// encode dynamic entries.
void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,
Value source, ValueRange offsets,
ValueRange sizes, ValueRange strides,
ArrayRef<NamedAttribute> attrs) {
auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
unsigned rank = sourceRankedTensorType.getRank();
SmallVector<int64_t, 4> staticOffsetsVector(
rank, ShapedType::kDynamicStrideOrOffset);
SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
SmallVector<int64_t, 4> staticStridesVector(
rank, ShapedType::kDynamicStrideOrOffset);
build(b, result, source, staticOffsetsVector, staticSizesVector,
staticStridesVector, offsets, sizes, strides, attrs);
}
SmallVector<Range, 8> SubTensorOp::getOrCreateRanges(OpBuilder &b,
Location loc) {
return ::getOrCreateRangesImpl(*this, b, loc);
}
/// Verifier for SubTensorOp.
static LogicalResult verify(SubTensorOp op) {
if (failed(verifyOpWithOffsetSizesAndStrides(op)))
return failure();
// Verify result type against inferred type.
auto expectedType = SubTensorOp::inferResultType(
op.getSourceRankedTensorType(),
extractFromI64ArrayAttr(op.static_offsets()),
extractFromI64ArrayAttr(op.static_sizes()),
extractFromI64ArrayAttr(op.static_strides()));
if (!isRankReducedType(expectedType, op.getType()))
return op.emitError("expected result type to be ")
<< expectedType << " or a rank-reduced version.";
return success();
}
//===----------------------------------------------------------------------===//
// TensorCastOp
//===----------------------------------------------------------------------===//

View File

@ -900,3 +900,27 @@ func @assume_alignment(%0: memref<4x4xf16>) {
assume_alignment %0, 16 : memref<4x4xf16>
return
}
// CHECK-LABEL: func @subtensor({{.*}}) {
func @subtensor(%t: tensor<8x16x4xf32>, %idx : index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
// CHECK: subtensor
// CHECK-SAME: tensor<8x16x4xf32> to tensor<?x?x?xf32>
%1 = subtensor %t[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
: tensor<8x16x4xf32> to tensor<?x?x?xf32>
// CHECK: subtensor
// CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4x4xf32>
%2 = subtensor %t[0, 2, 0][4, 4, 4][1, 1, 1]
: tensor<8x16x4xf32> to tensor<4x4x4xf32>
// CHECK: subtensor
// CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4xf32>
%3 = subtensor %t[0, 2, 0][4, 1, 4][1, 1, 1]
: tensor<8x16x4xf32> to tensor<4x4xf32>
return
}

View File

@ -1255,3 +1255,23 @@ func @imaginary_part_from_incompatible_complex_type(%cplx: complex<f64>) {
std.re %cplx : complex<f32>
return
}
// -----
func @subtensor_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
// expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>'}}
%0 = subtensor %t[0, 2, 0][4, 4, 4][1, 1, 1]
: tensor<8x16x4xf32> to tensor<?x4x4xf32>
return
}
// -----
func @subtensor_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) {
// expected-error @+1 {{expected result type to be 'tensor<?x3x?xf32>'}}
%0 = subtensor %t[0, 0, 0][%idx, 3, %idx][1, 1, 1]
: tensor<8x16x4xf32> to tensor<4x4x4xf32>
return
}

View File

@ -301,8 +301,7 @@ static void fillPromotionCallBackPatterns(MLIRContext *ctx,
template <typename IdOp, typename NProcsOp>
static SmallVector<ProcInfo, 2>
getGpuProcIds(OpBuilder &b, Location loc,
ArrayRef<SubViewOp::Range> parallelLoopRanges) {
getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
Type indexType = b.getIndexType();
SmallVector<ProcInfo, 2> procInfo(2);
procInfo[0] = {b.create<IdOp>(loc, indexType, b.getStringAttr("y")),