From fddb6e323ab1fe4c92e3f9e7942f5ddabe20fc56 Mon Sep 17 00:00:00 2001 From: jianghui58 Date: Wed, 24 Aug 2022 17:20:02 +0800 Subject: [PATCH] fix BCE populate bug && add OnesLike infer registry --- .../binary_cross_entropy_grad_populate.cc | 45 ---------------- .../populate/binary_cross_entropy_populate.cc | 45 ---------------- .../src/train/train_populate_parameter.cc | 54 ++++++++++++------- 3 files changed, 35 insertions(+), 109 deletions(-) delete mode 100644 mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc delete mode 100644 mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc diff --git a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc b/mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc deleted file mode 100644 index 5da193bc6e8..00000000000 --- a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_grad_populate.cc +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32_grad/binary_cross_entropy_grad.h" -using mindspore::schema::PrimitiveType_BinaryCrossEntropyGrad; - -namespace mindspore { -namespace lite { -OpParameter *PopulateBinaryCrossEntropyGradParameter(const void *prim) { - auto *primitive = static_cast(prim); - MS_ASSERT(primitive != nullptr); - auto value = primitive->value_as_BinaryCrossEntropyGrad(); - if (value == nullptr) { - MS_LOG(ERROR) << "param is nullptr"; - return nullptr; - } - - auto *param = reinterpret_cast(malloc(sizeof(BinaryCrossEntropyGradParameter))); - if (param == nullptr) { - MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed."; - return nullptr; - } - memset(param, 0, sizeof(BinaryCrossEntropyGradParameter)); - - param->op_parameter_.type_ = primitive->value_type(); - param->reduction = value->reduction(); - return reinterpret_cast(param); -} - -REG_POPULATE(PrimitiveType_BinaryCrossEntropyGrad, PopulateBinaryCrossEntropyGradParameter, SCHEMA_CUR); -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc b/mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc deleted file mode 100644 index 10060d3fd7c..00000000000 --- a/mindspore/lite/src/common/ops/populate/binary_cross_entropy_populate.cc +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32_grad/binary_cross_entropy.h" -using mindspore::schema::PrimitiveType_BinaryCrossEntropy; - -namespace mindspore { -namespace lite { -OpParameter *PopulateBinaryCrossEntropyParameter(const void *prim) { - auto primitive = static_cast(prim); - MS_ASSERT(primitive != nullptr); - auto value = primitive->value_as_BinaryCrossEntropy(); - if (value == nullptr) { - MS_LOG(ERROR) << "value is nullptr"; - return nullptr; - } - - auto *param = reinterpret_cast(malloc(sizeof(BinaryCrossEntropyParameter))); - if (param == nullptr) { - MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed."; - return nullptr; - } - memset(param, 0, sizeof(BinaryCrossEntropyParameter)); - - param->op_parameter_.type_ = primitive->value_type(); - param->reduction = value->reduction(); - return reinterpret_cast(param); -} - -REG_POPULATE(PrimitiveType_BinaryCrossEntropy, PopulateBinaryCrossEntropyParameter, SCHEMA_CUR); -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc index 85fdb0e470a..272b13b0506 100644 --- a/mindspore/lite/src/train/train_populate_parameter.cc +++ b/mindspore/lite/src/train/train_populate_parameter.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,13 +14,11 @@ * limitations under the License. */ #include "src/train/train_populate_parameter.h" -#include #include "src/common/ops/populate/populate_register.h" #include "src/common/ops/populate/default_populate.h" #include "nnacl/strided_slice_parameter.h" #include "nnacl/arithmetic.h" #include "nnacl/conv_parameter.h" -#include "nnacl/lstm_parameter.h" #include "nnacl/pooling_parameter.h" #include "nnacl/power_parameter.h" #include "nnacl/activation_parameter.h" @@ -31,6 +29,8 @@ #include "nnacl/fp32_grad/smooth_l1_loss.h" #include "nnacl/fp32_grad/resize_grad_parameter.h" #include "nnacl/fp32_grad/lstm_grad_fp32.h" +#include "nnacl/fp32_grad/binary_cross_entropy.h" +#include "nnacl/fp32_grad/binary_cross_entropy_grad.h" using mindspore::lite::Registry; @@ -88,29 +88,45 @@ OpParameter *PopulateApplyMomentumParameter(const void *prim) { } OpParameter *PopulateBCEParameter(const void *prim) { - int32_t *reduction = reinterpret_cast(malloc(sizeof(int32_t))); - if (reduction == nullptr) { - MS_LOG(ERROR) << "malloc reduction failed."; + auto primitive = static_cast(prim); + MS_ASSERT(primitive != nullptr); + auto value = primitive->value_as_BinaryCrossEntropy(); + if (value == nullptr) { + MS_LOG(ERROR) << "value is nullptr"; return nullptr; } - auto primitive = static_cast(prim); - auto value = primitive->value_as_BinaryCrossEntropy(); - MS_ASSERT(value != nullptr); - *reduction = value->reduction(); - return reinterpret_cast(reduction); + + auto *param = reinterpret_cast(malloc(sizeof(BinaryCrossEntropyParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed."; + return nullptr; + } + memset(param, 0, sizeof(BinaryCrossEntropyParameter)); + + param->op_parameter_.type_ = primitive->value_type(); + param->reduction = value->reduction(); + return reinterpret_cast(param); } OpParameter *PopulateBCEGradParameter(const void *prim) { - int32_t *reduction = reinterpret_cast(malloc(sizeof(int32_t))); - if (reduction == nullptr) { - MS_LOG(ERROR) << "malloc reduction failed."; + auto *primitive = static_cast(prim); + MS_ASSERT(primitive != nullptr); + auto value = primitive->value_as_BinaryCrossEntropyGrad(); + if (value == nullptr) { + MS_LOG(ERROR) << "param is nullptr"; return nullptr; } - auto primitive = static_cast(prim); - auto value = primitive->value_as_BinaryCrossEntropyGrad(); - MS_ASSERT(value != nullptr); - *reduction = value->reduction(); - return reinterpret_cast(reduction); + + auto *param = reinterpret_cast(malloc(sizeof(BinaryCrossEntropyGradParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed."; + return nullptr; + } + memset(param, 0, sizeof(BinaryCrossEntropyGradParameter)); + + param->op_parameter_.type_ = primitive->value_type(); + param->reduction = value->reduction(); + return reinterpret_cast(param); } OpParameter *PopulateAdamParameter(const void *prim) {