forked from OSchip/llvm-project
Introduce tensor.insert op to Tensor dialect.
Add `tensor.insert` op to make `tensor.extract`/`tensor.insert` work in pairs for `scalar` domain. Like `subtensor`/`subtensor_insert` work in pairs in `tensor` domain, and `vector.transfer_read`/`vector.transfer_write` work in pairs in `vector` domain. Reviewed By: silvas Differential Revision: https://reviews.llvm.org/D104139
This commit is contained in:
parent
899fdf548e
commit
b4baccc2a7
|
@ -183,6 +183,57 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_InsertOp : Tensor_Op<"insert",
|
||||
[NoSideEffect,
|
||||
TypesMatchWith<"result type matches type of dest",
|
||||
"dest", "result",
|
||||
"$_self.cast<ShapedType>()">,
|
||||
TypesMatchWith<"scalar type matches element type of dest",
|
||||
"dest", "scalar",
|
||||
"$_self.cast<ShapedType>().getElementType()">]> {
|
||||
let summary = "element insertion operation";
|
||||
let description = [{
|
||||
The `tensor.insert` op writes a tensor into a tensor `dest`as specified by
|
||||
the operation's indices.
|
||||
|
||||
It returns a copy of `dest` with the proper subtensor updated with the value
|
||||
of `scalar`.
|
||||
|
||||
The arity of indices must match the rank of the tensor `dest` (i.e., if a
|
||||
tensor is of rank 3, then 3 indices are required for the extract. The
|
||||
indices should all be of `index` type.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%4 = tensor.insert %t into %dest[%1, %2] : tensor<4x4xi32>
|
||||
%5 = tensor.insert %rt into %dest[%1, %2] : tensor<?x?xi32>
|
||||
%6 = tensor.insert %ut into %dest[%1, %2] : tensor<*xi32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyType:$scalar,
|
||||
AnyTensor:$dest,
|
||||
Variadic<Index>:$indices);
|
||||
let results = (outs AnyTensor:$result);
|
||||
let assemblyFormat = [{
|
||||
$scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$scalar, "Value":$dest,
|
||||
CArg<"ValueRange", "{}">:$indices), [{
|
||||
auto resType = dest.getType();
|
||||
build($_builder, $_state, resType, scalar, dest, indices);
|
||||
}]>];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -286,6 +286,28 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
results.add<ExtractElementFromTensorFromElements>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(InsertOp op) {
|
||||
// Verify the # indices match if we have a ranked type.
|
||||
if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
|
||||
if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
|
||||
return op.emitOpError("incorrect number of indices");
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
|
||||
Attribute scalar = operands[0];
|
||||
Attribute dest = operands[1];
|
||||
if (scalar && dest)
|
||||
if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
|
||||
if (scalar == splatDest.getSplatValue())
|
||||
return dest;
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenerateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -96,6 +96,19 @@ func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_insert
|
||||
func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
|
||||
// Fold an insert into a splat.
|
||||
// CHECK-DAG: %[[C4:.+]] = constant dense<4.{{0*}}e+00> : tensor<4xf32>
|
||||
%0 = constant dense<4.0> : tensor<4xf32>
|
||||
%1 = constant 4.0 : f32
|
||||
%ins_1 = tensor.insert %1 into %0[%arg0] : tensor<4xf32>
|
||||
// CHECK-NEXT: return %[[C4]]
|
||||
return %ins_1 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @extract_from_tensor.cast
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
|
||||
func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 {
|
||||
|
|
|
@ -16,6 +16,14 @@ func @extract_too_many_indices(%arg0: tensor<?xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
|
||||
// expected-error@+1 {{incorrect number of indices}}
|
||||
%0 = tensor.insert %arg0 into %arg1[] : tensor<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor.from_elements_wrong_result_type() {
|
||||
// expected-error@+2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}}
|
||||
%c0 = constant 0 : i32
|
||||
|
|
|
@ -22,6 +22,19 @@ func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @insert(
|
||||
// CHECK-SAME: %[[SCALAR:.*]]: f32
|
||||
// CHECK-SAME: %[[INDEX:.*]]: index
|
||||
// CHECK-SAME: %[[DEST1:.*]]: tensor<?x?x?xf32>
|
||||
// CHECK-SAME: %[[DEST2:.*]]: tensor<*xf32>
|
||||
func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>, %arg3: tensor<*xf32>) {
|
||||
// CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
|
||||
%0 = tensor.insert %arg0 into %arg2[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
|
||||
// CHECK: tensor.insert %[[SCALAR]] into %[[DEST2]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<*xf32>
|
||||
%1 = tensor.insert %arg0 into %arg3[%arg1, %arg1, %arg1] : tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor.from_elements() {
|
||||
func @tensor.from_elements() {
|
||||
%c0 = "std.constant"() {value = 0: index} : () -> index
|
||||
|
|
Loading…
Reference in New Issue