support arthmetic op 2d with 5d broadcast

This commit is contained in:
greatpanc 2022-02-17 16:34:07 +08:00
parent 0180f27026
commit 035f0384b2
1 changed files with 9 additions and 3 deletions

View File

@ -128,11 +128,11 @@ int ArithmeticOpenCLKernel::InitWeights() {
int ArithmeticOpenCLKernel::SetConstArgs() {
int arg_idx = CLARGSINDEX3;
if (!element_flag_) {
cl_int4 in0_shape = {static_cast<int>(in0_shape_->N), static_cast<int>(in0_shape_->H),
cl_int4 in0_shape = {static_cast<int>(in0_shape_->N), static_cast<int>(in0_shape_->D * in0_shape_->H),
static_cast<int>(in0_shape_->W), static_cast<int>(in0_shape_->Slice)};
cl_int4 in1_shape = {static_cast<int>(in1_shape_->N), static_cast<int>(in1_shape_->H),
cl_int4 in1_shape = {static_cast<int>(in1_shape_->N), static_cast<int>(in1_shape_->D * in1_shape_->H),
static_cast<int>(in1_shape_->W), static_cast<int>(in1_shape_->Slice)};
cl_int4 out_shape = {static_cast<int>(out_shape_->N), static_cast<int>(out_shape_->H),
cl_int4 out_shape = {static_cast<int>(out_shape_->N), static_cast<int>(out_shape_->D * out_shape_->H),
static_cast<int>(out_shape_->W), static_cast<int>(out_shape_->Slice)};
int broadcastC_flag = 0; // do not need broadcast in C4
if (in0_shape_->C == 1 && in1_shape_->C != 1) {
@ -199,6 +199,12 @@ int ArithmeticOpenCLKernel::InitGpuTensorInfoShape() {
in1_shape_switch_flag_ = true;
}
}
if (shape0.size() == DIMENSION_5D && shape1.size() == DIMENSION_2D) {
if (shape0.at(kNDHWC_W) == shape1.at(kNHWC_N) && shape0.at(kNDHWC_C) == shape1.at(kNHWC_H)) {
SwitchGpuTensorInfoNWDim(in1_shape_.get());
in1_shape_switch_flag_ = true;
}
}
return RET_OK;
}