support arthmetic op 2d with 5d broadcast
This commit is contained in:
parent
0180f27026
commit
035f0384b2
|
@ -128,11 +128,11 @@ int ArithmeticOpenCLKernel::InitWeights() {
|
|||
int ArithmeticOpenCLKernel::SetConstArgs() {
|
||||
int arg_idx = CLARGSINDEX3;
|
||||
if (!element_flag_) {
|
||||
cl_int4 in0_shape = {static_cast<int>(in0_shape_->N), static_cast<int>(in0_shape_->H),
|
||||
cl_int4 in0_shape = {static_cast<int>(in0_shape_->N), static_cast<int>(in0_shape_->D * in0_shape_->H),
|
||||
static_cast<int>(in0_shape_->W), static_cast<int>(in0_shape_->Slice)};
|
||||
cl_int4 in1_shape = {static_cast<int>(in1_shape_->N), static_cast<int>(in1_shape_->H),
|
||||
cl_int4 in1_shape = {static_cast<int>(in1_shape_->N), static_cast<int>(in1_shape_->D * in1_shape_->H),
|
||||
static_cast<int>(in1_shape_->W), static_cast<int>(in1_shape_->Slice)};
|
||||
cl_int4 out_shape = {static_cast<int>(out_shape_->N), static_cast<int>(out_shape_->H),
|
||||
cl_int4 out_shape = {static_cast<int>(out_shape_->N), static_cast<int>(out_shape_->D * out_shape_->H),
|
||||
static_cast<int>(out_shape_->W), static_cast<int>(out_shape_->Slice)};
|
||||
int broadcastC_flag = 0; // do not need broadcast in C4
|
||||
if (in0_shape_->C == 1 && in1_shape_->C != 1) {
|
||||
|
@ -199,6 +199,12 @@ int ArithmeticOpenCLKernel::InitGpuTensorInfoShape() {
|
|||
in1_shape_switch_flag_ = true;
|
||||
}
|
||||
}
|
||||
if (shape0.size() == DIMENSION_5D && shape1.size() == DIMENSION_2D) {
|
||||
if (shape0.at(kNDHWC_W) == shape1.at(kNHWC_N) && shape0.at(kNDHWC_C) == shape1.at(kNHWC_H)) {
|
||||
SwitchGpuTensorInfoNWDim(in1_shape_.get());
|
||||
in1_shape_switch_flag_ = true;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue