!49278 support ConvBN fusion NCHW
Merge pull request !49278 from KXiong/dev
This commit is contained in:
commit
f4412779ad
|
@ -93,6 +93,27 @@ void GenerateNewWeightConv2DTranspose(float *dst_weight, const float *scale_weig
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// this function should replace GenerateNewWeightConv2DTranspose after all fusions support NCHW
|
||||
void GenerateNewWeightConv2DTranspose_NCHW(float *dst_weight, const float *scale_weight,
|
||||
const tensor::TensorPtr &weight_tensor, int64_t group, int kernel_num) {
|
||||
MS_ASSERT(dst_weight != nullptr && scale_weight != nullptr && weight_tensor != nullptr);
|
||||
if (group <= 0 || kernel_num <= 0) {
|
||||
return;
|
||||
}
|
||||
auto cin_group = weight_tensor->shape()[0] / group;
|
||||
MS_ASSERT(weight_tensor->data_c() != nullptr);
|
||||
auto weight_data = reinterpret_cast<float *>(weight_tensor->data_c());
|
||||
int64_t area_size = weight_tensor->shape()[kNHWC_H] * weight_tensor->shape()[kNHWC_W];
|
||||
for (int64_t k = 0; k < cin_group; ++k) {
|
||||
for (int64_t i = 0; i < kernel_num; ++i) { // output channel num -> C
|
||||
for (int64_t j = 0; j < area_size; j++) { // HW
|
||||
dst_weight[i * area_size + j + k * area_size * kernel_num] =
|
||||
weight_data[i * area_size + j + k * area_size * kernel_num] * scale_weight[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
|
@ -333,7 +354,11 @@ int ConvTransformFusion::CalNewWeightTensor(const CNodePtr &conv_node, const ten
|
|||
auto conv2d_prim_c = conv2d_prim->GetPrim();
|
||||
MS_ASSERT(conv2d_prim_c != nullptr);
|
||||
auto group = conv2d_prim_c->GetAttr(ops::kGroup) == nullptr ? 1 : conv2d_prim->get_group();
|
||||
GenerateNewWeightConv2DTranspose(tmp_weight_data, trans_scale, weight_tensor, group, kernel_num);
|
||||
if (!nchw_format_) {
|
||||
GenerateNewWeightConv2DTranspose(tmp_weight_data, trans_scale, weight_tensor, group, kernel_num);
|
||||
} else {
|
||||
GenerateNewWeightConv2DTranspose_NCHW(tmp_weight_data, trans_scale, weight_tensor, group, kernel_num);
|
||||
}
|
||||
}
|
||||
auto ret = memcpy_s(weight_data, weight_tensor->Size(), tmp_weight_data, data_size);
|
||||
delete[] tmp_weight_data;
|
||||
|
|
|
@ -43,6 +43,7 @@ class ConvTransformFusion : public LitePatternProcessPass {
|
|||
|
||||
protected:
|
||||
FmkType fmk_type_ = converter::kFmkTypeTf;
|
||||
bool nchw_format_ = false;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TRANSFORM_FUSION_H_
|
||||
|
|
Loading…
Reference in New Issue