forked from OSchip/llvm-project
[MLIR][Shape] Allow `shape.any` to operate on extent tensors
Differential Revision: https://reviews.llvm.org/D84433
This commit is contained in:
parent
274db1d21a
commit
7f600da828
|
@ -509,11 +509,14 @@ def Shape_ConcatOp : Shape_Op<"concat", []> {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO: Move the code below and witnesses to a different file.
|
||||
def Shape_AnyOp : Shape_Op<"any", [Commutative, NoSideEffect]> {
|
||||
def Shape_AnyOp : Shape_Op<"any", [Commutative,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Return any combination of the input shapes";
|
||||
let description = [{
|
||||
This operation takes multiple input shapes and returns some combination of
|
||||
their dimensions. This can be best seen with examples below.
|
||||
This operation takes multiple input shapes or extent tensors and returns
|
||||
some combination of their dimensions. This can be best seen with examples
|
||||
below.
|
||||
|
||||
The result is undefined, but still side-effect free, in cases where the
|
||||
inputs have differing ranks or differ in extents of shared dimensions.
|
||||
|
@ -525,11 +528,10 @@ def Shape_AnyOp : Shape_Op<"any", [Commutative, NoSideEffect]> {
|
|||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
|
||||
let results = (outs Shape_ShapeType:$result);
|
||||
|
||||
let assemblyFormat = "$inputs attr-dict";
|
||||
let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$inputs);
|
||||
let results = (outs Shape_ShapeOrExtentTensorType:$result);
|
||||
|
||||
let assemblyFormat = "$inputs `:` type($result) attr-dict";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -165,11 +165,12 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
|
|||
// Lower `any` to its first operand.
|
||||
// CHECK-LABEL: @any_of_three
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
|
||||
func @any_of_three(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
|
||||
-> !shape.shape {
|
||||
func @any_of_three(%a : tensor<?xindex>,
|
||||
%b : tensor<?xindex>,
|
||||
%c : tensor<?xindex>) -> tensor<?xindex> {
|
||||
// CHECK: return %[[A]] : tensor<?xindex>
|
||||
%result = shape.any %a, %b, %c
|
||||
return %result : !shape.shape
|
||||
%result = shape.any %a, %b, %c : tensor<?xindex>
|
||||
return %result : tensor<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -177,9 +178,9 @@ func @any_of_three(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
|
|||
// Lower `any` to its first operand.
|
||||
// CHECK-LABEL: @any_of_one
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex>
|
||||
func @any_of_one(%a : !shape.shape) -> !shape.shape {
|
||||
func @any_of_one(%a : tensor<?xindex>) -> tensor<?xindex> {
|
||||
// CHECK: return %[[A]] : tensor<?xindex>
|
||||
%result = shape.any %a
|
||||
return %result : !shape.shape
|
||||
%result = shape.any %a : tensor<?xindex>
|
||||
return %result : tensor<?xindex>
|
||||
}
|
||||
|
||||
|
|
|
@ -364,14 +364,25 @@ func @f() {
|
|||
|
||||
// any can be replaced with a constant input if it has one.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0 : !shape.shape) -> !shape.shape {
|
||||
func @f(%arg : !shape.shape) -> !shape.shape {
|
||||
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape
|
||||
// CHECK-NEXT: return %[[CS]]
|
||||
%0 = shape.const_shape [2, 3, 4] : !shape.shape
|
||||
%1 = shape.any %0, %arg0
|
||||
%1 = shape.any %0, %arg : !shape.shape
|
||||
return %1 : !shape.shape
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// any can be replaced with a constant input if it has one.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
|
||||
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<?xindex>
|
||||
// CHECK-NEXT: return %[[CS]] : tensor<?xindex>
|
||||
%0 = shape.const_shape [2, 3, 4] : tensor<?xindex>
|
||||
%1 = shape.any %0, %arg : tensor<?xindex>
|
||||
return %1 : tensor<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -380,7 +391,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
|
|||
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
|
||||
// CHECK-NEXT: %[[CS:.*]] = shape.any
|
||||
// CHECK-NEXT: return %[[CS]]
|
||||
%1 = shape.any %arg0, %arg1
|
||||
%1 = shape.any %arg0, %arg1 : !shape.shape
|
||||
return %1 : !shape.shape
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
// RUN: mlir-opt -split-input-file %s | mlir-opt | FileCheck %s
|
||||
// Verify the printed output can be parsed.
|
||||
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
|
||||
// Verify the generic form can be parsed.
|
||||
|
@ -99,7 +98,7 @@ func @test_constraints() {
|
|||
%w3 = shape.const_witness false
|
||||
%w4 = shape.assuming_all %w0, %w1, %w2, %w3
|
||||
shape.assuming %w4 -> !shape.shape {
|
||||
%2 = shape.any %0, %1
|
||||
%2 = shape.any %0, %1 : !shape.shape
|
||||
shape.assuming_yield %2 : !shape.shape
|
||||
}
|
||||
return
|
||||
|
@ -173,3 +172,14 @@ func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> !shape.size {
|
|||
%result = shape.get_extent %arg, %c0 : tensor<?xindex>
|
||||
return %result : !shape.size
|
||||
}
|
||||
|
||||
func @any() {
|
||||
%0 = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
%1 = shape.const_shape [4, 5, 6] : !shape.shape
|
||||
%2 = shape.any %0, %1 : !shape.shape
|
||||
%3 = shape.const_shape [1, 2, 3] : tensor<?xindex>
|
||||
%4 = shape.const_shape [4, 5, 6] : tensor<?xindex>
|
||||
%5 = shape.any %3, %4 : tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue