forked from mindspore-Ecosystem/mindspore
add identity pass and adjust log
This commit is contained in:
parent
b8fbabae34
commit
920cbb1e22
|
@ -57,6 +57,11 @@ constexpr int RET_INFER_INVALID = -501; /**< Invalid infer shape before runtime.
|
|||
|
||||
/* User input param error code, range: [-600, 700)*/
|
||||
constexpr int RET_INPUT_PARAM_INVALID = -600; /**< Invalid input param by user. */
|
||||
|
||||
/// \brief Print description of errorcode.
|
||||
///
|
||||
/// \param[in] error_code define return status of procedure.
|
||||
void PrintErrorInfo(STATUS error_code);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ set(LITE_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/model.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc
|
||||
)
|
||||
|
||||
if (SUPPORT_GPU)
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "include/errorcode.h"
|
||||
#include <map>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
void PrintErrorInfo(STATUS status) {
|
||||
std::map<int, std::string> info_map = {{RET_OK, "No error occurs."},
|
||||
{RET_ERROR, "Common error code."},
|
||||
{RET_NULL_PTR, "NULL pointer returned."},
|
||||
{RET_PARAM_INVALID, "Invalid parameter."},
|
||||
{RET_NO_CHANGE, "No change."},
|
||||
{RET_SUCCESS_EXIT, "No error but exit."},
|
||||
{RET_MEMORY_FAILED, "Fail to create memory."},
|
||||
{RET_NOT_SUPPORT, "Fail to support."},
|
||||
{RET_OUT_OF_TENSOR_RANGE, "Failed to check range."},
|
||||
{RET_INPUT_TENSOR_ERROR, "Failed to check input tensor."},
|
||||
{RET_REENTRANT_ERROR, "Exist executor running."},
|
||||
{RET_GRAPH_FILE_ERR, "Failed to verify graph file."},
|
||||
{RET_NOT_FIND_OP, "Failed to find operator."},
|
||||
{RET_INVALID_OP_NAME, "Invalid operator name."},
|
||||
{RET_INVALID_OP_ATTR, "Invalid operator attr."},
|
||||
{RET_OP_EXECUTE_FAILURE, "Failed to execution operator."},
|
||||
{RET_FORMAT_ERR, "Failed to checking tensor format."},
|
||||
{RET_INFER_ERR, "Failed to infer shape."},
|
||||
{RET_INFER_INVALID, "Invalid infer shape before runtime."},
|
||||
{RET_INPUT_PARAM_INVALID, "Invalid input param by user."}};
|
||||
std::cout << info_map[status] << std::endl;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_IDENTITY_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_IDENTITY_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Identity : public PrimitiveC {
|
||||
public:
|
||||
MS_DECLARE_PARENT(Identity, PrimitiveC);
|
||||
Identity() = default;
|
||||
explicit Identity(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_IDENTITY_H_
|
|
@ -137,6 +137,7 @@
|
|||
#include "src/ops/upsample.h"
|
||||
#include "src/ops/layer_norm.h"
|
||||
#include "src/ops/non_max_suppression.h"
|
||||
#include "src/ops/identity.h"
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
#include "src/ops/neg_grad.h"
|
||||
|
@ -729,6 +730,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
|
|||
return new LayerNorm(primitive);
|
||||
case schema::PrimitiveType_NonMaxSuppression:
|
||||
return new NonMaxSuppression(primitive);
|
||||
case schema::PrimitiveType_Identity:
|
||||
return new Identity(primitive);
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
case schema::PrimitiveType_ActivationGrad:
|
||||
|
|
|
@ -134,6 +134,7 @@ set(TEST_LITE_SRC
|
|||
${LITE_DIR}/tools/common/storage.cc
|
||||
${LITE_DIR}/tools/benchmark/benchmark.cc
|
||||
${LITE_DIR}/test/st/benchmark_test.cc
|
||||
${LITE_DIR}/src/errorcode.cc
|
||||
)
|
||||
### gpu runtime
|
||||
if (SUPPORT_GPU)
|
||||
|
@ -184,6 +185,7 @@ if(ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc
|
||||
)
|
||||
endif()
|
||||
### train
|
||||
|
|
|
@ -46,6 +46,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/graph/weight_format_hardcode_pass.cc
|
||||
../optimizer/graph/clip_convert_activation_pass.cc
|
||||
../optimizer/graph/unused_cast_node_remove_pass.cc
|
||||
../optimizer/graph/identity_remove_pass.cc
|
||||
)
|
||||
|
||||
add_subdirectory(../anf_importer anf_importer)
|
||||
|
@ -75,6 +76,7 @@ set(LITE_SRC
|
|||
${SRC_DIR}/executor.cc
|
||||
${SRC_DIR}/model.cc
|
||||
${SRC_DIR}/model_common.cc
|
||||
${SRC_DIR}/errorcode.cc
|
||||
)
|
||||
if (SUPPORT_TRAIN)
|
||||
set(LITE_SRC
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
|
||||
#include "tools/optimizer/fusion/layer_norm_fusion.h"
|
||||
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
|
||||
#include "tools/optimizer/graph/identity_remove_pass.h"
|
||||
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
|
||||
#include "tools/optimizer/graph/weight_format_transform_pass.h"
|
||||
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
|
||||
|
@ -53,6 +54,11 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
// for now - trainning is not supporting fuse operations
|
||||
if (config != nullptr && config->trainModel == false) {
|
||||
// remove quantdtype when awaretraining
|
||||
if (config->fmk == lite::converter::FmkType_ONNX) {
|
||||
auto remove_identity_pass = std::make_shared<opt::RemoveIdentityOpPass>();
|
||||
remove_identity_pass->SetFmkType(config->fmk);
|
||||
pm->AddPass(remove_identity_pass);
|
||||
}
|
||||
if (config->quantType == QuantType_AwareTraining) {
|
||||
pm->AddPass(std::make_shared<opt::QuantDtypeCastFusion>());
|
||||
}
|
||||
|
|
|
@ -109,6 +109,7 @@ int RunConverter(int argc, const char **argv) {
|
|||
if (flags == nullptr) {
|
||||
MS_LOG(ERROR) << "new flags error ";
|
||||
std::cout << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << std::endl;
|
||||
PrintErrorInfo(RET_MEMORY_FAILED);
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
auto status = flags->Init(argc, argv);
|
||||
|
@ -117,6 +118,7 @@ int RunConverter(int argc, const char **argv) {
|
|||
MS_LOG(ERROR) << "converter::Flags Init failed: " << status;
|
||||
std::cout << "CONVERTER::FLAGS INIT FAILED:" << status << std::endl;
|
||||
}
|
||||
PrintErrorInfo(status);
|
||||
return status;
|
||||
}
|
||||
// Load graph
|
||||
|
@ -148,6 +150,7 @@ int RunConverter(int argc, const char **argv) {
|
|||
default: {
|
||||
MS_LOG(ERROR) << "Unsupported fmkType: " << flags->fmk;
|
||||
std::cout << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << std::endl;
|
||||
PrintErrorInfo(RET_INPUT_PARAM_INVALID);
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
@ -156,6 +159,7 @@ int RunConverter(int argc, const char **argv) {
|
|||
if (fb_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert model return nullptr";
|
||||
std::cout << "CONVERT RESULT FAILED:" << status << std::endl;
|
||||
PrintErrorInfo(status);
|
||||
return status;
|
||||
}
|
||||
|
||||
|
@ -163,15 +167,17 @@ int RunConverter(int argc, const char **argv) {
|
|||
Storage storage;
|
||||
fb_graph->version = Version();
|
||||
status = storage.Save(*fb_graph, flags->outputFile);
|
||||
if (status != 0) {
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Save graph to file failed";
|
||||
std::cout << "SAVE GRAPH FAILED:" << status << std::endl;
|
||||
PrintErrorInfo(status);
|
||||
return status;
|
||||
}
|
||||
|
||||
delete fb_graph;
|
||||
MS_LOG(INFO) << "CONVERT RESULT: SUCCESS!";
|
||||
std::cout << "CONVERT RESULT SUCCESS:" << status << std::endl;
|
||||
PrintErrorInfo(status);
|
||||
return status;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -270,7 +270,7 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
|||
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
||||
interrupt = true;
|
||||
return RET_NOT_FIND_OP;
|
||||
int status = ParseLoopAttr(dst_op, onnx_node, quantType, dst_graph);
|
||||
int status = ParseSubgraph(dst_op, onnx_node, quantType, dst_graph);
|
||||
if (status != RET_OK || interrupt) {
|
||||
interrupt = true;
|
||||
return status;
|
||||
|
@ -496,7 +496,7 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
|
|||
}
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::ParseLoopAttr(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node,
|
||||
STATUS OnnxModelParser::ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node,
|
||||
const QuantType &quantType, schema::MetaGraphT *dst_graph) {
|
||||
MS_LOG(DEBUG) << "onnx LoopParser";
|
||||
if (dst_op == nullptr) {
|
||||
|
|
|
@ -88,7 +88,7 @@ class OnnxModelParser : public ModelParser {
|
|||
|
||||
void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph);
|
||||
|
||||
STATUS ParseLoopAttr(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, const QuantType &quantType,
|
||||
STATUS ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, const QuantType &quantType,
|
||||
schema::MetaGraphT *dst_graph);
|
||||
|
||||
private:
|
||||
|
|
|
@ -61,7 +61,7 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
|
|||
}
|
||||
}
|
||||
if (input_shape.int64_data_size() == 0) {
|
||||
MS_LOG(WARNING) << "shape maybe from another op other than const initializer";
|
||||
MS_LOG(INFO) << "shape maybe from another op other than const initializer";
|
||||
} else {
|
||||
for (int i = 0; i < input_shape.int64_data_size(); ++i) {
|
||||
shape.push_back(input_shape.int64_data(i));
|
||||
|
|
|
@ -70,13 +70,13 @@ STATUS AwareQuantizer::GenerateQuantParam() {
|
|||
}
|
||||
auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node));
|
||||
if (quantParamCalcer == nullptr) {
|
||||
MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str()
|
||||
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
|
||||
MS_LOG(INFO) << "Can not find QuantParamCalcer for " << node->name.c_str()
|
||||
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
|
||||
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
|
||||
} else {
|
||||
auto status = quantParamCalcer->Calc(graph, *node);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
|
||||
MS_LOG(INFO) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
|
||||
node->quantType = schema::QuantType_QUANT_NONE;
|
||||
} else {
|
||||
node->quantType = schema::QuantType_AwareTraining;
|
||||
|
@ -103,7 +103,7 @@ STATUS AwareQuantizer::DoQuantize() {
|
|||
GetCNodeTType(*node) == schema::PrimitiveType_MatMul) {
|
||||
auto inputIndexes = node->inputIndex;
|
||||
if (inputIndexes.size() < 2) {
|
||||
MS_LOG(WARNING) << node->name.c_str() << " node input has invalid inputs tensor count";
|
||||
MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// quant weight
|
||||
|
@ -111,7 +111,7 @@ STATUS AwareQuantizer::DoQuantize() {
|
|||
if (!weightTensor->quantParams.empty() && weightTensor->quantParams.at(0)->inited) {
|
||||
status = QuantConvWeight(graph, node.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "QuantConvWeight failed!";
|
||||
MS_LOG(ERROR) << "QuantConvWeight failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ STATUS AwareQuantizer::DoQuantize() {
|
|||
if (!biasTensor->quantParams.empty() && biasTensor->quantParams.at(0)->inited) {
|
||||
status = QuantConvBias(graph, node.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "QuantConvBias failed!";
|
||||
MS_LOG(ERROR) << "QuantConvBias failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
@ -129,7 +129,7 @@ STATUS AwareQuantizer::DoQuantize() {
|
|||
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) {
|
||||
status = QuantDetectionPostProcessConstTensor(graph, node.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "QuantDetectionPostProcessConstTensor failed!";
|
||||
MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (GetCNodeTType(*node) == schema::PrimitiveType_Add ||
|
||||
|
@ -137,7 +137,7 @@ STATUS AwareQuantizer::DoQuantize() {
|
|||
GetCNodeTType(*node) == schema::PrimitiveType_Mul) {
|
||||
status = QuantArithmeticConstTensor(graph, node.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "QuantArithmeticConstTensor failed!";
|
||||
MS_LOG(ERROR) << "QuantArithmeticConstTensor failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
@ -168,7 +168,7 @@ STATUS AwareQuantizer::QuantArithmeticConstTensor(const schema::MetaGraphT *grap
|
|||
}
|
||||
if (inTensor->dataType != TypeId::kNumberTypeFloat32 && inTensor->dataType != TypeId::kNumberTypeFloat &&
|
||||
inTensor->dataType != TypeId::kNumberTypeUInt8) {
|
||||
MS_LOG(WARNING) << node->name.c_str() << "'s weight data is not float or uint8";
|
||||
MS_LOG(ERROR) << node->name.c_str() << "'s weight data is not float or uint8";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
|
@ -303,7 +303,7 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schem
|
|||
}
|
||||
if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && weightTensor->dataType != TypeId::kNumberTypeFloat &&
|
||||
weightTensor->dataType != TypeId::kNumberTypeUInt8) {
|
||||
MS_LOG(WARNING) << "conv " << node->name.c_str() << "'s weight data is not float or uint8";
|
||||
MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t wShapeSize = GetShapeSize(*(weightTensor.get()));
|
||||
|
|
|
@ -33,7 +33,7 @@ STATUS QuantParamCalcer::ComputeConstQuantParam(const schema::TensorT &tensor, Q
|
|||
return RET_OK;
|
||||
}
|
||||
if (tensor.dataType != TypeId::kNumberTypeFloat) {
|
||||
MS_LOG(WARNING) << "Const Tensor without quantParam should has float dataType, in fact: " << tensor.dataType;
|
||||
MS_LOG(ERROR) << "Const Tensor without quantParam should has float dataType, in fact: " << tensor.dataType;
|
||||
return RET_ERROR;
|
||||
}
|
||||
const auto *constData = reinterpret_cast<const float *>(tensor.data.data());
|
||||
|
@ -83,7 +83,7 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
|
|||
if (!tensor->data.empty() && !IsContain(graph->inputIndex, node.inputIndex.at(i))) {
|
||||
auto status = ComputeConstQuantParam((*tensor), quantParam.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "ComputeConstQuantParam failed: " << status;
|
||||
MS_LOG(INFO) << "ComputeConstQuantParam failed: " << status;
|
||||
return status;
|
||||
}
|
||||
tensor->quantParams.front() = std::move(quantParam);
|
||||
|
@ -112,15 +112,15 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
|
|||
int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) {
|
||||
auto status = QuantParamCalcer::Calc(subGraph, node);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status;
|
||||
MS_LOG(ERROR) << "Call QuantParamCalcer::Calc failed: " << status;
|
||||
return status;
|
||||
}
|
||||
if (inputParamDone != node.inputIndex.size()) {
|
||||
MS_LOG(WARNING) << "Can not determine inputTensor quantParam, node " << node.name;
|
||||
MS_LOG(ERROR) << "Can not determine inputTensor quantParam, node " << node.name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (outputParamDone != node.outputIndex.size()) {
|
||||
MS_LOG(WARNING) << "Can not determine outputTensor quantParam, node " << node.name;
|
||||
MS_LOG(ERROR) << "Can not determine outputTensor quantParam, node " << node.name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -129,7 +129,7 @@ int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) {
|
|||
int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
|
||||
auto status = QuantParamCalcer::Calc(graph, node);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status;
|
||||
MS_LOG(ERROR) << "Call QuantParamCalcer::Calc failed: " << status;
|
||||
return status;
|
||||
}
|
||||
if (inputParamDone != node.inputIndex.size()) {
|
||||
|
@ -139,7 +139,7 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
|
|||
auto outputQuantParam = GetTensorQuantParam(outTensor);
|
||||
MS_ASSERT(outputQuantParam != nullptr);
|
||||
if (outputQuantParam == nullptr || !outputQuantParam->inited) {
|
||||
MS_LOG(WARNING) << "Can not determine inputTensor quantParam from outputTensor for node " << node.name;
|
||||
MS_LOG(ERROR) << "Can not determine inputTensor quantParam from outputTensor for node " << node.name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (unsigned int i : node.inputIndex) {
|
||||
|
@ -159,7 +159,7 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
|
|||
MS_ASSERT(inTensor != nullptr);
|
||||
auto inQuantParam = GetTensorQuantParam(inTensor);
|
||||
if (inQuantParam == nullptr || !inQuantParam->inited) {
|
||||
MS_LOG(WARNING) << "Can not determine outputTensor quantParam from inputTensor for node %s" << node.name;
|
||||
MS_LOG(ERROR) << "Can not determine outputTensor quantParam from inputTensor for node %s" << node.name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < node.outputIndex.size(); i++) {
|
||||
|
@ -188,12 +188,12 @@ class CalcConcat : public QuantParamCalcer {
|
|||
MS_ASSERT(node.outputIndex.size() == 1);
|
||||
auto status = QuantParamCalcer::Calc(graph, node);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status;
|
||||
MS_LOG(ERROR) << "Call QuantParamCalcer::Calc failed: " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
if (inputParamDone != node.inputIndex.size()) {
|
||||
MS_LOG(WARNING) << "Can not determine concat inputTensor quantParam, node " << node.name;
|
||||
MS_LOG(ERROR) << "Can not determine concat inputTensor quantParam, node " << node.name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
|
@ -233,7 +233,7 @@ class CalcConcat : public QuantParamCalcer {
|
|||
|
||||
status = quant::CalQuantizationParams(outQuantParam.get(), minMin, maxMax, narrowRange, numBits);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!";
|
||||
MS_LOG(ERROR) << "in aware quantization run CalQuantizationParams failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
outTensor->quantParams.emplace_back(std::move(outQuantParam));
|
||||
|
@ -253,12 +253,12 @@ class CalcAdd : public QuantParamCalcer {
|
|||
MS_ASSERT(node.outputIndex.size() == 1);
|
||||
auto status = QuantParamCalcer::Calc(graph, node);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status;
|
||||
MS_LOG(ERROR) << "Call QuantParamCalcer::Calc failed: " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
if (inputParamDone != 2) {
|
||||
MS_LOG(WARNING) << "Can not determine add inputTensor quantParam, node " << node.name;
|
||||
MS_LOG(ERROR) << "Can not determine add inputTensor quantParam, node " << node.name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (outputParamDone != 1) {
|
||||
|
@ -283,7 +283,7 @@ class CalcAdd : public QuantParamCalcer {
|
|||
biasTensor = &tensor1;
|
||||
paramTensor = &tensor0;
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Can not determine add outputTensor quantParam, node " << node.name;
|
||||
MS_LOG(ERROR) << "Can not determine add outputTensor quantParam, node " << node.name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto quantParam = GetTensorQuantParam(*paramTensor);
|
||||
|
@ -298,7 +298,7 @@ class CalcAdd : public QuantParamCalcer {
|
|||
auto *bias = static_cast<float *>(oriTensorData);
|
||||
status = quant::CalQuantizationParams(outQuantParam.get(), min + (*bias), max + (*bias));
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!";
|
||||
MS_LOG(ERROR) << "in aware quantization run CalQuantizationParams failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if ((*biasTensor)->dataType == TypeId::kNumberTypeUInt8) {
|
||||
|
@ -307,11 +307,11 @@ class CalcAdd : public QuantParamCalcer {
|
|||
auto *bias = static_cast<uint8_t *>(oriTensorData);
|
||||
status = quant::CalQuantizationParams(outQuantParam.get(), min + (*bias), max + (*bias));
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!";
|
||||
MS_LOG(ERROR) << "in aware quantization run CalQuantizationParams failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Unsupported tensor dataType: " << (*biasTensor)->dataType;
|
||||
MS_LOG(ERROR) << "Unsupported tensor dataType: " << (*biasTensor)->dataType;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
@ -330,12 +330,12 @@ class CalcRealDiv : public QuantParamCalcer {
|
|||
MS_ASSERT(node.outputIndex.size() == 1);
|
||||
auto status = QuantParamCalcer::Calc(graph, node);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status;
|
||||
MS_LOG(ERROR) << "Call QuantParamCalcer::Calc failed: " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
if (inputParamDone != 2) {
|
||||
MS_LOG(WARNING) << "Can not determine realdiv inputTensor quantParam, node " << node.name;
|
||||
MS_LOG(ERROR) << "Can not determine realdiv inputTensor quantParam, node " << node.name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (outputParamDone != 1) {
|
||||
|
@ -361,7 +361,7 @@ class CalcRealDiv : public QuantParamCalcer {
|
|||
MS_ASSERT(*div != 0);
|
||||
status = quant::CalQuantizationParams(outQuantParam.get(), min / (*div), max / (*div));
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!";
|
||||
MS_LOG(ERROR) << "in aware quantization run CalQuantizationParams failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (tensor1->dataType == TypeId::kNumberTypeUInt8) {
|
||||
|
@ -370,17 +370,17 @@ class CalcRealDiv : public QuantParamCalcer {
|
|||
auto *div = static_cast<uint8_t *>(oriTensorData);
|
||||
status = quant::CalQuantizationParams(outQuantParam.get(), min / (*div), max + (*div));
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!";
|
||||
MS_LOG(ERROR) << "in aware quantization run CalQuantizationParams failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Unsupported tensor dataType: " << tensor1->dataType;
|
||||
MS_LOG(ERROR) << "Unsupported tensor dataType: " << tensor1->dataType;
|
||||
return RET_ERROR;
|
||||
}
|
||||
outTensor->quantParams.front() = std::move(outQuantParam);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Can not determine realDiv outputTensor quantParam, node " << node.name;
|
||||
MS_LOG(ERROR) << "Can not determine realDiv outputTensor quantParam, node " << node.name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
@ -397,19 +397,19 @@ class CalcToSet : public QuantParamCalcer {
|
|||
MS_ASSERT(node.outputIndex.size() == 1);
|
||||
auto status = QuantParamCalcer::Calc(graph, node);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: %d" << status;
|
||||
MS_LOG(ERROR) << "Call QuantParamCalcer::Calc failed: %d" << status;
|
||||
return status;
|
||||
}
|
||||
// input
|
||||
if (inputParamDone != node.inputIndex.size()) {
|
||||
MS_LOG(WARNING) << "Can not determine inputTensor quantParam, node " << node.name;
|
||||
MS_LOG(ERROR) << "Can not determine inputTensor quantParam, node " << node.name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
// output
|
||||
if (outputParamDone != node.outputIndex.size()) {
|
||||
std::unique_ptr<QuantParamT> quantParam = std::make_unique<QuantParamT>();
|
||||
if (quantParam == nullptr) {
|
||||
MS_LOG(WARNING) << "new QuantParamT failed";
|
||||
MS_LOG(ERROR) << "new QuantParamT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
quantParam->scale = (max - min) / 256;
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/graph/identity_remove_pass.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "mindspore/lite/include/errorcode.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) {
|
||||
if (this->fmk_type != lite::converter::FmkType_ONNX) {
|
||||
MS_LOG(INFO) << "The framework type of model should be onnx.";
|
||||
return RET_OK;
|
||||
}
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto type = opt::GetCNodeType(node);
|
||||
if (type != schema::PrimitiveType_Identity) {
|
||||
continue;
|
||||
}
|
||||
auto identity_cnode = node->cast<CNodePtr>();
|
||||
if (identity_cnode->inputs().size() != lite::kDoubleNum) {
|
||||
MS_LOG(ERROR) << "The `node input is a single tensor";
|
||||
return RET_ERROR;
|
||||
}
|
||||
manager->Replace(node, identity_cnode->input(1));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
class RemoveIdentityOpPass : public Pass {
|
||||
public:
|
||||
RemoveIdentityOpPass() : Pass("remove_identity_pass") {}
|
||||
~RemoveIdentityOpPass() override = default;
|
||||
void SetFmkType(FmkType fmkType) { this->fmk_type = fmkType; }
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
FmkType fmk_type = lite::converter::FmkType_ONNX;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_
|
Loading…
Reference in New Issue