forked from OSchip/llvm-project
[mlir][tensor] Add gather/scatter op definitions to the tensor dialect.
Gather/Scatter are examined from first principles in light of our recent progress on tensor-based codegen and in-place bufferization. In the future, lowering of these abstractions to operate **inplace** on buffers will likely require a more powerful buffer representation than strided memref. General context: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707 Relevant TL;DR parts of the proposal: - gather: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707#proposal-gatherop-and-friends-10 - need for more expressive types: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707#proposal-bufferization-copy-view-and-the-need-for-more-expressive-types-12 - jagged buffer discussion: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707#proposal-first-class-jagged-buffer-13 Differential Revision: https://reviews.llvm.org/D130348
This commit is contained in:
parent
5de4d97a00
commit
d2613d5bb5
|
@ -232,7 +232,7 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
|
|||
|
||||
Example:
|
||||
|
||||
```
|
||||
```mlir
|
||||
// Rank-reducing extract_slice.
|
||||
%1 = tensor.extract_slice %0[0, 0, 0][1, 16, 4][1, 1, 1] :
|
||||
tensor<8x16x4xf32> to tensor<16x4xf32>
|
||||
|
@ -372,8 +372,8 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
|
|||
"$_self.cast<ShapedType>().getNumElements(), "
|
||||
"$_self.cast<ShapedType>().getElementType())">
|
||||
]> {
|
||||
string summary = "tensor from elements operation.";
|
||||
string description = [{
|
||||
let summary = "tensor from elements operation.";
|
||||
let description = [{
|
||||
Create a N-D tensor from a range of same-type arguments. The number of
|
||||
provided `elements` should equal to the number of the elements in the
|
||||
result type. The `elements` correspond to a flattened tensor.
|
||||
|
@ -406,6 +406,144 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GatherOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_GatherOp : Tensor_Op<"gather", [
|
||||
NoSideEffect
|
||||
]> {
|
||||
let summary = "gather a subset of a tensor at specified indices";
|
||||
let description = [{
|
||||
The `gather` operation extracts a subset of the elements from a `source`
|
||||
tensor at the given indices.
|
||||
|
||||
In its most general form, the tensor of indices specifies all the coordinates
|
||||
of every element to extract (i.e. COO format, without the payload).
|
||||
The indices are expected to be confined to coordinate values that fit the
|
||||
range of the `source` tensor, otherwise the behavior is undefined.
|
||||
|
||||
The leading dimensions of the index tensor give the result tensor its leading
|
||||
dimensions. The trailing dimensions of the result tensor are obtained from
|
||||
the source tensor by omitting the dimensions specified in `gather_dims`
|
||||
(rank-reducing semantics) or setting them to `1` (rank-preserving semantics)
|
||||
(see examples).
|
||||
The trailing dimension of the index tensor contains the coordinates and is
|
||||
expected to have its size equal to the number of dimensions being gathered.
|
||||
This convention allows an idiomatic specification and lowering of "gathering
|
||||
multiple N-D slices from the source tensor".
|
||||
|
||||
Note: in the examples below, we separate out the indexing part of the tensor
|
||||
type by a whitespace for readability purposes.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// For each 1x2 triple of coordinates in %indices, extract the
|
||||
// element (i.e. 0-D subset) at the coordinates triple in %source.
|
||||
//
|
||||
%out = tensor.gather %source[%indices] gather_dims([0, 1, 2]) :
|
||||
(tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
|
||||
|
||||
// Note: result type may be further rank-reduced to tensor<1x2x f32>.
|
||||
```
|
||||
|
||||
A slice variant is provided to allow specifying whole slices of the source
|
||||
tensor.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// For each 5x6 singleton of coordinates in %indices, extract the 2-D
|
||||
// slice %source[*, %indices[...]:%indices[...] + 1, *] with the indices
|
||||
// corresponding to the `gather_dims` attribute specified by %indices.
|
||||
//
|
||||
%out = tensor.gather %source[%indices] gather_dims([1]) :
|
||||
(tensor<3x4x5xf32>, tensor<6x7x 1xindex>) -> tensor<6x7x 3x1x5xf32>
|
||||
|
||||
// Note: result type may be further rank-reduced to tensor<6x7x 3x5xf32>.
|
||||
```
|
||||
|
||||
The dimensions specified in the gather_dims attribute are ones for which the
|
||||
result tensor has size `1`.
|
||||
I.e. if the source type is `axbxcxd` and the coordinates are [1, 3], then
|
||||
the shape suffix is `ax1xcx1`.
|
||||
Gather also allows rank-reducing semantics where the shape `ax1xcx1` can be
|
||||
further simplified to `axc`.
|
||||
|
||||
The elemental type of the indices tensor can be any integer type.
|
||||
In the absence of target-specific or problem specific information the default
|
||||
type one should use is `index`.
|
||||
|
||||
This operation does not support unranked tensors.
|
||||
|
||||
An optional `unique` unit attribute may be specified to indicate that the
|
||||
coordinates in `indices` are statically guaranteed to be unique at runtime.
|
||||
Incorrectly setting the `unique` attribute when the coordinates are not truly
|
||||
unique is undefined behavior.
|
||||
|
||||
Only full slices are meant to be supported by this op, if one desires
|
||||
partial slices (e.g. strided windows) one should compose this op with other
|
||||
tensor ops (e.g. tensor.extract_slice). This is to avoid a slippery slope of
|
||||
complexity that would make the op unusable in practice.
|
||||
|
||||
At the tensor-level, the index tensor is specified in an AoS form (i.e.
|
||||
coordinate tuple is the most minor). It is the responsibility of further
|
||||
lowerings and bufferiation to implement various concrete layouts.
|
||||
|
||||
Note: As currently specified, the operation must lower to an abstraction that
|
||||
performs copies to the output tensor. This is because the buffer type system
|
||||
is currently not rich enough to allow multiple non-contiguous views in the
|
||||
same type. This is visible more clearly in a notional buffer version of the
|
||||
op:
|
||||
|
||||
```mlir
|
||||
// memref<?x4x1xf32> is a contiguous buffer of ?x4x1 elements.
|
||||
// gather from random source slices must copy to the contiguous output.
|
||||
%out = memref.gather %source[%indices] gather_dims([1]) :
|
||||
(memref<4x4xf32>, memref<?x 1xindex>) -> memref<?x 4x1xf32>
|
||||
|
||||
// Nested buffer support would allow gather to directly index into the
|
||||
// source buffer (i.e. represent a jagged view into the source).
|
||||
%out = memref.gather %source[%indices] gather_dims([1]) :
|
||||
(memref<4x4xf32>, memref<?x 1xindex>) -> memref<? x memref<4x1xf32>>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyRankedTensor:$source,
|
||||
RankedTensorOf<[AnySignlessIntegerOrIndex]>:$indices,
|
||||
DenseI64ArrayAttr:$gather_dims,
|
||||
UnitAttr:$unique);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source `[` $indices `]`
|
||||
`gather_dims` `(` $gather_dims `)`
|
||||
(`unique` $unique^)?
|
||||
attr-dict
|
||||
`:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TODO: InferTypeOpInterface once enough confidence is built with
|
||||
// tensor<tensor> and its lwoering to memref<memref>.
|
||||
static RankedTensorType inferResultType(RankedTensorType sourceType,
|
||||
RankedTensorType indicesType,
|
||||
ArrayRef<int64_t> gatherDims,
|
||||
bool rankReduced);
|
||||
RankedTensorType getIndicesType() {
|
||||
return getIndices().getType().cast<RankedTensorType>();
|
||||
}
|
||||
RankedTensorType getSourceType() {
|
||||
return getSource().getType().cast<RankedTensorType>();
|
||||
}
|
||||
RankedTensorType getResultType() {
|
||||
return getResult().getType().cast<RankedTensorType>();
|
||||
}
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenerateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -414,8 +552,8 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
|
|||
[RecursiveSideEffects,
|
||||
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
|
||||
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
|
||||
string summary = "Creates a dynamically sized tensor from elements";
|
||||
string description = [{
|
||||
let summary = "Creates a dynamically sized tensor from elements";
|
||||
let description = [{
|
||||
This operation creates a dynamically sized tensor with elements of any type.
|
||||
It expects one index operand per dynamic extent of the result tensor.
|
||||
|
||||
|
@ -560,7 +698,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
|
|||
|
||||
Example:
|
||||
|
||||
```
|
||||
```mlir
|
||||
// Rank-altering insert_slice.
|
||||
%1 = tensor.insert_slice %t into %0[0, 0, 0][1, 16, 4][1, 1, 1] :
|
||||
tensor<16x4xf32> into tensor<8x16x4xf32>
|
||||
|
@ -1210,6 +1348,147 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
|
|||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ScatterOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_ScatterOp : Tensor_Op<"scatter", [
|
||||
NoSideEffect
|
||||
]> {
|
||||
let summary =
|
||||
"scatter a tensor into a destination tensor at specified indices";
|
||||
let description = [{
|
||||
The `scatter` operation inserts a `source` tensor into a `dest` tensor at
|
||||
the given indices.
|
||||
|
||||
In its most general form, the tensor of indices specifies all the coordinates
|
||||
of every element to insert (i.e. COO format, without the payload).
|
||||
The indices are expected to be confined to coordinate values that fit the
|
||||
range of the `dest` tensor, otherwise the behavior is undefined.
|
||||
|
||||
The leading dimensions of the index tensor must match that of the dest
|
||||
tensor. The trailing dimensions of the dest tensor must match those of the
|
||||
source tensor by omitting the dimensions specified in scatter_dims
|
||||
(rank-reducing semantics) or setting them to `1` (rank-preserving semantics)
|
||||
(see examples).
|
||||
This convention allows an idiomatic specification and lowering of
|
||||
"scattering multiple N-D slices into the dest tensor".
|
||||
The result type must match the type of the dest tensor.
|
||||
|
||||
Note: in the examples below, we separate out the indexing part of the tensor
|
||||
type by a whitespace for readability purposes.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// For each 1x2 triple of coordinates in %indices, insert the
|
||||
// element (i.e. 0-D subset) at the coordinates triple in %dest.
|
||||
//
|
||||
%out = tensor.scatter %source into %dest[%indices]
|
||||
scatter_dims([0, 1, 2]) unique :
|
||||
(tensor<1x2x 1x1x1xf32>, tensor<4x4x4xf32>, tensor<1x2x 3xindex>)
|
||||
-> tensor<4x4x4xf32>
|
||||
|
||||
// Note: source type may be further rank-reduced to tensor<1x2x f32>.
|
||||
```
|
||||
|
||||
A slice variant is provided to allow specifying insertion of whole tensor
|
||||
slices into the `dest` tensor.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// For each 3 singleton of coordinates in %indices, insert the 2-D
|
||||
// slice into %dest[*, %indices[...]:%indices[...] + 1, *] with the
|
||||
// indices corresponding to the scatter_dims attribute specified by
|
||||
// %indices.
|
||||
//
|
||||
%out = tensor.scatter %source into %dest[%indices] scatter_dims([1]) unique :
|
||||
(tensor<3x 4x1x6xf32>, tensor<4x5x6xf32>, tensor<3x 1xindex>)
|
||||
-> tensor<4x5x6xf32>
|
||||
```
|
||||
|
||||
The dimensions specified in the scatter_dims attribute are ones for which the
|
||||
source tensor has size `1`.
|
||||
I.e. if the dest type is `axbxcxd` and the coordinates are [1, 3], then
|
||||
the source type suffix is `ax1xcx1`.
|
||||
Sactter also allows rank-reducing semantics where the shape `ax1xcx1` can be
|
||||
further simplified to `axc`.
|
||||
|
||||
The elemental type of the indices tensor can be any integer type.
|
||||
In the absence of target-specific or problem specific information the default
|
||||
type one should use is `index`.
|
||||
|
||||
This operation does not support unranked tensors.
|
||||
|
||||
A `unique` unit attribute must be be specified to indicate that the
|
||||
coordinates are statically guaranteed to be unique at runtime. If coordinates
|
||||
are not truly unique at runtime, the behavior is undefined.
|
||||
|
||||
Only full slices are meant to be supported by this op, if one desires
|
||||
partial slices (e.g. strided windows) one should compose this op with other
|
||||
tensor ops (e.g. tensor.insert_slice). This is to avoid a slippery slope of
|
||||
complexity that would make the op unusable in practice.
|
||||
|
||||
At the tensor-level, the index tensor is specified in an AoS form (i.e.
|
||||
coordinate tuple is the most minor). It is the responsibility of further
|
||||
lowerings and bufferiation to implement various concrete layouts.
|
||||
|
||||
Note: As currently specified, the operation must lower to an abstraction that
|
||||
performs copies to the output tensor. This is because the buffer type system
|
||||
is currently not rich enough to allow multiple non-contiguous views in the
|
||||
same type. This is visible more clearly in a notional buffer version of the
|
||||
op:
|
||||
|
||||
```mlir
|
||||
// memref<?x 4xf32> is a contiguous buffer of ?x4 elements, scatter into
|
||||
// random dest slices must copy to the contiguous dest.
|
||||
//
|
||||
some_side_effecting_op_writing_into %source, ...: memref<3x 4xf32>
|
||||
memref.scatter %source into %dest[%indices] scatter_dims([1]) unique :
|
||||
(memref<3x 4xf32>, memref<?x 4xf32>, memref<?x 1xindex>)
|
||||
|
||||
// Nested buffer support in the producing op would allow writing directly
|
||||
// into the dest buffer.
|
||||
%v = some_nested_buffer_view_op %dest[%indices] scatter_dims([1]) unique :
|
||||
memref<? x memref<4xf32>>
|
||||
some_side_effecting_op_writing_into %v, ...: memref<? x memref<4xf32>>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyRankedTensor:$source,
|
||||
AnyRankedTensor:$dest,
|
||||
RankedTensorOf<[AnySignlessIntegerOrIndex]>:$indices,
|
||||
DenseI64ArrayAttr:$scatter_dims,
|
||||
UnitAttr:$unique);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source `into` $dest `[` $indices `]`
|
||||
`scatter_dims` `(` $scatter_dims `)`
|
||||
(`unique` $unique^)?
|
||||
attr-dict
|
||||
`:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
RankedTensorType getDestType() {
|
||||
return getDest().getType().cast<RankedTensorType>();
|
||||
}
|
||||
RankedTensorType getIndicesType() {
|
||||
return getIndices().getType().cast<RankedTensorType>();
|
||||
}
|
||||
RankedTensorType getSourceType() {
|
||||
return getSource().getType().cast<RankedTensorType>();
|
||||
}
|
||||
RankedTensorType getResultType() {
|
||||
return getResult().getType().cast<RankedTensorType>();
|
||||
}
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SplatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -17,8 +17,11 @@
|
|||
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include <algorithm>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tensor;
|
||||
|
@ -543,6 +546,89 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
results.add<ExtractElementFromIndexCast>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GatherOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Return the inferred result type for a gatherOp where:
|
||||
/// - sourceType is the type of the source tensor gathered from
|
||||
/// - indicesType is the type of the indices used to gather
|
||||
/// - gatherDims are the dims along which the gather occurs.
|
||||
/// Return a full rank or ranked-reduced variant of the type depending on
|
||||
/// the value of rankReduced.
|
||||
///
|
||||
/// The leading dimensions of the index tensor give the result tensor its
|
||||
/// leading dimensions.
|
||||
/// The trailing dimensions of the result tensor are obtained from the source
|
||||
/// tensor by setting the dimensions specified in gather_dims to `1` (if
|
||||
/// rankedReduced is false), or skipping them (otherwise).
|
||||
RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
|
||||
RankedTensorType indicesType,
|
||||
ArrayRef<int64_t> gatherDims,
|
||||
bool rankReduced) {
|
||||
SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
|
||||
resultShape.reserve(resultShape.size() + sourceType.getRank());
|
||||
for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
|
||||
if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
|
||||
if (!rankReduced)
|
||||
resultShape.push_back(1);
|
||||
continue;
|
||||
}
|
||||
resultShape.push_back(sourceType.getDimSize(idx));
|
||||
}
|
||||
return RankedTensorType::Builder(sourceType).setShape(resultShape);
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims, int64_t rank,
|
||||
StringRef gatherOrScatter, StringRef sourceOrDest) {
|
||||
if (dims.empty())
|
||||
return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
|
||||
|
||||
int64_t numGatherDims = dims.size();
|
||||
if (numGatherDims > rank)
|
||||
return op->emitOpError(gatherOrScatter)
|
||||
<< "_dims overflow " << sourceOrDest << " rank";
|
||||
for (int64_t val : dims) {
|
||||
if (val < 0)
|
||||
return op->emitOpError(gatherOrScatter)
|
||||
<< "_dims value must be non-negative";
|
||||
if (val >= rank)
|
||||
return op->emitOpError(gatherOrScatter)
|
||||
<< "_dims value must be smaller than " << sourceOrDest << " rank";
|
||||
}
|
||||
for (int64_t i = 1; i < numGatherDims; ++i) {
|
||||
if (dims[i - 1] >= dims[i])
|
||||
return op->emitOpError(gatherOrScatter)
|
||||
<< "_dims values must be strictly increasing";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult GatherOp::verify() {
|
||||
int64_t sourceRank = getSourceType().getRank();
|
||||
ArrayRef<int64_t> gatherDims = getGatherDims();
|
||||
if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank,
|
||||
"gather", "source")))
|
||||
return failure();
|
||||
|
||||
RankedTensorType expectedResultType = GatherOp::inferResultType(
|
||||
getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
|
||||
RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
|
||||
getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
|
||||
if (getResultType() != expectedResultType &&
|
||||
getResultType() != expectedRankReducedResultType) {
|
||||
return emitOpError("result type "
|
||||
"mismatch: "
|
||||
"expected ")
|
||||
<< expectedResultType << " or its rank-reduced variant "
|
||||
<< expectedRankReducedResultType << " (got: " << getResultType()
|
||||
<< ")";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2306,6 +2392,42 @@ void ParallelInsertSliceOp::getCanonicalizationPatterns(
|
|||
InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ScatterOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ScatterOp::verify() {
|
||||
int64_t destRank = getDestType().getRank();
|
||||
ArrayRef<int64_t> scatterDims = getScatterDims();
|
||||
if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank,
|
||||
"scatter", "dest")))
|
||||
return failure();
|
||||
|
||||
if (!getUnique())
|
||||
return emitOpError("requires 'unique' attribute to be set");
|
||||
// TODO: we could also check statically that there are fewer leading index
|
||||
// tensor dims than the dest dims. If this is not the case, the unique
|
||||
// attribute cannot be true.
|
||||
|
||||
// Use the GatherOp::inferResultType on the `dest` type and verify the
|
||||
// expected type matches the source type.
|
||||
RankedTensorType expectedSourceType = GatherOp::inferResultType(
|
||||
getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
|
||||
RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
|
||||
getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
|
||||
if (getSourceType() != expectedSourceType &&
|
||||
getSourceType() != expectedRankReducedSourceType) {
|
||||
return emitOpError("source type "
|
||||
"mismatch: "
|
||||
"expected ")
|
||||
<< expectedSourceType << " or its rank-reduced variant "
|
||||
<< expectedRankReducedSourceType << " (got: " << getSourceType()
|
||||
<< ")";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SplatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -377,3 +377,140 @@ func.func @invalid_splat(%v : vector<8xf32>) {
|
|||
%w = tensor.splat %v : tensor<8xvector<8xf32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @gather_empty_dims(
|
||||
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
|
||||
// expected-error@+1 {{gather_dims must be non-empty}}
|
||||
%out = tensor.gather %source[%indices] gather_dims([]):
|
||||
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @gather_coordinate_rank_overflow(
|
||||
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
|
||||
// expected-error@+1 {{gather_dims overflow source rank}}
|
||||
%out = tensor.gather %source[%indices] gather_dims([0, 1, 2, 3]):
|
||||
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @gather_coordinate_negative(
|
||||
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
|
||||
// expected-error@+1 {{gather_dims value must be non-negative}}
|
||||
%out = tensor.gather %source[%indices] gather_dims([-1]):
|
||||
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @gather_coordinate_overflow(
|
||||
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
|
||||
// expected-error@+1 {{gather_dims value must be smaller than source rank}}
|
||||
%out = tensor.gather %source[%indices] gather_dims([42]):
|
||||
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @gather_coordinate_overflow(
|
||||
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
|
||||
// expected-error@+1 {{gather_dims values must be strictly increasing}}
|
||||
%out = tensor.gather %source[%indices] gather_dims([1, 0]):
|
||||
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @gather_wrong_result_type(
|
||||
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
|
||||
// expected-error@+1 {{result type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<1x2x1xf32>')}}
|
||||
%out = tensor.gather %source[%indices] gather_dims([0, 2]):
|
||||
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scatter_empty_dims(
|
||||
%source : tensor<f32>,
|
||||
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
|
||||
// expected-error@+1 {{scatter_dims must be non-empty}}
|
||||
%out = tensor.scatter %source into %dest[%indices] scatter_dims([]) unique:
|
||||
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scatter_coordinate_rank_overflow(
|
||||
%source : tensor<f32>,
|
||||
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
|
||||
// expected-error@+1 {{scatter_dims overflow dest rank}}
|
||||
%out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 1, 2, 3]) unique:
|
||||
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scatter_coordinate_negative(
|
||||
%source : tensor<f32>,
|
||||
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
|
||||
// expected-error@+1 {{scatter_dims value must be non-negative}}
|
||||
%out = tensor.scatter %source into %dest[%indices] scatter_dims([-1]) unique:
|
||||
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scatter_coordinate_overflow(
|
||||
%source : tensor<f32>,
|
||||
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
|
||||
// expected-error@+1 {{scatter_dims value must be smaller than dest rank}}
|
||||
%out = tensor.scatter %source into %dest[%indices] scatter_dims([42]) unique:
|
||||
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scatter_coordinate_overflow(
|
||||
%source : tensor<f32>,
|
||||
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
|
||||
// expected-error@+1 {{scatter_dims values must be strictly increasing}}
|
||||
%out = tensor.scatter %source into %dest[%indices] scatter_dims([1, 0]) unique:
|
||||
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scatter_missing_unique(
|
||||
%source : tensor<f32>,
|
||||
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
|
||||
// expected-error@+1 {{requires 'unique' attribute to be set}}
|
||||
%out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]):
|
||||
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scatter_wrong_result_type(
|
||||
%source : tensor<f32>,
|
||||
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
|
||||
// expected-error@+1 {{source type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<f32>')}}
|
||||
%out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]) unique:
|
||||
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -260,3 +260,22 @@ func.func @test_splat_op(%s : f32) {
|
|||
%u = "tensor.splat"(%s) : (f32) -> tensor<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @gather_scatter
|
||||
func.func @gather_scatter(
|
||||
%dest : tensor<4x5x6xf32>, %indices: tensor<1x3x2xindex>, %indices_i32: tensor<1x3x2xi32>) {
|
||||
%gathered = tensor.gather %dest[%indices_i32] gather_dims([1, 2]) unique:
|
||||
(tensor<4x5x6xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4x1x1xf32>
|
||||
%rank_reduced_gathered = tensor.gather %dest[%indices] gather_dims([1, 2]) unique:
|
||||
(tensor<4x5x6xf32>, tensor<1x3x2xindex>) -> tensor<1x3x4xf32>
|
||||
|
||||
%scattered = tensor.scatter %gathered into %dest[%indices]
|
||||
scatter_dims([1, 2]) unique:
|
||||
(tensor<1x3x4x1x1xf32>, tensor<4x5x6xf32>, tensor<1x3x2xindex>) -> tensor<4x5x6xf32>
|
||||
%rank_reduced_scattered = tensor.scatter %rank_reduced_gathered into %dest[%indices_i32]
|
||||
scatter_dims([1, 2]) unique:
|
||||
(tensor<1x3x4xf32>, tensor<4x5x6xf32>, tensor<1x3x2xi32>) -> tensor<4x5x6xf32>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue