[mlir][tensor] Add gather/scatter op definitions to the tensor dialect.

Gather/Scatter are examined from first principles in light of our recent progress on tensor-based codegen
and in-place bufferization.

In the future, lowering of these abstractions to operate **inplace** on buffers
will likely require a more powerful buffer representation than strided memref.

General context: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707
Relevant TL;DR parts of the proposal:
- gather: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707#proposal-gatherop-and-friends-10
- need for more expressive types: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707#proposal-bufferization-copy-view-and-the-need-for-more-expressive-types-12
- jagged buffer discussion: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707#proposal-first-class-jagged-buffer-13

Differential Revision: https://reviews.llvm.org/D130348
This commit is contained in:
Nicolas Vasilache 2022-07-21 04:38:46 -07:00
parent 5de4d97a00
commit d2613d5bb5
4 changed files with 563 additions and 6 deletions

View File

@ -232,7 +232,7 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
Example:
```
```mlir
// Rank-reducing extract_slice.
%1 = tensor.extract_slice %0[0, 0, 0][1, 16, 4][1, 1, 1] :
tensor<8x16x4xf32> to tensor<16x4xf32>
@ -372,8 +372,8 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
"$_self.cast<ShapedType>().getNumElements(), "
"$_self.cast<ShapedType>().getElementType())">
]> {
string summary = "tensor from elements operation.";
string description = [{
let summary = "tensor from elements operation.";
let description = [{
Create a N-D tensor from a range of same-type arguments. The number of
provided `elements` should equal to the number of the elements in the
result type. The `elements` correspond to a flattened tensor.
@ -406,6 +406,144 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// GatherOp
//===----------------------------------------------------------------------===//
def Tensor_GatherOp : Tensor_Op<"gather", [
NoSideEffect
]> {
let summary = "gather a subset of a tensor at specified indices";
let description = [{
The `gather` operation extracts a subset of the elements from a `source`
tensor at the given indices.
In its most general form, the tensor of indices specifies all the coordinates
of every element to extract (i.e. COO format, without the payload).
The indices are expected to be confined to coordinate values that fit the
range of the `source` tensor, otherwise the behavior is undefined.
The leading dimensions of the index tensor give the result tensor its leading
dimensions. The trailing dimensions of the result tensor are obtained from
the source tensor by omitting the dimensions specified in `gather_dims`
(rank-reducing semantics) or setting them to `1` (rank-preserving semantics)
(see examples).
The trailing dimension of the index tensor contains the coordinates and is
expected to have its size equal to the number of dimensions being gathered.
This convention allows an idiomatic specification and lowering of "gathering
multiple N-D slices from the source tensor".
Note: in the examples below, we separate out the indexing part of the tensor
type by a whitespace for readability purposes.
Example:
```mlir
// For each 1x2 triple of coordinates in %indices, extract the
// element (i.e. 0-D subset) at the coordinates triple in %source.
//
%out = tensor.gather %source[%indices] gather_dims([0, 1, 2]) :
(tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
// Note: result type may be further rank-reduced to tensor<1x2x f32>.
```
A slice variant is provided to allow specifying whole slices of the source
tensor.
Example:
```mlir
// For each 5x6 singleton of coordinates in %indices, extract the 2-D
// slice %source[*, %indices[...]:%indices[...] + 1, *] with the indices
// corresponding to the `gather_dims` attribute specified by %indices.
//
%out = tensor.gather %source[%indices] gather_dims([1]) :
(tensor<3x4x5xf32>, tensor<6x7x 1xindex>) -> tensor<6x7x 3x1x5xf32>
// Note: result type may be further rank-reduced to tensor<6x7x 3x5xf32>.
```
The dimensions specified in the gather_dims attribute are ones for which the
result tensor has size `1`.
I.e. if the source type is `axbxcxd` and the coordinates are [1, 3], then
the shape suffix is `ax1xcx1`.
Gather also allows rank-reducing semantics where the shape `ax1xcx1` can be
further simplified to `axc`.
The elemental type of the indices tensor can be any integer type.
In the absence of target-specific or problem specific information the default
type one should use is `index`.
This operation does not support unranked tensors.
An optional `unique` unit attribute may be specified to indicate that the
coordinates in `indices` are statically guaranteed to be unique at runtime.
Incorrectly setting the `unique` attribute when the coordinates are not truly
unique is undefined behavior.
Only full slices are meant to be supported by this op, if one desires
partial slices (e.g. strided windows) one should compose this op with other
tensor ops (e.g. tensor.extract_slice). This is to avoid a slippery slope of
complexity that would make the op unusable in practice.
At the tensor-level, the index tensor is specified in an AoS form (i.e.
coordinate tuple is the most minor). It is the responsibility of further
lowerings and bufferiation to implement various concrete layouts.
Note: As currently specified, the operation must lower to an abstraction that
performs copies to the output tensor. This is because the buffer type system
is currently not rich enough to allow multiple non-contiguous views in the
same type. This is visible more clearly in a notional buffer version of the
op:
```mlir
// memref<?x4x1xf32> is a contiguous buffer of ?x4x1 elements.
// gather from random source slices must copy to the contiguous output.
%out = memref.gather %source[%indices] gather_dims([1]) :
(memref<4x4xf32>, memref<?x 1xindex>) -> memref<?x 4x1xf32>
// Nested buffer support would allow gather to directly index into the
// source buffer (i.e. represent a jagged view into the source).
%out = memref.gather %source[%indices] gather_dims([1]) :
(memref<4x4xf32>, memref<?x 1xindex>) -> memref<? x memref<4x1xf32>>
```
}];
let arguments = (ins AnyRankedTensor:$source,
RankedTensorOf<[AnySignlessIntegerOrIndex]>:$indices,
DenseI64ArrayAttr:$gather_dims,
UnitAttr:$unique);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$source `[` $indices `]`
`gather_dims` `(` $gather_dims `)`
(`unique` $unique^)?
attr-dict
`:` functional-type(operands, results)
}];
let extraClassDeclaration = [{
// TODO: InferTypeOpInterface once enough confidence is built with
// tensor<tensor> and its lwoering to memref<memref>.
static RankedTensorType inferResultType(RankedTensorType sourceType,
RankedTensorType indicesType,
ArrayRef<int64_t> gatherDims,
bool rankReduced);
RankedTensorType getIndicesType() {
return getIndices().getType().cast<RankedTensorType>();
}
RankedTensorType getSourceType() {
return getSource().getType().cast<RankedTensorType>();
}
RankedTensorType getResultType() {
return getResult().getType().cast<RankedTensorType>();
}
}];
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
@ -414,8 +552,8 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
[RecursiveSideEffects,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
string summary = "Creates a dynamically sized tensor from elements";
string description = [{
let summary = "Creates a dynamically sized tensor from elements";
let description = [{
This operation creates a dynamically sized tensor with elements of any type.
It expects one index operand per dynamic extent of the result tensor.
@ -560,7 +698,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
Example:
```
```mlir
// Rank-altering insert_slice.
%1 = tensor.insert_slice %t into %0[0, 0, 0][1, 16, 4][1, 1, 1] :
tensor<16x4xf32> into tensor<8x16x4xf32>
@ -1210,6 +1348,147 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
def Tensor_ScatterOp : Tensor_Op<"scatter", [
NoSideEffect
]> {
let summary =
"scatter a tensor into a destination tensor at specified indices";
let description = [{
The `scatter` operation inserts a `source` tensor into a `dest` tensor at
the given indices.
In its most general form, the tensor of indices specifies all the coordinates
of every element to insert (i.e. COO format, without the payload).
The indices are expected to be confined to coordinate values that fit the
range of the `dest` tensor, otherwise the behavior is undefined.
The leading dimensions of the index tensor must match that of the dest
tensor. The trailing dimensions of the dest tensor must match those of the
source tensor by omitting the dimensions specified in scatter_dims
(rank-reducing semantics) or setting them to `1` (rank-preserving semantics)
(see examples).
This convention allows an idiomatic specification and lowering of
"scattering multiple N-D slices into the dest tensor".
The result type must match the type of the dest tensor.
Note: in the examples below, we separate out the indexing part of the tensor
type by a whitespace for readability purposes.
Example:
```mlir
// For each 1x2 triple of coordinates in %indices, insert the
// element (i.e. 0-D subset) at the coordinates triple in %dest.
//
%out = tensor.scatter %source into %dest[%indices]
scatter_dims([0, 1, 2]) unique :
(tensor<1x2x 1x1x1xf32>, tensor<4x4x4xf32>, tensor<1x2x 3xindex>)
-> tensor<4x4x4xf32>
// Note: source type may be further rank-reduced to tensor<1x2x f32>.
```
A slice variant is provided to allow specifying insertion of whole tensor
slices into the `dest` tensor.
Example:
```mlir
// For each 3 singleton of coordinates in %indices, insert the 2-D
// slice into %dest[*, %indices[...]:%indices[...] + 1, *] with the
// indices corresponding to the scatter_dims attribute specified by
// %indices.
//
%out = tensor.scatter %source into %dest[%indices] scatter_dims([1]) unique :
(tensor<3x 4x1x6xf32>, tensor<4x5x6xf32>, tensor<3x 1xindex>)
-> tensor<4x5x6xf32>
```
The dimensions specified in the scatter_dims attribute are ones for which the
source tensor has size `1`.
I.e. if the dest type is `axbxcxd` and the coordinates are [1, 3], then
the source type suffix is `ax1xcx1`.
Sactter also allows rank-reducing semantics where the shape `ax1xcx1` can be
further simplified to `axc`.
The elemental type of the indices tensor can be any integer type.
In the absence of target-specific or problem specific information the default
type one should use is `index`.
This operation does not support unranked tensors.
A `unique` unit attribute must be be specified to indicate that the
coordinates are statically guaranteed to be unique at runtime. If coordinates
are not truly unique at runtime, the behavior is undefined.
Only full slices are meant to be supported by this op, if one desires
partial slices (e.g. strided windows) one should compose this op with other
tensor ops (e.g. tensor.insert_slice). This is to avoid a slippery slope of
complexity that would make the op unusable in practice.
At the tensor-level, the index tensor is specified in an AoS form (i.e.
coordinate tuple is the most minor). It is the responsibility of further
lowerings and bufferiation to implement various concrete layouts.
Note: As currently specified, the operation must lower to an abstraction that
performs copies to the output tensor. This is because the buffer type system
is currently not rich enough to allow multiple non-contiguous views in the
same type. This is visible more clearly in a notional buffer version of the
op:
```mlir
// memref<?x 4xf32> is a contiguous buffer of ?x4 elements, scatter into
// random dest slices must copy to the contiguous dest.
//
some_side_effecting_op_writing_into %source, ...: memref<3x 4xf32>
memref.scatter %source into %dest[%indices] scatter_dims([1]) unique :
(memref<3x 4xf32>, memref<?x 4xf32>, memref<?x 1xindex>)
// Nested buffer support in the producing op would allow writing directly
// into the dest buffer.
%v = some_nested_buffer_view_op %dest[%indices] scatter_dims([1]) unique :
memref<? x memref<4xf32>>
some_side_effecting_op_writing_into %v, ...: memref<? x memref<4xf32>>
```
}];
let arguments = (ins AnyRankedTensor:$source,
AnyRankedTensor:$dest,
RankedTensorOf<[AnySignlessIntegerOrIndex]>:$indices,
DenseI64ArrayAttr:$scatter_dims,
UnitAttr:$unique);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$source `into` $dest `[` $indices `]`
`scatter_dims` `(` $scatter_dims `)`
(`unique` $unique^)?
attr-dict
`:` functional-type(operands, results)
}];
let extraClassDeclaration = [{
RankedTensorType getDestType() {
return getDest().getType().cast<RankedTensorType>();
}
RankedTensorType getIndicesType() {
return getIndices().getType().cast<RankedTensorType>();
}
RankedTensorType getSourceType() {
return getSource().getType().cast<RankedTensorType>();
}
RankedTensorType getResultType() {
return getResult().getType().cast<RankedTensorType>();
}
}];
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//

View File

@ -17,8 +17,11 @@
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringRef.h"
#include <algorithm>
using namespace mlir;
using namespace mlir::tensor;
@ -543,6 +546,89 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ExtractElementFromIndexCast>(context);
}
//===----------------------------------------------------------------------===//
// GatherOp
//===----------------------------------------------------------------------===//
/// Return the inferred result type for a gatherOp where:
/// - sourceType is the type of the source tensor gathered from
/// - indicesType is the type of the indices used to gather
/// - gatherDims are the dims along which the gather occurs.
/// Return a full rank or ranked-reduced variant of the type depending on
/// the value of rankReduced.
///
/// The leading dimensions of the index tensor give the result tensor its
/// leading dimensions.
/// The trailing dimensions of the result tensor are obtained from the source
/// tensor by setting the dimensions specified in gather_dims to `1` (if
/// rankedReduced is false), or skipping them (otherwise).
RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
RankedTensorType indicesType,
ArrayRef<int64_t> gatherDims,
bool rankReduced) {
SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
resultShape.reserve(resultShape.size() + sourceType.getRank());
for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
if (!rankReduced)
resultShape.push_back(1);
continue;
}
resultShape.push_back(sourceType.getDimSize(idx));
}
return RankedTensorType::Builder(sourceType).setShape(resultShape);
}
static LogicalResult
verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims, int64_t rank,
StringRef gatherOrScatter, StringRef sourceOrDest) {
if (dims.empty())
return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
int64_t numGatherDims = dims.size();
if (numGatherDims > rank)
return op->emitOpError(gatherOrScatter)
<< "_dims overflow " << sourceOrDest << " rank";
for (int64_t val : dims) {
if (val < 0)
return op->emitOpError(gatherOrScatter)
<< "_dims value must be non-negative";
if (val >= rank)
return op->emitOpError(gatherOrScatter)
<< "_dims value must be smaller than " << sourceOrDest << " rank";
}
for (int64_t i = 1; i < numGatherDims; ++i) {
if (dims[i - 1] >= dims[i])
return op->emitOpError(gatherOrScatter)
<< "_dims values must be strictly increasing";
}
return success();
}
LogicalResult GatherOp::verify() {
int64_t sourceRank = getSourceType().getRank();
ArrayRef<int64_t> gatherDims = getGatherDims();
if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank,
"gather", "source")))
return failure();
RankedTensorType expectedResultType = GatherOp::inferResultType(
getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
if (getResultType() != expectedResultType &&
getResultType() != expectedRankReducedResultType) {
return emitOpError("result type "
"mismatch: "
"expected ")
<< expectedResultType << " or its rank-reduced variant "
<< expectedRankReducedResultType << " (got: " << getResultType()
<< ")";
}
return success();
}
//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//
@ -2306,6 +2392,42 @@ void ParallelInsertSliceOp::getCanonicalizationPatterns(
InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
}
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
LogicalResult ScatterOp::verify() {
int64_t destRank = getDestType().getRank();
ArrayRef<int64_t> scatterDims = getScatterDims();
if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank,
"scatter", "dest")))
return failure();
if (!getUnique())
return emitOpError("requires 'unique' attribute to be set");
// TODO: we could also check statically that there are fewer leading index
// tensor dims than the dest dims. If this is not the case, the unique
// attribute cannot be true.
// Use the GatherOp::inferResultType on the `dest` type and verify the
// expected type matches the source type.
RankedTensorType expectedSourceType = GatherOp::inferResultType(
getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
if (getSourceType() != expectedSourceType &&
getSourceType() != expectedRankReducedSourceType) {
return emitOpError("source type "
"mismatch: "
"expected ")
<< expectedSourceType << " or its rank-reduced variant "
<< expectedRankReducedSourceType << " (got: " << getSourceType()
<< ")";
}
return success();
}
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//

