fushion mem check fixed

This commit is contained in:
kai00 2020-08-24 17:37:32 +08:00
parent 33a562de3d
commit 2ae4214fa2
2 changed files with 17 additions and 16 deletions

View File

@ -90,23 +90,29 @@ STATUS BatchNormConvertScalePass::DoFusion(MetaGraphT *graph, const std::string
return RET_OK;
}
auto bnPath = matchedPath.at(bnOpName);
status = GetTransParam(graph, bnPath);
if (status != RET_OK) {
MS_LOG(ERROR) << "GetTransParam failed: " << status;
return status;
}
status = GenNewScaleTensor(graph, bnPath);
if (status != RET_OK) {
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
delete[] transScale;
delete[] transBias;
transScale = nullptr;
transBias = nullptr;
return status;
}
status = ConvertBNToScale(graph, bnPath);
if (status != RET_OK) {
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
delete[] transScale;
delete[] transBias;
transScale = nullptr;
transBias = nullptr;
return status;
}
delete[] transScale;
delete[] transBias;
transScale = nullptr;
transBias = nullptr;
return RET_OK;
}
STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std::shared_ptr<Path> &bnPath) {
@ -245,6 +251,10 @@ STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::sh
// cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps)
if (memcpy_s(transScale, bnChannel * sizeof(float), varianceData, bnChannel * sizeof(float)) != 0) {
MS_LOG(ERROR) << "memcpy_s transScale error";
delete[] transScale;
delete[] transBias;
transScale = nullptr;
transBias = nullptr;
return RET_ERROR;
}
// 1/sqrt(variance + eps)
@ -370,14 +380,5 @@ STATUS BatchNormConvertScalePass::GetBnEpsilon(MetaGraphT *graph) {
}
return RET_OK;
}
BatchNormConvertScalePass::~BatchNormConvertScalePass() {
if (this->transScale != nullptr) {
delete (this->transScale);
}
if (this->transBias != nullptr) {
delete (this->transBias);
}
}
} // namespace lite
} // namespace mindspore

View File

@ -36,7 +36,7 @@ class BatchNormConvertScalePass : public FusionPass {
public:
BatchNormConvertScalePass() = default;
~BatchNormConvertScalePass() override;
~BatchNormConvertScalePass() = default;
STATUS DefinePattern() override;