forked from OSchip/llvm-project
[VectorOps] Add legality rules to broadcast
PiperOrigin-RevId: 283360101
This commit is contained in:
parent
b41162b3af
commit
3126004a5a
|
@ -171,7 +171,24 @@ def Vector_BroadcastOp :
|
|||
let summary = "broadcast operation";
|
||||
let description = [{
|
||||
Broadcasts the scalar or k-D vector value in the source operand
|
||||
to a n-D result vector such that the broadcast makes sense.
|
||||
to a n-D result vector such that the broadcast makes sense, i.e.,
|
||||
the source operand is duplicated to match the given rank and sizes
|
||||
in the result vector. The legality rules are:
|
||||
* the source operand must have the same element type as the result type
|
||||
* a k-D vector <s_1 x .. x s_k x type> can be broadcast to
|
||||
a n-D vector <t_1 x .. x t_n x type> if
|
||||
* k <= n, and
|
||||
* the sizes in the trailing dimensions n-k < i <= n with j=i+k-n
|
||||
match exactly as s_j = t_i or s_j = 1:
|
||||
```
|
||||
t_1 x .. t_n-k x t_n-k+1 x .. x t_i x .. x t_n
|
||||
s_1 x .. x s_j x .. x s_k
|
||||
<duplication> <potential stretch>
|
||||
```
|
||||
The source operand is duplicated over all the missing leading dimensions
|
||||
and streched over the trailing dimensions where the source has a non-equal
|
||||
dimension of 1. These rules imply that any scalar broadcast (k=0) to any
|
||||
shaped vector with the same element type is always legal.
|
||||
|
||||
Examples:
|
||||
```
|
||||
|
|
|
@ -386,10 +386,17 @@ static LogicalResult verify(BroadcastOp op) {
|
|||
if (srcVectorType) {
|
||||
const int64_t srcRank = srcVectorType.getRank();
|
||||
const int64_t dstRank = dstVectorType.getRank();
|
||||
// TODO(ajcbik): implement proper rank testing for broadcast;
|
||||
// this is just a temporary placeholder check.
|
||||
if (srcRank > dstRank) {
|
||||
if (srcRank > dstRank)
|
||||
return op.emitOpError("source rank higher than destination rank");
|
||||
// Source has an exact match or singleton value for all trailing dimensions
|
||||
// (all leading dimensions are simply duplicated).
|
||||
const int64_t lead = dstRank - srcRank;
|
||||
for (int64_t i = 0; i < srcRank; i++) {
|
||||
const int64_t srcDim = srcVectorType.getDimSize(i);
|
||||
const int64_t dstDim = dstVectorType.getDimSize(lead + i);
|
||||
if (srcDim != 1 && srcDim != dstDim)
|
||||
return op.emitOpError("dimension mismatch (")
|
||||
<< srcDim << " vs. " << dstDim << ")";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
|
|
|
@ -9,6 +9,20 @@ func @broadcast_rank_too_high(%arg0: vector<4x4xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @broadcast_dim1_mismatch(%arg0: vector<7xf32>) {
|
||||
// expected-error@+1 {{vector.broadcast' op dimension mismatch (7 vs. 3)}}
|
||||
%1 = vector.broadcast %arg0 : vector<7xf32> to vector<3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {
|
||||
// expected-error@+1 {{vector.broadcast' op dimension mismatch (4 vs. 1)}}
|
||||
%1 = vector.broadcast %arg0 : vector<4x8xf32> to vector<1x8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_element_vector_type(%arg0: index) {
|
||||
// expected-error@+1 {{expected vector type}}
|
||||
%1 = vector.extractelement %arg0[] : index
|
||||
|
|
|
@ -23,12 +23,16 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>) {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: @vector_broadcast
|
||||
func @vector_broadcast(%a: f32, %b: vector<16xf32>) -> vector<8x16xf32> {
|
||||
func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> {
|
||||
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
|
||||
%0 = vector.broadcast %a : f32 to vector<16xf32>
|
||||
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32>
|
||||
%1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32>
|
||||
return %1 : vector<8x16xf32>
|
||||
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32>
|
||||
%2 = vector.broadcast %c : vector<1x16xf32> to vector<8x16xf32>
|
||||
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32>
|
||||
%3 = vector.broadcast %d : vector<8x1xf32> to vector<8x16xf32>
|
||||
return %3 : vector<8x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extractelement
|
||||
|
|
Loading…
Reference in New Issue