View File

@ -377,3 +377,140 @@ func.func @invalid_splat(%v : vector<8xf32>) {
%w = tensor.splat %v : tensor<8xvector<8xf32>>
return
}
// -----
func.func @gather_empty_dims(
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
// expected-error@+1 {{gather_dims must be non-empty}}
%out = tensor.gather %source[%indices] gather_dims([]):
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
return
}
// -----
func.func @gather_coordinate_rank_overflow(
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
// expected-error@+1 {{gather_dims overflow source rank}}
%out = tensor.gather %source[%indices] gather_dims([0, 1, 2, 3]):
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
return
}
// -----
func.func @gather_coordinate_negative(
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
// expected-error@+1 {{gather_dims value must be non-negative}}
%out = tensor.gather %source[%indices] gather_dims([-1]):
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
return
}
// -----
func.func @gather_coordinate_overflow(
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
// expected-error@+1 {{gather_dims value must be smaller than source rank}}
%out = tensor.gather %source[%indices] gather_dims([42]):
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
return
}
// -----
func.func @gather_coordinate_overflow(
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
// expected-error@+1 {{gather_dims values must be strictly increasing}}
%out = tensor.gather %source[%indices] gather_dims([1, 0]):
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
return
}
// -----
func.func @gather_wrong_result_type(
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
// expected-error@+1 {{result type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<1x2x1xf32>')}}
%out = tensor.gather %source[%indices] gather_dims([0, 2]):
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
return
}
// -----
func.func @scatter_empty_dims(
%source : tensor<f32>,
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
// expected-error@+1 {{scatter_dims must be non-empty}}
%out = tensor.scatter %source into %dest[%indices] scatter_dims([]) unique:
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
return
}
// -----
func.func @scatter_coordinate_rank_overflow(
%source : tensor<f32>,
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
// expected-error@+1 {{scatter_dims overflow dest rank}}
%out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 1, 2, 3]) unique:
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
return
}
// -----
func.func @scatter_coordinate_negative(
%source : tensor<f32>,
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
// expected-error@+1 {{scatter_dims value must be non-negative}}
%out = tensor.scatter %source into %dest[%indices] scatter_dims([-1]) unique:
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
return
}
// -----
func.func @scatter_coordinate_overflow(
%source : tensor<f32>,
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
// expected-error@+1 {{scatter_dims value must be smaller than dest rank}}
%out = tensor.scatter %source into %dest[%indices] scatter_dims([42]) unique:
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
return
}
// -----
func.func @scatter_coordinate_overflow(
%source : tensor<f32>,
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
// expected-error@+1 {{scatter_dims values must be strictly increasing}}
%out = tensor.scatter %source into %dest[%indices] scatter_dims([1, 0]) unique:
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
return
}
// -----
func.func @scatter_missing_unique(
%source : tensor<f32>,
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
// expected-error@+1 {{requires 'unique' attribute to be set}}
%out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]):
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
return
}
// -----
func.func @scatter_wrong_result_type(
%source : tensor<f32>,
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
// expected-error@+1 {{source type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<f32>')}}
%out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]) unique:
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
return
}

