!40831 [MS][LITE] fix BCE populate bug
Merge pull request !40831 from jianghui58/train_dev
This commit is contained in:
commit
2a81ddbd04
|
@ -1,45 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/common/ops/populate/populate_register.h"
|
||||
#include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
|
||||
using mindspore::schema::PrimitiveType_BinaryCrossEntropyGrad;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
OpParameter *PopulateBinaryCrossEntropyGradParameter(const void *prim) {
|
||||
auto *primitive = static_cast<const schema::Primitive *>(prim);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto value = primitive->value_as_BinaryCrossEntropyGrad();
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *param = reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(BinaryCrossEntropyGradParameter));
|
||||
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
param->reduction = value->reduction();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
|
||||
REG_POPULATE(PrimitiveType_BinaryCrossEntropyGrad, PopulateBinaryCrossEntropyGradParameter, SCHEMA_CUR);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,45 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/common/ops/populate/populate_register.h"
|
||||
#include "nnacl/fp32_grad/binary_cross_entropy.h"
|
||||
using mindspore::schema::PrimitiveType_BinaryCrossEntropy;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
OpParameter *PopulateBinaryCrossEntropyParameter(const void *prim) {
|
||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto value = primitive->value_as_BinaryCrossEntropy();
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "value is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *param = reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(BinaryCrossEntropyParameter));
|
||||
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
param->reduction = value->reduction();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
|
||||
REG_POPULATE(PrimitiveType_BinaryCrossEntropy, PopulateBinaryCrossEntropyParameter, SCHEMA_CUR);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,13 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "src/train/train_populate_parameter.h"
|
||||
#include <algorithm>
|
||||
#include "src/common/ops/populate/populate_register.h"
|
||||
#include "src/common/ops/populate/default_populate.h"
|
||||
#include "nnacl/strided_slice_parameter.h"
|
||||
#include "nnacl/arithmetic.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/lstm_parameter.h"
|
||||
#include "nnacl/pooling_parameter.h"
|
||||
#include "nnacl/power_parameter.h"
|
||||
#include "nnacl/activation_parameter.h"
|
||||
|
@ -31,6 +29,8 @@
|
|||
#include "nnacl/fp32_grad/smooth_l1_loss.h"
|
||||
#include "nnacl/fp32_grad/resize_grad_parameter.h"
|
||||
#include "nnacl/fp32_grad/lstm_grad_fp32.h"
|
||||
#include "nnacl/fp32_grad/binary_cross_entropy.h"
|
||||
#include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
|
||||
|
||||
using mindspore::lite::Registry;
|
||||
|
||||
|
@ -88,29 +88,45 @@ OpParameter *PopulateApplyMomentumParameter(const void *prim) {
|
|||
}
|
||||
|
||||
OpParameter *PopulateBCEParameter(const void *prim) {
|
||||
int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
|
||||
if (reduction == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc reduction failed.";
|
||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto value = primitive->value_as_BinaryCrossEntropy();
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "value is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||
auto value = primitive->value_as_BinaryCrossEntropy();
|
||||
MS_ASSERT(value != nullptr);
|
||||
*reduction = value->reduction();
|
||||
return reinterpret_cast<OpParameter *>(reduction);
|
||||
|
||||
auto *param = reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(BinaryCrossEntropyParameter));
|
||||
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
param->reduction = value->reduction();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
|
||||
OpParameter *PopulateBCEGradParameter(const void *prim) {
|
||||
int32_t *reduction = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
|
||||
if (reduction == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc reduction failed.";
|
||||
auto *primitive = static_cast<const schema::Primitive *>(prim);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto value = primitive->value_as_BinaryCrossEntropyGrad();
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "param is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||
auto value = primitive->value_as_BinaryCrossEntropyGrad();
|
||||
MS_ASSERT(value != nullptr);
|
||||
*reduction = value->reduction();
|
||||
return reinterpret_cast<OpParameter *>(reduction);
|
||||
|
||||
auto *param = reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(BinaryCrossEntropyGradParameter));
|
||||
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
param->reduction = value->reduction();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
|
||||
OpParameter *PopulateAdamParameter(const void *prim) {
|
||||
|
|
Loading…
Reference in New Issue