!49278 support ConvBN fusion NCHW
Merge pull request !49278 from KXiong/dev
This commit is contained in:
commit
f4412779ad
|
@ -93,6 +93,27 @@ void GenerateNewWeightConv2DTranspose(float *dst_weight, const float *scale_weig
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this function should replace GenerateNewWeightConv2DTranspose after all fusions support NCHW
|
||||||
|
void GenerateNewWeightConv2DTranspose_NCHW(float *dst_weight, const float *scale_weight,
|
||||||
|
const tensor::TensorPtr &weight_tensor, int64_t group, int kernel_num) {
|
||||||
|
MS_ASSERT(dst_weight != nullptr && scale_weight != nullptr && weight_tensor != nullptr);
|
||||||
|
if (group <= 0 || kernel_num <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto cin_group = weight_tensor->shape()[0] / group;
|
||||||
|
MS_ASSERT(weight_tensor->data_c() != nullptr);
|
||||||
|
auto weight_data = reinterpret_cast<float *>(weight_tensor->data_c());
|
||||||
|
int64_t area_size = weight_tensor->shape()[kNHWC_H] * weight_tensor->shape()[kNHWC_W];
|
||||||
|
for (int64_t k = 0; k < cin_group; ++k) {
|
||||||
|
for (int64_t i = 0; i < kernel_num; ++i) { // output channel num -> C
|
||||||
|
for (int64_t j = 0; j < area_size; j++) { // HW
|
||||||
|
dst_weight[i * area_size + j + k * area_size * kernel_num] =
|
||||||
|
weight_data[i * area_size + j + k * area_size * kernel_num] * scale_weight[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
|
@ -333,7 +354,11 @@ int ConvTransformFusion::CalNewWeightTensor(const CNodePtr &conv_node, const ten
|
||||||
auto conv2d_prim_c = conv2d_prim->GetPrim();
|
auto conv2d_prim_c = conv2d_prim->GetPrim();
|
||||||
MS_ASSERT(conv2d_prim_c != nullptr);
|
MS_ASSERT(conv2d_prim_c != nullptr);
|
||||||
auto group = conv2d_prim_c->GetAttr(ops::kGroup) == nullptr ? 1 : conv2d_prim->get_group();
|
auto group = conv2d_prim_c->GetAttr(ops::kGroup) == nullptr ? 1 : conv2d_prim->get_group();
|
||||||
|
if (!nchw_format_) {
|
||||||
GenerateNewWeightConv2DTranspose(tmp_weight_data, trans_scale, weight_tensor, group, kernel_num);
|
GenerateNewWeightConv2DTranspose(tmp_weight_data, trans_scale, weight_tensor, group, kernel_num);
|
||||||
|
} else {
|
||||||
|
GenerateNewWeightConv2DTranspose_NCHW(tmp_weight_data, trans_scale, weight_tensor, group, kernel_num);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
auto ret = memcpy_s(weight_data, weight_tensor->Size(), tmp_weight_data, data_size);
|
auto ret = memcpy_s(weight_data, weight_tensor->Size(), tmp_weight_data, data_size);
|
||||||
delete[] tmp_weight_data;
|
delete[] tmp_weight_data;
|
||||||
|
|
|
@ -43,6 +43,7 @@ class ConvTransformFusion : public LitePatternProcessPass {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
FmkType fmk_type_ = converter::kFmkTypeTf;
|
FmkType fmk_type_ = converter::kFmkTypeTf;
|
||||||
|
bool nchw_format_ = false;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::opt
|
} // namespace mindspore::opt
|
||||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TRANSFORM_FUSION_H_
|
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TRANSFORM_FUSION_H_
|
||||||
|
|
Loading…
Reference in New Issue