View File

@ -260,3 +260,22 @@ func.func @test_splat_op(%s : f32) {
%u = "tensor.splat"(%s) : (f32) -> tensor<4xf32>
return
}
// -----
// CHECK-LABEL: func @gather_scatter
func.func @gather_scatter(
%dest : tensor<4x5x6xf32>, %indices: tensor<1x3x2xindex>, %indices_i32: tensor<1x3x2xi32>) {
%gathered = tensor.gather %dest[%indices_i32] gather_dims([1, 2]) unique:
(tensor<4x5x6xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4x1x1xf32>
%rank_reduced_gathered = tensor.gather %dest[%indices] gather_dims([1, 2]) unique:
(tensor<4x5x6xf32>, tensor<1x3x2xindex>) -> tensor<1x3x4xf32>
%scattered = tensor.scatter %gathered into %dest[%indices]
scatter_dims([1, 2]) unique:
(tensor<1x3x4x1x1xf32>, tensor<4x5x6xf32>, tensor<1x3x2xindex>) -> tensor<4x5x6xf32>
%rank_reduced_scattered = tensor.scatter %rank_reduced_gathered into %dest[%indices_i32]
scatter_dims([1, 2]) unique:
(tensor<1x3x4xf32>, tensor<4x5x6xf32>, tensor<1x3x2xi32>) -> tensor<4x5x6xf32>
return
}