solve aware quantizer memory problem

This commit is contained in:
cjh9368 2020-08-13 10:49:38 +08:00
parent 18c6ac9988
commit 2ae2c3ceca
2 changed files with 8 additions and 8 deletions

View File

@ -293,13 +293,13 @@ STATUS AwareQuantizer::GenerateQuantParam() {
MS_ASSERT(graph->inputIndex.size() == 1); MS_ASSERT(graph->inputIndex.size() == 1);
// set graphInputNode input // set graphInputNode input
for (auto graphInputIndex : graph->inputIndex) { for (auto graphInputIndex : graph->inputIndex) {
auto status = mInputArray->SetInputArrayQP(graph.get(), graphInputIndex); auto status = mInputArray->SetInputArrayQP(graph, graphInputIndex);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "SetInputArrayQP failed"; MS_LOG(ERROR) << "SetInputArrayQP failed";
return status; return status;
} }
} }
auto status = GenerateDefaultQuantParam(graph.get()); auto status = GenerateDefaultQuantParam(graph);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "GenerateDefaultQuantParam failed"; MS_LOG(ERROR) << "GenerateDefaultQuantParam failed";
return status; return status;
@ -319,7 +319,7 @@ STATUS AwareQuantizer::GenerateQuantParam() {
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE); node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
} else { } else {
status = quantParamCalcer->Calc(graph.get(), *node); status = quantParamCalcer->Calc(graph, *node);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
node->quantType = schema::QuantType_QUANT_NONE; node->quantType = schema::QuantType_QUANT_NONE;
@ -349,27 +349,27 @@ STATUS AwareQuantizer::DoQuantize() {
return RET_ERROR; return RET_ERROR;
} }
// quant weight // quant weight
status = QuantConvWeight(graph.get(), node.get()); status = QuantConvWeight(graph, node.get());
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantConvWeight failed!"; MS_LOG(ERROR) << "QuantConvWeight failed!";
return RET_ERROR; return RET_ERROR;
} }
// quant bias // quant bias
if (inputIndexes.size() == 3) { if (inputIndexes.size() == 3) {
status = QuantConvBias(graph.get(), node.get()); status = QuantConvBias(graph, node.get());
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantConvBias failed!"; MS_LOG(ERROR) << "QuantConvBias failed!";
return RET_ERROR; return RET_ERROR;
} }
} }
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { } else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) {
status = QuantDetectionPostProcessConstTensor(graph.get(), node.get()); status = QuantDetectionPostProcessConstTensor(graph, node.get());
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!"; MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!";
return RET_ERROR; return RET_ERROR;
} }
} else if (GetCNodeTType(*node) == schema::PrimitiveType_Add) { } else if (GetCNodeTType(*node) == schema::PrimitiveType_Add) {
status = QuantAddConstTensor(graph.get(), node.get()); status = QuantAddConstTensor(graph, node.get());
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantAddConstTensor failed!"; MS_LOG(ERROR) << "QuantAddConstTensor failed!";
return RET_ERROR; return RET_ERROR;

View File

@ -73,7 +73,7 @@ class FbQuantizer {
virtual STATUS DoQuantize() = 0; virtual STATUS DoQuantize() = 0;
protected: protected:
std::shared_ptr<schema::MetaGraphT> graph = nullptr; schema::MetaGraphT *graph = nullptr;
}; };
} // namespace mindspore::lite::quant } // namespace mindspore::lite::quant