From c644cb807341096e8131d173a25440df1b35c030 Mon Sep 17 00:00:00 2001 From: buxue Date: Fri, 16 Dec 2022 12:24:05 +0430 Subject: [PATCH] develop AdamWeightDecay op on cpu lite --- .../api_python/mindspore.communication.rst | 1 - .../device/cpu/kernel/nnacl/fp32/adam_fp32.c | 10 +- .../cpu/kernel/nnacl/fp32/adam_fp32_simd.h.in | 2 +- .../nnacl/infer/adam_weight_decay_infer.c | 54 +++++++ .../nnacl/infer/adam_weight_decay_infer.h | 32 ++++ .../cpu/kernel/nnacl/infer/infer_register.c | 2 + .../cpu/kernel/nnacl/infer/scatter_nd_infer.c | 2 +- .../plugin/device/cpu/kernel/nnacl/op_base.h | 3 +- mindspore/lite/schema/ops.fbs | 5 + mindspore/lite/src/common/ops/ops_def.cc | 5 + .../lite/src/common/ops/ops_func_declare.h | 2 + mindspore/lite/src/common/ops/ops_utils.cc | 1 + .../src/litert/kernel/cpu/fp32_grad/adam.cc | 3 - .../kernel/cpu/fp32_grad/adam_weight_decay.cc | 150 ++++++++++++++++++ .../kernel/cpu/fp32_grad/adam_weight_decay.h | 50 ++++++ mindspore/lite/src/train/optimizer_kernel.h | 4 + .../src/train/train_populate_parameter.cc | 2 + .../infer/adam_weight_decay_infer_test.cc | 57 +++++++ .../mindspore/communication/management.py | 6 +- .../train/callback/_on_request_exit.py | 4 +- tests/ut/cpp/CMakeLists.txt | 2 - .../kernel/cpu/adam_delta_cpu_kernel_test.cc | 93 ----------- 22 files changed, 377 insertions(+), 113 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/adam_weight_decay_infer.c create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/adam_weight_decay_infer.h create mode 100644 mindspore/lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.cc create mode 100644 mindspore/lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.h create mode 100644 mindspore/lite/test/ut/nnacl/infer/adam_weight_decay_infer_test.cc delete mode 100644 tests/ut/cpp/kernel/cpu/adam_delta_cpu_kernel_test.cc diff --git a/docs/api/api_python/mindspore.communication.rst b/docs/api/api_python/mindspore.communication.rst index b11fdf16890..a19413f6d46 100644 --- a/docs/api/api_python/mindspore.communication.rst +++ b/docs/api/api_python/mindspore.communication.rst @@ -8,7 +8,6 @@ mindspore.communication 针对GPU设备,用户需要准备host文件和mpi,详见 `GPU指导文档 `_ 。 -目前尚不支持CPU。 .. py:class:: mindspore.communication.GlobalComm diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/adam_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/adam_fp32.c index 8e84d111db5..512a321f9bc 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/adam_fp32.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/adam_fp32.c @@ -13,13 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include -#include "nnacl/fp32/exp_fp32.h" +#include "nnacl/errorcode.h" #include "nnacl/fp32/adam_fp32.h" -#include "nnacl/intrinsics/ms_simd_instructions.h" -#ifdef ENABLE_AVX512 -#include "nnacl/avx512/adam_fp32_avx512.h" -#endif +#include "nnacl/adam_fp32_simd.h" int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient, size_t start, size_t end, bool use_nesterov) { @@ -159,7 +157,7 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, const float *gradient, size_t start, size_t end) { size_t c1 = start; - SIMD_RUN_AVX512(AdamWeightDecayFp32, c1, var, m, v, lr, beta1, beta2, epsilon, decay, gradient, end); + SIMD_RUN_NO_SCALAR(AdamWeightDecayFp32, c1, var, m, v, lr, beta1, beta2, epsilon, decay, gradient, end); // remaining const float beta1_minus = 1 - beta1; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/adam_fp32_simd.h.in b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/adam_fp32_simd.h.in index d1ea83ac684..c716ff16987 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/adam_fp32_simd.h.in +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/adam_fp32_simd.h.in @@ -23,7 +23,6 @@ extern "C" { #endif @SIMD_INSTRUCTION_BEGIN@ -#ifdef MS_SIMD_AVX512 static inline size_t AdamWeightDecayFp32@SIMD_INSTRUCTION@(size_t index, float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, const float *gradient, size_t end) { SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); @@ -57,6 +56,7 @@ extern "C" { return index; } +#ifdef MS_SIMD_AVX512 static inline size_t FusedCastAdamFp32Fp16@SIMD_INSTRUCTION@(size_t index, float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t end) { SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/adam_weight_decay_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/adam_weight_decay_infer.c new file mode 100644 index 00000000000..34bbf9aaddb --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/adam_weight_decay_infer.c @@ -0,0 +1,54 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/adam_weight_decay_infer.h" +#include "nnacl/infer/infer_register.h" + +int AdamWeightDecayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + const size_t expected_inputs_size = 9; + const int var_idx = 0; + const int m_idx = 1; + const int v_idx = 2; + const int lr_idx = 3; + const int beta1_idx = 4; + const int beta2_idx = 5; + const int epsilon = 6; + const int decay_idx = 7; + const int grad_idx = 8; + int check_ret = + CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, expected_inputs_size); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (GetElementNum(inputs[var_idx]) != GetElementNum(inputs[m_idx]) || + GetElementNum(inputs[var_idx]) != GetElementNum(inputs[v_idx]) || + GetElementNum(inputs[var_idx]) != GetElementNum(inputs[grad_idx]) || GetElementNum(inputs[lr_idx]) != 1 || + GetElementNum(inputs[beta1_idx]) != 1 || GetElementNum(inputs[beta2_idx]) != 1 || + GetElementNum(inputs[epsilon]) != 1 || GetElementNum(inputs[decay_idx]) != 1) { + return NNACL_ERR; + } + if (outputs_size != 0) { + TensorC *out = outputs[0]; + SetDataTypeFormat(out, inputs[0]); + out->shape_size_ = 1; + out->shape_[0] = 1; + } + return NNACL_OK; +} + +REG_INFER(AdamWeightDecay, PrimType_AdamWeightDecay, AdamWeightDecayInferShape) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/adam_weight_decay_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/adam_weight_decay_infer.h new file mode 100644 index 00000000000..30803dc5b41 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/adam_weight_decay_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H +#define MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AdamWeightDecayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/infer_register.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/infer_register.c index f9ddf940dd2..38ef523e348 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/infer_register.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/infer_register.c @@ -18,6 +18,7 @@ #ifdef _MSC_VER #include "nnacl/infer/activation_grad_infer.h" #include "nnacl/infer/adam_infer.h" +#include "nnacl/infer/adam_weight_decay_infer.h" #include "nnacl/infer/add_sub_grad_infer.h" #include "nnacl/infer/addn_infer.h" #include "nnacl/infer/affine_infer.h" @@ -168,6 +169,7 @@ void RegAllInferFunc1() { g_infer_func[PrimType_Activation] = CommonInferShape; g_infer_func[PrimType_ActivationGrad] = ActivationGradInferShape; g_infer_func[PrimType_Adam] = AdamInferShape; + g_infer_func[PrimType_AdamWeightDecay] = AdamWeightDecayInferShape; g_infer_func[PrimType_AdderFusion] = Conv2dInferShape; g_infer_func[PrimType_AddFusion] = ArithmeticInferShape; g_infer_func[PrimType_AddGrad] = AddSubGradInferShape; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/scatter_nd_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/scatter_nd_infer.c index 06fc56a991d..51f535be216 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/scatter_nd_infer.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/scatter_nd_infer.c @@ -28,7 +28,7 @@ int ScatterNdInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor if (shape->data_ == NULL) { return NNACL_INFER_INVALID; } - const TensorC *update = inputs[FIRST_INPUT]; + const TensorC *update = inputs[SECOND_INPUT]; TensorC *output = outputs[0]; SetDataTypeFormat(output, update); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h index ea5d8317c80..fea00d0b81f 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h @@ -538,8 +538,9 @@ enum PrimType { PrimType_ScatterElements = 217, PrimType_Triu = 218, PrimType_Tril = 219, + PrimType_AdamWeightDecay = 220, PrimType_MIN = PrimType_NONE, - PrimType_MAX = PrimType_Tril + 1, + PrimType_MAX = PrimType_AdamWeightDecay + 1, // inner operators. PrimType_Inner_ToFormat = 10000, diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 5c832f202a3..1ca7b0ab8f4 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -237,6 +237,7 @@ union PrimitiveType { ScatterElements, Triu, Tril, + AdamWeightDecay, } table Abs { @@ -1322,3 +1323,7 @@ table Triu { table Tril { } + +table AdamWeightDecay { + use_locking: bool; +} diff --git a/mindspore/lite/src/common/ops/ops_def.cc b/mindspore/lite/src/common/ops/ops_def.cc index da3e1c6579d..0f91290f3f5 100644 --- a/mindspore/lite/src/common/ops/ops_def.cc +++ b/mindspore/lite/src/common/ops/ops_def.cc @@ -237,6 +237,7 @@ OP_TYPE(SparseSegmentSum) OP_TYPE(ScatterElements) OP_TYPE(Triu) OP_TYPE(Tril) +OP_TYPE(AdamWeightDecay) OP_TYPE_DEF_END(PrimitiveType) OP_SCHEMA_DEF(Abs) @@ -1326,3 +1327,7 @@ OP_SCHEMA_DEF_END(Triu) OP_SCHEMA_DEF(Tril) OP_SCHEMA_DEF_END(Tril) + +OP_SCHEMA_DEF(AdamWeightDecay) +OP_ATTR(use_locking, bool) +OP_SCHEMA_DEF_END(AdamWeightDecay) diff --git a/mindspore/lite/src/common/ops/ops_func_declare.h b/mindspore/lite/src/common/ops/ops_func_declare.h index 83bb89922cf..2292d032d87 100644 --- a/mindspore/lite/src/common/ops/ops_func_declare.h +++ b/mindspore/lite/src/common/ops/ops_func_declare.h @@ -266,6 +266,7 @@ #include "ops/scatter_elements.h" #include "ops/triu.h" #include "ops/tril.h" +#include "ops/adam_weight_decay.h" namespace mindspore::lite::ops { #define FUNC_MSOP2SCHEMAOP_DECLARE(OP) std::unique_ptr MSOp2SchemaOp(const mindspore::ops::OP *op); @@ -500,6 +501,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(SparseSegmentSum) FUNC_MSOP2SCHEMAOP_DECLARE(ScatterElements) FUNC_MSOP2SCHEMAOP_DECLARE(Triu) FUNC_MSOP2SCHEMAOP_DECLARE(Tril) +FUNC_MSOP2SCHEMAOP_DECLARE(AdamWeightDecay) #endif } // namespace mindspore::lite::ops #else diff --git a/mindspore/lite/src/common/ops/ops_utils.cc b/mindspore/lite/src/common/ops/ops_utils.cc index d2191d5fa60..6c1802968ca 100644 --- a/mindspore/lite/src/common/ops/ops_utils.cc +++ b/mindspore/lite/src/common/ops/ops_utils.cc @@ -278,6 +278,7 @@ REG_MINDSPORE_OPERATOR(Tril) REG_MINDSPORE_OPERATOR(SparseFillEmptyRows) REG_MINDSPORE_OPERATOR(SparseReshape) REG_MINDSPORE_OPERATOR(SparseSegmentSum) +REG_MINDSPORE_OPERATOR(AdamWeightDecay) } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32_grad/adam.cc b/mindspore/lite/src/litert/kernel/cpu/fp32_grad/adam.cc index 8695b43c33a..7f96e40eb44 100644 --- a/mindspore/lite/src/litert/kernel/cpu/fp32_grad/adam.cc +++ b/mindspore/lite/src/litert/kernel/cpu/fp32_grad/adam.cc @@ -30,9 +30,6 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Adam; namespace mindspore::kernel { -constexpr static int kWeightIdx = 0; -constexpr static int kMomentVector1stIdx = 1; -constexpr static int kMomentVector2stIdx = 2; constexpr static int kBeta1PowerIdx = 3; constexpr static int kBeta2PowerIdx = 4; constexpr static int kBeta1Idx = 6; diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.cc b/mindspore/lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.cc new file mode 100644 index 00000000000..bb80d1feac1 --- /dev/null +++ b/mindspore/lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.cc @@ -0,0 +1,150 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/litert/kernel/cpu/fp32_grad/adam_weight_decay.h" +#include +#include +#include "schema/model_generated.h" +#include "src/litert/kernel_registry.h" +#include "include/errorcode.h" +#include "plugin/device/cpu/kernel/nnacl/fp32/adam_fp32.h" + +using mindspore::kernel::KERNEL_ARCH; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_AdamWeightDecay; + +namespace mindspore::kernel { +namespace { +constexpr static int kBeta1Idx = 4; +constexpr static int kBeta2Idx = 5; +constexpr static int kEpsilonIdx = 6; +constexpr static int kDecayIdx = 7; +} // namespace + +int AdamWeightDecayCPUKernel::ReSize() { return RET_OK; } + +int AdamWeightDecayCPUKernel::DoExecute(int task_id) { + CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_9D); + auto weight = reinterpret_cast(in_tensors_.at(kWeightIdx)->MutableData()); + auto m = reinterpret_cast(in_tensors_.at(kMomentVector1stIdx)->MutableData()); + auto v = reinterpret_cast(in_tensors_.at(kMomentVector2stIdx)->MutableData()); + auto learning_rate = lr_; + auto beta1 = reinterpret_cast(in_tensors_.at(kBeta1Idx)->MutableData())[0]; + auto beta2 = reinterpret_cast(in_tensors_.at(kBeta2Idx)->MutableData())[0]; + auto eps = reinterpret_cast(in_tensors_.at(kEpsilonIdx)->MutableData())[0]; + auto decay = reinterpret_cast(in_tensors_.at(kDecayIdx)->MutableData())[0]; + auto gradient = reinterpret_cast(in_tensors_.at(kGradientIdx)->MutableData()); + int length = in_tensors_.at(kWeightIdx)->ElementsNum(); + CHECK_NULL_RETURN(weight); + CHECK_NULL_RETURN(m); + CHECK_NULL_RETURN(v); + CHECK_NULL_RETURN(gradient); + + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); + int start = stride * task_id; + int end = start + count; + + return AdamWeightDecayFp32(weight, m, v, learning_rate, beta1, beta2, eps, decay, gradient, start, end); +} + +int AdamWeightDecayRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { + auto adam_weight_decay_kernel = reinterpret_cast(cdata); + CHECK_NULL_RETURN(adam_weight_decay_kernel); + auto error_code = RET_OK; + if (adam_weight_decay_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) { + error_code = adam_weight_decay_kernel->ExecuteVirtualBatch(task_id); + } else if (adam_weight_decay_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) { + error_code = adam_weight_decay_kernel->ExecuteVirtualBatch(task_id); + } else { + error_code = adam_weight_decay_kernel->DoExecute(task_id); + } + + if (error_code != RET_OK) { + MS_LOG(ERROR) << "AdamWeightDecay run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int AdamWeightDecayCPUKernel::Run() { + int error_code = ParallelLaunch(this->ms_context_, AdamWeightDecayRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "AdamWeightDecay function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int AdamWeightDecayCPUKernel::Prepare() { + auto ret = OptimizerKernel::Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Failed to initialize AdamWeightDecay Kernel"; + return RET_ERROR; + } + return RET_OK; +} + +std::vector AdamWeightDecayCPUKernel::GetOptimizerParamsIdxs() const { + std::vector indices = {4, 5, 6, 7}; + return indices; +} + +int AdamWeightDecayCPUKernel::OptimizerStep() { + CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_9D - 1); + auto weight = reinterpret_cast(in_tensors_.at(kWeightIdx)->MutableData()); + auto m = reinterpret_cast(in_tensors_.at(kMomentVector1stIdx)->MutableData()); + auto v = reinterpret_cast(in_tensors_.at(kMomentVector2stIdx)->MutableData()); + auto learning_rate = lr_; + auto beta1 = reinterpret_cast(in_tensors_.at(kBeta1Idx)->MutableData())[0]; + auto beta2 = reinterpret_cast(in_tensors_.at(kBeta2Idx)->MutableData())[0]; + auto eps = reinterpret_cast(in_tensors_.at(kEpsilonIdx)->MutableData())[0]; + auto decay = reinterpret_cast(in_tensors_.at(kDecayIdx)->MutableData())[0]; + int length = in_tensors_.at(kWeightIdx)->ElementsNum(); + CHECK_NULL_RETURN(weight); + CHECK_NULL_RETURN(m); + CHECK_NULL_RETURN(v); + + int ret = RET_OK; + if (grad_sum_ != nullptr && valid_grad_sum_) { + size_t start = 0; + size_t end = length; + ret = AdamWeightDecayFp32(weight, m, v, learning_rate, beta1, beta2, eps, decay, grad_sum_, start, end); + std::fill(grad_sum_, grad_sum_ + length, 0); + OptimizerKernel::OptimizerStep(); + } + return ret; +} + +kernel::LiteKernel *CpuAdamWeightDecayFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::InnerContext *ctx, + const kernel::KernelKey &desc) { + MS_CHECK_TRUE_MSG(opParameter != nullptr, nullptr, "Op parameter is nullptr."); + MS_ASSERT(desc.type == schema::PrimitiveType_AdamWeightDecay); + auto *kernel = new (std::nothrow) AdamWeightDecayCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new AdamWeightDecayCPUKernel fail!"; + free(opParameter); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_AdamWeightDecay, CpuAdamWeightDecayFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.h b/mindspore/lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.h new file mode 100644 index 00000000000..7da29665941 --- /dev/null +++ b/mindspore/lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.h @@ -0,0 +1,50 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_ADAM_WEIGHT_DECAY_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_ADAM_WEIGHT_DECAY_H_ + +#include +#include "src/train/optimizer_kernel.h" + +namespace mindspore::kernel { +constexpr static int kLrIdx = 3; +constexpr static int kGradientIdx = 8; + +class AdamWeightDecayCPUKernel : public OptimizerKernel { + public: + explicit AdamWeightDecayCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx) + : OptimizerKernel(parameter, inputs, outputs, ctx, kLrIdx, kGradientIdx), thread_count_(ctx->thread_num_) {} + ~AdamWeightDecayCPUKernel() override { + if (grad_sum_ != nullptr) { + ms_context_->allocator->Free(grad_sum_); + grad_sum_ = nullptr; + } + } + int Prepare() override; + int ReSize() override; + int Run() override; + int DoExecute(int task_id); + int OptimizerStep() override; + std::vector GetOptimizerParamsIdxs() const override; + + private: + int thread_count_; +}; +} // namespace mindspore::kernel + +#endif // #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_ADAM_WEIGHT_DECAY_H_ diff --git a/mindspore/lite/src/train/optimizer_kernel.h b/mindspore/lite/src/train/optimizer_kernel.h index 8a9370304c6..99edef7be5f 100644 --- a/mindspore/lite/src/train/optimizer_kernel.h +++ b/mindspore/lite/src/train/optimizer_kernel.h @@ -29,6 +29,10 @@ using mindspore::lite::RET_OK; using mindspore::lite::RET_OUT_OF_TENSOR_RANGE; namespace mindspore::kernel { +constexpr static int kWeightIdx = 0; +constexpr static int kMomentVector1stIdx = 1; +constexpr static int kMomentVector2stIdx = 2; + enum class WeightUpdateMode { NORMAL, VIRTUAL_BATCH, ACCUMULATE_GRADS }; class OptimizerKernel : public LiteKernel { diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc index 272b13b0506..8ed194a7f0e 100644 --- a/mindspore/lite/src/train/train_populate_parameter.cc +++ b/mindspore/lite/src/train/train_populate_parameter.cc @@ -621,6 +621,8 @@ void PopulateTrainParameters() { Registry SgdParameterRegistry(schema::PrimitiveType_SGD, PopulateSgdParameter, lite::SCHEMA_CUR); Registry BNGradParameterRegistry(schema::PrimitiveType_BatchNormGrad, PopulateBNGradParameter, lite::SCHEMA_CUR); Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter, lite::SCHEMA_CUR); + Registry AdamWeightDecayParameterRegistry(schema::PrimitiveType_AdamWeightDecay, lite::DefaultPopulateParameter, + lite::SCHEMA_CUR); Registry AssignParameterRegistry(schema::PrimitiveType_Assign, lite::DefaultPopulateParameter, lite::SCHEMA_CUR); Registry AssignAddParameterRegistry(schema::PrimitiveType_AssignAdd, lite::DefaultPopulateParameter, lite::SCHEMA_CUR); diff --git a/mindspore/lite/test/ut/nnacl/infer/adam_weight_decay_infer_test.cc b/mindspore/lite/test/ut/nnacl/infer/adam_weight_decay_infer_test.cc new file mode 100644 index 00000000000..52dae7fbe34 --- /dev/null +++ b/mindspore/lite/test/ut/nnacl/infer/adam_weight_decay_infer_test.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/common_test.h" +#include "nnacl/infer/adam_weight_decay_infer.h" + +namespace mindspore { +class AdamWeightDecayInfer : public mindspore::CommonTest { + public: + AdamWeightDecayInfer() {} +}; + +void AdamWeightDecayInferInitArgs(std::vector *inputs, std::vector *outputs) { + const size_t inputs_size = 9; + for (size_t i = 0; i < inputs_size; i++) { + auto *input_x = new TensorC; + input_x->shape_size_ = 1; + input_x->shape_[0] = 1; + inputs->push_back(input_x); + } + auto *output = new TensorC; + outputs->push_back(output); +} + +void AdamWeightDecayInferReleaseResources(OpParameter *param, std::vector inputs, + std::vector outputs) { + delete param; + for (auto t : inputs) delete t; + for (auto t : outputs) delete t; +} + +TEST_F(AdamWeightDecayInfer, OneDim) { + std::vector inputs; + std::vector outputs; + AdamWeightDecayInferInitArgs(&inputs, &outputs); + auto *param = new OpParameter; + int ret = AdamWeightDecayInferShape((const TensorC **)inputs.data(), inputs.size(), outputs.data(), outputs.size(), + reinterpret_cast(param)); + ASSERT_EQ(ret, NNACL_OK); + ASSERT_EQ(outputs[0]->shape_size_, 1); + ASSERT_EQ(outputs[0]->shape_[0], 1); + AdamWeightDecayInferReleaseResources(param, inputs, outputs); +} +} // namespace mindspore diff --git a/mindspore/python/mindspore/communication/management.py b/mindspore/python/mindspore/communication/management.py index ca60a7fe3c2..144621391a8 100755 --- a/mindspore/python/mindspore/communication/management.py +++ b/mindspore/python/mindspore/communication/management.py @@ -109,7 +109,7 @@ def init(backend_name=None): have not been exported when backend is HCCL. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: .. note:: @@ -225,7 +225,7 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP): RuntimeError: If HCCL/NCCL is not available. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: .. note:: @@ -322,7 +322,7 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP): RuntimeError: If HCCL/NCCL is not available. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: .. note:: diff --git a/mindspore/python/mindspore/train/callback/_on_request_exit.py b/mindspore/python/mindspore/train/callback/_on_request_exit.py index f24121a97ee..453f827ba43 100644 --- a/mindspore/python/mindspore/train/callback/_on_request_exit.py +++ b/mindspore/python/mindspore/train/callback/_on_request_exit.py @@ -95,8 +95,8 @@ class OnRequestExit(Callback): self.train_file_path = os.path.abspath(os.path.join(directory, f"{file_name}_train")) self.eval_file_path = os.path.abspath(os.path.join(directory, f"{file_name}_eval")) self.sig = Validator.check_isinstance('sig', sig, int) - if self.sig == signal.SIGKILL or self.sig == signal.SIGINT: - raise ValueError("Not support send exit request by signal SIGKILL or SIGINT.") + if hasattr(signal, "SIGKILL") and self.sig == signal.SIGKILL: + raise ValueError("Not support send exit request by signal SIGKILL.") self.exit = False def on_train_begin(self, run_context): diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index a4cbebb40db..fd019917df6 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -202,7 +202,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_proximal_adagrad_cpu_kernel.cc" "../../../mindspore/ccsrc/plugin/device/cpu/kernel/unique_cpu_kernel.cc" "../../../mindspore/ccsrc/plugin/device/cpu/kernel/unique_with_pad_cpu_kernel.cc" - "../../../mindspore/ccsrc/plugin/device/cpu/kernel/adam_delta_cpu_kernel.cc" "../../../mindspore/ccsrc/plugin/device/cpu/kernel/fused_ada_factor_cpu_kernel.cc" "../../../mindspore/ccsrc/kernel/akg/*.cc" "../../../mindspore/ccsrc/plugin/device/ascend/kernel/akg/*.cc" @@ -243,7 +242,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/distributed/embedding_cache/*.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/profiler/*.cc" "../../../mindspore/ccsrc/profiler/device/profiling.cc" - "../../../mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/adam_fp32.c" "../../../mindspore/ccsrc/kernel/kernel.cc" "../../../mindspore/ccsrc/plugin/device/ascend/kernel/akg/akg_kernel_metadata.cc" "../../../mindspore/ccsrc/plugin/device/ascend/kernel/ascend_kernel_mod.cc" diff --git a/tests/ut/cpp/kernel/cpu/adam_delta_cpu_kernel_test.cc b/tests/ut/cpp/kernel/cpu/adam_delta_cpu_kernel_test.cc deleted file mode 100644 index c4435b131ce..00000000000 --- a/tests/ut/cpp/kernel/cpu/adam_delta_cpu_kernel_test.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "common/common_test.h" -#define private public -#define protected public -#include "plugin/device/cpu/kernel/adam_delta_cpu_kernel.h" -#undef private -#undef protected - -namespace mindspore { -namespace kernel { -class AdamDeltaCpuKernelTest : public UT::Common { - public: - AdamDeltaCpuKernelTest() : adam_delta_(std::make_shared()) {} - - void SetUp() override { - delta_.clear(); - m_.clear(); - v_.clear(); - grad_.clear(); - inputs_.clear(); - workspace_.clear(); - outputs_.clear(); - } - - AddressPtr CreateKernelAddress(void *addr, size_t elem_num) { - auto kernel_addr = std::make_shared
(); - kernel_addr->addr = addr; - kernel_addr->size = elem_num * 4; - return kernel_addr; - } - - void CreateAddress() { - inputs_.push_back(CreateKernelAddress(m_.data(), elem_num_)); - inputs_.push_back(CreateKernelAddress(v_.data(), elem_num_)); - inputs_.push_back(CreateKernelAddress(&beta1_power_, 1)); - inputs_.push_back(CreateKernelAddress(&beta2_power_, 1)); - inputs_.push_back(CreateKernelAddress(&lr_, 1)); - inputs_.push_back(CreateKernelAddress(&beta1_, 1)); - inputs_.push_back(CreateKernelAddress(&beta2_, 1)); - inputs_.push_back(CreateKernelAddress(&epsilon_, 1)); - inputs_.push_back(CreateKernelAddress(grad_.data(), elem_num_)); - outputs_.push_back(CreateKernelAddress(delta_.data(), elem_num_)); - } - - std::vector delta_; - std::vector m_; - std::vector v_; - std::vector grad_; - std::vector inputs_; - std::vector workspace_; - std::vector outputs_; - std::shared_ptr adam_delta_; - float beta1_power_ = 0.9; - float beta2_power_ = 0.999; - float lr_ = 0.001; - float beta1_ = 0.9; - float beta2_ = 0.999; - float epsilon_ = 1e-8; - size_t elem_num_ = 27; -}; - -TEST_F(AdamDeltaCpuKernelTest, compute_test) { - for (size_t i = 0; i < elem_num_; ++i) { - delta_.push_back(1.0); - m_.push_back(1.0); - v_.push_back(1.0); - grad_.push_back(1.0); - } - adam_delta_->elem_num_ = elem_num_; - CreateAddress(); - adam_delta_->Launch(inputs_, workspace_, outputs_); - for (size_t i = 0; i < elem_num_; ++i) { - EXPECT_TRUE(std::fabs(delta_[i] + 0.000316) < 1e-6); - } -} -} // namespace kernel -} // namespace mindspore