forked from mindspore-Ecosystem/mindspore
!7166 gpu no broad cast kernel dims exceed
Merge pull request !7166 from chenweifeng/broadcast-grad-dims-exceed
This commit is contained in:
commit
b619af917f
|
@ -83,17 +83,23 @@ class BroadcastOpGpuKernel : public GpuKernel {
|
||||||
rhs_shape_.resize(MAX_DIMS, 1);
|
rhs_shape_.resize(MAX_DIMS, 1);
|
||||||
output_shape_.resize(MAX_DIMS, 1);
|
output_shape_.resize(MAX_DIMS, 1);
|
||||||
for (size_t i = 0; i < shape3.size(); i++) {
|
for (size_t i = 0; i < shape3.size(); i++) {
|
||||||
output_shape_[i] = shape3[i];
|
if (need_broadcast_) {
|
||||||
|
output_shape_[i] = shape3[i];
|
||||||
|
}
|
||||||
output_num_ *= shape3[i];
|
output_num_ *= shape3[i];
|
||||||
}
|
}
|
||||||
int lhs_offset = shape3.size() - shape1.size();
|
int lhs_offset = shape3.size() - shape1.size();
|
||||||
for (size_t j = 0; j < shape1.size(); j++) {
|
for (size_t j = 0; j < shape1.size(); j++) {
|
||||||
lhs_shape_[j + lhs_offset] = shape1[j];
|
if (need_broadcast_) {
|
||||||
|
lhs_shape_[j + lhs_offset] = shape1[j];
|
||||||
|
}
|
||||||
input1_num_ *= shape1[j];
|
input1_num_ *= shape1[j];
|
||||||
}
|
}
|
||||||
int rhs_offset = shape3.size() - shape2.size();
|
int rhs_offset = shape3.size() - shape2.size();
|
||||||
for (size_t k = 0; k < shape2.size(); k++) {
|
for (size_t k = 0; k < shape2.size(); k++) {
|
||||||
rhs_shape_[k + rhs_offset] = shape2[k];
|
if (need_broadcast_) {
|
||||||
|
rhs_shape_[k + rhs_offset] = shape2[k];
|
||||||
|
}
|
||||||
input2_num_ *= shape2[k];
|
input2_num_ *= shape2[k];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -78,17 +78,23 @@ class BroadcastOpGradGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < shape3.size(); i++) {
|
for (size_t i = 0; i < shape3.size(); i++) {
|
||||||
dy_shape_[i] = shape3[i];
|
if (need_broadcast_) {
|
||||||
|
dy_shape_[i] = shape3[i];
|
||||||
|
}
|
||||||
output_num_ *= shape3[i];
|
output_num_ *= shape3[i];
|
||||||
}
|
}
|
||||||
int x1_offset = shape3.size() - shape1.size();
|
int x1_offset = shape3.size() - shape1.size();
|
||||||
for (size_t i = 0; i < shape1.size(); i++) {
|
for (size_t i = 0; i < shape1.size(); i++) {
|
||||||
x1_shape_[i + x1_offset] = shape1[i];
|
if (need_broadcast_) {
|
||||||
|
x1_shape_[i + x1_offset] = shape1[i];
|
||||||
|
}
|
||||||
input1_num_ *= shape1[i];
|
input1_num_ *= shape1[i];
|
||||||
}
|
}
|
||||||
int x2_offset = shape3.size() - shape2.size();
|
int x2_offset = shape3.size() - shape2.size();
|
||||||
for (size_t i = 0; i < shape2.size(); i++) {
|
for (size_t i = 0; i < shape2.size(); i++) {
|
||||||
x2_shape_[i + x2_offset] = shape2[i];
|
if (need_broadcast_) {
|
||||||
|
x2_shape_[i + x2_offset] = shape2[i];
|
||||||
|
}
|
||||||
input2_num_ *= shape2[i];
|
input2_num_ *= shape2[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue