forked from mindspore-Ecosystem/mindspore
!4144 fix anf transform bug
Merge pull request !4144 from zhengjun10/master
This commit is contained in:
commit
7f53253b56
|
@ -90,7 +90,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// graph = anfTransform->Transform(graph);
|
||||
graph = anfTransform->Transform(graph);
|
||||
|
||||
CreateQuantizer(graph, flag);
|
||||
if (mQuantizer != nullptr) {
|
||||
|
|
|
@ -100,20 +100,20 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
// }
|
||||
|
||||
// fusion
|
||||
{
|
||||
Optimizer fusionOptimizer;
|
||||
fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
status = fusionOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
// {
|
||||
// Optimizer fusionOptimizer;
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
// status = fusionOptimizer.Run(graphDefT);
|
||||
// if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
// MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
|
||||
// return status;
|
||||
// }
|
||||
// }
|
||||
|
||||
// weight format trans
|
||||
if (ctx.formatTrans) {
|
||||
|
|
|
@ -89,10 +89,10 @@ void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, c
|
|||
auto add_weight_param = bias_add_weight->cast<ParameterPtr>()->default_param();
|
||||
auto add_weight_tensor = std::dynamic_pointer_cast<ParamValueLite>(add_weight_param);
|
||||
auto add_weight_data = reinterpret_cast<float *>(add_weight_tensor->tensor_addr());
|
||||
|
||||
if (add_weight_tensor->tensor_shape().empty()) {
|
||||
if (EOK != memset_s(add_bias_data, kernel_nums * sizeof(float), *add_weight_data, kernel_nums * sizeof(float))) {
|
||||
MS_LOG(EXCEPTION) << "memset_s conv_bias_data failed";
|
||||
auto add_weight_shape = add_weight_tensor->tensor_shape();
|
||||
if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] ==1)) {
|
||||
for (size_t i = 0; i < kernel_nums; i++) {
|
||||
add_bias_data[i] = *add_weight_data;
|
||||
}
|
||||
} else {
|
||||
if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) {
|
||||
|
|
|
@ -145,8 +145,8 @@ const {
|
|||
// conv has bias,bias_flag true
|
||||
bool bias_flag = false;
|
||||
if (conv_bias_node != nullptr) {
|
||||
auto bias_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
|
||||
auto bias_tensor = std::dynamic_pointer_cast<ParamValueLite>(bias_weight_param);
|
||||
auto conv_bias_param = conv_bias_node->cast<ParameterPtr>()->default_param();
|
||||
auto bias_tensor = std::dynamic_pointer_cast<ParamValueLite>(conv_bias_param);
|
||||
bias_data = reinterpret_cast<float *>(bias_tensor->tensor_addr());
|
||||
bias_flag = true;
|
||||
} else {
|
||||
|
@ -187,7 +187,7 @@ const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_nu
|
|||
MS_ASSERT(bias_data != nullptr);
|
||||
if (bias_flag) {
|
||||
auto tmp_bias_data = new(std::nothrow) float[kernel_num];
|
||||
if (EOK != memset_s(bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float))) {
|
||||
if (EOK != memset_s(tmp_bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float))) {
|
||||
MS_LOG(EXCEPTION) << "memset bias data failed";
|
||||
}
|
||||
for (size_t i = 0; i < kernel_num; i++) {
|
||||
|
|
Loading…
Reference in New Issue