[VectorOps] Add legality rules to broadcast

PiperOrigin-RevId: 283360101
This commit is contained in:
Aart Bik 2019-12-02 09:56:58 -08:00 committed by A. Unique TensorFlower
parent b41162b3af
commit 3126004a5a
4 changed files with 48 additions and 6 deletions

View File

@ -171,7 +171,24 @@ def Vector_BroadcastOp :
let summary = "broadcast operation";
let description = [{
Broadcasts the scalar or k-D vector value in the source operand
to a n-D result vector such that the broadcast makes sense.
to a n-D result vector such that the broadcast makes sense, i.e.,
the source operand is duplicated to match the given rank and sizes
in the result vector. The legality rules are:
* the source operand must have the same element type as the result type
* a k-D vector <s_1 x .. x s_k x type> can be broadcast to
a n-D vector <t_1 x .. x t_n x type> if
* k <= n, and
* the sizes in the trailing dimensions n-k < i <= n with j=i+k-n
match exactly as s_j = t_i or s_j = 1:
```
t_1 x .. t_n-k x t_n-k+1 x .. x t_i x .. x t_n
s_1 x .. x s_j x .. x s_k
<duplication> <potential stretch>
```
The source operand is duplicated over all the missing leading dimensions
and streched over the trailing dimensions where the source has a non-equal
dimension of 1. These rules imply that any scalar broadcast (k=0) to any
shaped vector with the same element type is always legal.
Examples:
```

View File

@ -386,10 +386,17 @@ static LogicalResult verify(BroadcastOp op) {
if (srcVectorType) {
const int64_t srcRank = srcVectorType.getRank();
const int64_t dstRank = dstVectorType.getRank();
// TODO(ajcbik): implement proper rank testing for broadcast;
// this is just a temporary placeholder check.
if (srcRank > dstRank) {
if (srcRank > dstRank)
return op.emitOpError("source rank higher than destination rank");
// Source has an exact match or singleton value for all trailing dimensions
// (all leading dimensions are simply duplicated).
const int64_t lead = dstRank - srcRank;
for (int64_t i = 0; i < srcRank; i++) {
const int64_t srcDim = srcVectorType.getDimSize(i);
const int64_t dstDim = dstVectorType.getDimSize(lead + i);
if (srcDim != 1 && srcDim != dstDim)
return op.emitOpError("dimension mismatch (")
<< srcDim << " vs. " << dstDim << ")";
}
}
return success();

View File

@ -9,6 +9,20 @@ func @broadcast_rank_too_high(%arg0: vector<4x4xf32>) {
// -----
func @broadcast_dim1_mismatch(%arg0: vector<7xf32>) {
// expected-error@+1 {{vector.broadcast' op dimension mismatch (7 vs. 3)}}
%1 = vector.broadcast %arg0 : vector<7xf32> to vector<3xf32>
}
// -----
func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {
// expected-error@+1 {{vector.broadcast' op dimension mismatch (4 vs. 1)}}
%1 = vector.broadcast %arg0 : vector<4x8xf32> to vector<1x8xf32>
}
// -----
func @extract_element_vector_type(%arg0: index) {
// expected-error@+1 {{expected vector type}}
%1 = vector.extractelement %arg0[] : index

View File

@ -23,12 +23,16 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>) {
}
// CHECK-LABEL: @vector_broadcast
func @vector_broadcast(%a: f32, %b: vector<16xf32>) -> vector<8x16xf32> {
func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> {
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
%0 = vector.broadcast %a : f32 to vector<16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32>
%1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32>
return %1 : vector<8x16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32>
%2 = vector.broadcast %c : vector<1x16xf32> to vector<8x16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32>
%3 = vector.broadcast %d : vector<8x1xf32> to vector<8x16xf32>
return %3 : vector<8x16xf32>
}
// CHECK-LABEL: @extractelement