!49278 support ConvBN fusion NCHW

Merge pull request !49278 from KXiong/dev
This commit is contained in:
i-robot 2023-02-25 02:24:29 +00:00 committed by Gitee
commit f4412779ad
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 27 additions and 1 deletions

View File

@ -93,6 +93,27 @@ void GenerateNewWeightConv2DTranspose(float *dst_weight, const float *scale_weig
}
}
}
// this function should replace GenerateNewWeightConv2DTranspose after all fusions support NCHW
void GenerateNewWeightConv2DTranspose_NCHW(float *dst_weight, const float *scale_weight,
const tensor::TensorPtr &weight_tensor, int64_t group, int kernel_num) {
MS_ASSERT(dst_weight != nullptr && scale_weight != nullptr && weight_tensor != nullptr);
if (group <= 0 || kernel_num <= 0) {
return;
}
auto cin_group = weight_tensor->shape()[0] / group;
MS_ASSERT(weight_tensor->data_c() != nullptr);
auto weight_data = reinterpret_cast<float *>(weight_tensor->data_c());
int64_t area_size = weight_tensor->shape()[kNHWC_H] * weight_tensor->shape()[kNHWC_W];
for (int64_t k = 0; k < cin_group; ++k) {
for (int64_t i = 0; i < kernel_num; ++i) { // output channel num -> C
for (int64_t j = 0; j < area_size; j++) { // HW
dst_weight[i * area_size + j + k * area_size * kernel_num] =
weight_data[i * area_size + j + k * area_size * kernel_num] * scale_weight[i];
}
}
}
}
} // namespace
const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
@ -333,7 +354,11 @@ int ConvTransformFusion::CalNewWeightTensor(const CNodePtr &conv_node, const ten
auto conv2d_prim_c = conv2d_prim->GetPrim();
MS_ASSERT(conv2d_prim_c != nullptr);
auto group = conv2d_prim_c->GetAttr(ops::kGroup) == nullptr ? 1 : conv2d_prim->get_group();
GenerateNewWeightConv2DTranspose(tmp_weight_data, trans_scale, weight_tensor, group, kernel_num);
if (!nchw_format_) {
GenerateNewWeightConv2DTranspose(tmp_weight_data, trans_scale, weight_tensor, group, kernel_num);
} else {
GenerateNewWeightConv2DTranspose_NCHW(tmp_weight_data, trans_scale, weight_tensor, group, kernel_num);
}
}
auto ret = memcpy_s(weight_data, weight_tensor->Size(), tmp_weight_data, data_size);
delete[] tmp_weight_data;

View File

@ -43,6 +43,7 @@ class ConvTransformFusion : public LitePatternProcessPass {
protected:
FmkType fmk_type_ = converter::kFmkTypeTf;
bool nchw_format_ = false;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TRANSFORM_FUSION_H_