Introduce tensor.insert op to Tensor dialect.

Add `tensor.insert` op to make `tensor.extract`/`tensor.insert` work in pairs
for `scalar` domain. Like `subtensor`/`subtensor_insert` work in pairs in
`tensor` domain, and `vector.transfer_read`/`vector.transfer_write` work in
pairs in `vector` domain.

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D104139
This commit is contained in:
Hanhan Wang 2021-06-13 13:45:33 -07:00
parent 899fdf548e
commit b4baccc2a7
5 changed files with 107 additions and 0 deletions

View File

@ -183,6 +183,57 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//
def Tensor_InsertOp : Tensor_Op<"insert",
[NoSideEffect,
TypesMatchWith<"result type matches type of dest",
"dest", "result",
"$_self.cast<ShapedType>()">,
TypesMatchWith<"scalar type matches element type of dest",
"dest", "scalar",
"$_self.cast<ShapedType>().getElementType()">]> {
let summary = "element insertion operation";
let description = [{
The `tensor.insert` op writes a tensor into a tensor `dest`as specified by
the operation's indices.
It returns a copy of `dest` with the proper subtensor updated with the value
of `scalar`.
The arity of indices must match the rank of the tensor `dest` (i.e., if a
tensor is of rank 3, then 3 indices are required for the extract. The
indices should all be of `index` type.
Example:
```mlir
%4 = tensor.insert %t into %dest[%1, %2] : tensor<4x4xi32>
%5 = tensor.insert %rt into %dest[%1, %2] : tensor<?x?xi32>
%6 = tensor.insert %ut into %dest[%1, %2] : tensor<*xi32>
```
}];
let arguments = (ins AnyType:$scalar,
AnyTensor:$dest,
Variadic<Index>:$indices);
let results = (outs AnyTensor:$result);
let assemblyFormat = [{
$scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
}];
let builders = [
OpBuilder<(ins "Value":$scalar, "Value":$dest,
CArg<"ValueRange", "{}">:$indices), [{
auto resType = dest.getType();
build($_builder, $_state, resType, scalar, dest, indices);
}]>];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//

View File

@ -286,6 +286,28 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ExtractElementFromTensorFromElements>(context);
}
//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(InsertOp op) {
// Verify the # indices match if we have a ranked type.
if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
return op.emitOpError("incorrect number of indices");
return success();
}
OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
Attribute scalar = operands[0];
Attribute dest = operands[1];
if (scalar && dest)
if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
if (scalar == splatDest.getSplatValue())
return dest;
return {};
}
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//

View File

@ -96,6 +96,19 @@ func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
// -----
// CHECK-LABEL: func @fold_insert
func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
// Fold an insert into a splat.
// CHECK-DAG: %[[C4:.+]] = constant dense<4.{{0*}}e+00> : tensor<4xf32>
%0 = constant dense<4.0> : tensor<4xf32>
%1 = constant 4.0 : f32
%ins_1 = tensor.insert %1 into %0[%arg0] : tensor<4xf32>
// CHECK-NEXT: return %[[C4]]
return %ins_1 : tensor<4xf32>
}
// -----
// CHECK-LABEL: func @extract_from_tensor.cast
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 {

View File

@ -16,6 +16,14 @@ func @extract_too_many_indices(%arg0: tensor<?xf32>) {
// -----
func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
// expected-error@+1 {{incorrect number of indices}}
%0 = tensor.insert %arg0 into %arg1[] : tensor<?xf32>
return
}
// -----
func @tensor.from_elements_wrong_result_type() {
// expected-error@+2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}}
%c0 = constant 0 : i32

View File

@ -22,6 +22,19 @@ func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
return
}
// CHECK-LABEL: func @insert(
// CHECK-SAME: %[[SCALAR:.*]]: f32
// CHECK-SAME: %[[INDEX:.*]]: index
// CHECK-SAME: %[[DEST1:.*]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[DEST2:.*]]: tensor<*xf32>
func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>, %arg3: tensor<*xf32>) {
// CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
%0 = tensor.insert %arg0 into %arg2[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
// CHECK: tensor.insert %[[SCALAR]] into %[[DEST2]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<*xf32>
%1 = tensor.insert %arg0 into %arg3[%arg1, %arg1, %arg1] : tensor<*xf32>
return
}
// CHECK-LABEL: func @tensor.from_elements() {
func @tensor.from_elements() {
%c0 = "std.constant"() {value = 0: index} : () -> index