!40831 [MS][LITE] fix BCE populate bug

Merge pull request !40831 from jianghui58/train_dev
This commit is contained in:
i-robot 2022-09-16 06:48:12 +00:00 committed by Gitee
commit 2a81ddbd04
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 35 additions and 109 deletions

View File

@ -1,45 +0,0 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/ops/populate/populate_register.h"
#include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
using mindspore::schema::PrimitiveType_BinaryCrossEntropyGrad;
namespace mindspore {
namespace lite {
OpParameter *PopulateBinaryCrossEntropyGradParameter(const void *prim) {
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_BinaryCrossEntropyGrad();
if (value == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
return nullptr;
}
auto *param = reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(BinaryCrossEntropyGradParameter));
param->op_parameter_.type_ = primitive->value_type();
param->reduction = value->reduction();
return reinterpret_cast<OpParameter *>(param);
}
REG_POPULATE(PrimitiveType_BinaryCrossEntropyGrad, PopulateBinaryCrossEntropyGradParameter, SCHEMA_CUR);
} // namespace lite
} // namespace mindspore

View File

@ -1,45 +0,0 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/ops/populate/populate_register.h"
#include "nnacl/fp32_grad/binary_cross_entropy.h"
using mindspore::schema::PrimitiveType_BinaryCrossEntropy;
namespace mindspore {
namespace lite {
OpParameter *PopulateBinaryCrossEntropyParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_BinaryCrossEntropy();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}
auto *param = reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(BinaryCrossEntropyParameter));
param->op_parameter_.type_ = primitive->value_type();
param->reduction = value->reduction();
return reinterpret_cast<OpParameter *>(param);
}
REG_POPULATE(PrimitiveType_BinaryCrossEntropy, PopulateBinaryCrossEntropyParameter, SCHEMA_CUR);
} // namespace lite
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,13 +14,11 @@
* limitations under the License.
*/
#include "src/train/train_populate_parameter.h"
#include <algorithm>
#include "src/common/ops/populate/populate_register.h"
#include "src/common/ops/populate/default_populate.h"
#include "nnacl/strided_slice_parameter.h"
#include "nnacl/arithmetic.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/lstm_parameter.h"
#include "nnacl/pooling_parameter.h"
#include "nnacl/power_parameter.h"
#include "nnacl/activation_parameter.h"
@ -31,6 +29,8 @@
#include "nnacl/fp32_grad/smooth_l1_loss.h"
#include "nnacl/fp32_grad/resize_grad_parameter.h"
#include "nnacl/fp32_grad/lstm_grad_fp32.h"
#include "nnacl/fp32_grad/binary_cross_entropy.h"
#include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
using mindspore::lite::Registry;
@ -88,29 +88,45 @@ OpParameter *PopulateApplyMomentumParameter(const void *prim) {
}
OpParameter *PopulateBCEParameter(const void *prim) {
int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
if (reduction == nullptr) {
MS_LOG(ERROR) << "malloc reduction failed.";
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_BinaryCrossEntropy();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}
auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_BinaryCrossEntropy();
MS_ASSERT(value != nullptr);
*reduction = value->reduction();
return reinterpret_cast<OpParameter *>(reduction);
auto *param = reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(BinaryCrossEntropyParameter));
param->op_parameter_.type_ = primitive->value_type();
param->reduction = value->reduction();
return reinterpret_cast<OpParameter *>(param);
}
OpParameter *PopulateBCEGradParameter(const void *prim) {
int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
if (reduction == nullptr) {
MS_LOG(ERROR) << "malloc reduction failed.";
auto *primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_BinaryCrossEntropyGrad();
if (value == nullptr) {
MS_LOG(ERROR) << "param is nullptr";
return nullptr;
}
auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_BinaryCrossEntropyGrad();
MS_ASSERT(value != nullptr);
*reduction = value->reduction();
return reinterpret_cast<OpParameter *>(reduction);
auto *param = reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(BinaryCrossEntropyGradParameter));
param->op_parameter_.type_ = primitive->value_type();
param->reduction = value->reduction();
return reinterpret_cast<OpParameter *>(param);
}
OpParameter *PopulateAdamParameter(const void *prim) {