develop AdamWeightDecay op on cpu lite

This commit is contained in:
buxue 2022-12-16 12:24:05 +04:30
parent 87f8483348
commit c644cb8073
22 changed files with 377 additions and 113 deletions

View File

@ -8,7 +8,6 @@ mindspore.communication
针对GPU设备用户需要准备host文件和mpi详见 `GPU指导文档 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/train_gpu.html#准备环节>`_
目前尚不支持CPU。
.. py:class:: mindspore.communication.GlobalComm

View File

@ -13,13 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <math.h>
#include "nnacl/fp32/exp_fp32.h"
#include "nnacl/errorcode.h"
#include "nnacl/fp32/adam_fp32.h"
#include "nnacl/intrinsics/ms_simd_instructions.h"
#ifdef ENABLE_AVX512
#include "nnacl/avx512/adam_fp32_avx512.h"
#endif
#include "nnacl/adam_fp32_simd.h"
int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient,
size_t start, size_t end, bool use_nesterov) {
@ -159,7 +157,7 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const float *gradient, size_t start, size_t end) {
size_t c1 = start;
SIMD_RUN_AVX512(AdamWeightDecayFp32, c1, var, m, v, lr, beta1, beta2, epsilon, decay, gradient, end);
SIMD_RUN_NO_SCALAR(AdamWeightDecayFp32, c1, var, m, v, lr, beta1, beta2, epsilon, decay, gradient, end);
// remaining
const float beta1_minus = 1 - beta1;

View File

@ -23,7 +23,6 @@
extern "C" {
#endif
@SIMD_INSTRUCTION_BEGIN@
#ifdef MS_SIMD_AVX512
static inline size_t AdamWeightDecayFp32@SIMD_INSTRUCTION@(size_t index, float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const float *gradient, size_t end) {
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
@ -57,6 +56,7 @@ extern "C" {
return index;
}
#ifdef MS_SIMD_AVX512
static inline size_t FusedCastAdamFp32Fp16@SIMD_INSTRUCTION@(size_t index, float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
float global_norm_reciprocal, size_t end) {
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);

View File

@ -0,0 +1,54 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/infer/adam_weight_decay_infer.h"
#include "nnacl/infer/infer_register.h"
int AdamWeightDecayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
const size_t expected_inputs_size = 9;
const int var_idx = 0;
const int m_idx = 1;
const int v_idx = 2;
const int lr_idx = 3;
const int beta1_idx = 4;
const int beta2_idx = 5;
const int epsilon = 6;
const int decay_idx = 7;
const int grad_idx = 8;
int check_ret =
CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, expected_inputs_size);
if (check_ret != NNACL_OK) {
return check_ret;
}
if (GetElementNum(inputs[var_idx]) != GetElementNum(inputs[m_idx]) ||
GetElementNum(inputs[var_idx]) != GetElementNum(inputs[v_idx]) ||
GetElementNum(inputs[var_idx]) != GetElementNum(inputs[grad_idx]) || GetElementNum(inputs[lr_idx]) != 1 ||
GetElementNum(inputs[beta1_idx]) != 1 || GetElementNum(inputs[beta2_idx]) != 1 ||
GetElementNum(inputs[epsilon]) != 1 || GetElementNum(inputs[decay_idx]) != 1) {
return NNACL_ERR;
}
if (outputs_size != 0) {
TensorC *out = outputs[0];
SetDataTypeFormat(out, inputs[0]);
out->shape_size_ = 1;
out->shape_[0] = 1;
}
return NNACL_OK;
}
REG_INFER(AdamWeightDecay, PrimType_AdamWeightDecay, AdamWeightDecayInferShape)

View File

@ -0,0 +1,32 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H
#define MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H
#include "nnacl/infer/common_infer.h"
#ifdef __cplusplus
extern "C" {
#endif
int AdamWeightDecayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H

View File

@ -18,6 +18,7 @@
#ifdef _MSC_VER
#include "nnacl/infer/activation_grad_infer.h"
#include "nnacl/infer/adam_infer.h"
#include "nnacl/infer/adam_weight_decay_infer.h"
#include "nnacl/infer/add_sub_grad_infer.h"
#include "nnacl/infer/addn_infer.h"
#include "nnacl/infer/affine_infer.h"
@ -168,6 +169,7 @@ void RegAllInferFunc1() {
g_infer_func[PrimType_Activation] = CommonInferShape;
g_infer_func[PrimType_ActivationGrad] = ActivationGradInferShape;
g_infer_func[PrimType_Adam] = AdamInferShape;
g_infer_func[PrimType_AdamWeightDecay] = AdamWeightDecayInferShape;
g_infer_func[PrimType_AdderFusion] = Conv2dInferShape;
g_infer_func[PrimType_AddFusion] = ArithmeticInferShape;
g_infer_func[PrimType_AddGrad] = AddSubGradInferShape;

View File

@ -28,7 +28,7 @@ int ScatterNdInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
if (shape->data_ == NULL) {
return NNACL_INFER_INVALID;
}
const TensorC *update = inputs[FIRST_INPUT];
const TensorC *update = inputs[SECOND_INPUT];
TensorC *output = outputs[0];
SetDataTypeFormat(output, update);

View File

@ -538,8 +538,9 @@ enum PrimType {
PrimType_ScatterElements = 217,
PrimType_Triu = 218,
PrimType_Tril = 219,
PrimType_AdamWeightDecay = 220,
PrimType_MIN = PrimType_NONE,
PrimType_MAX = PrimType_Tril + 1,
PrimType_MAX = PrimType_AdamWeightDecay + 1,
// inner operators.
PrimType_Inner_ToFormat = 10000,

View File

@ -237,6 +237,7 @@ union PrimitiveType {
ScatterElements,
Triu,
Tril,
AdamWeightDecay,
}
table Abs {
@ -1322,3 +1323,7 @@ table Triu {
table Tril {
}
table AdamWeightDecay {
use_locking: bool;
}

View File

@ -237,6 +237,7 @@ OP_TYPE(SparseSegmentSum)
OP_TYPE(ScatterElements)
OP_TYPE(Triu)
OP_TYPE(Tril)
OP_TYPE(AdamWeightDecay)
OP_TYPE_DEF_END(PrimitiveType)
OP_SCHEMA_DEF(Abs)
@ -1326,3 +1327,7 @@ OP_SCHEMA_DEF_END(Triu)
OP_SCHEMA_DEF(Tril)
OP_SCHEMA_DEF_END(Tril)
OP_SCHEMA_DEF(AdamWeightDecay)
OP_ATTR(use_locking, bool)
OP_SCHEMA_DEF_END(AdamWeightDecay)

View File

@ -266,6 +266,7 @@
#include "ops/scatter_elements.h"
#include "ops/triu.h"
#include "ops/tril.h"
#include "ops/adam_weight_decay.h"
namespace mindspore::lite::ops {
#define FUNC_MSOP2SCHEMAOP_DECLARE(OP) std::unique_ptr<schema::PrimitiveT> MSOp2SchemaOp(const mindspore::ops::OP *op);
@ -500,6 +501,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(SparseSegmentSum)
FUNC_MSOP2SCHEMAOP_DECLARE(ScatterElements)
FUNC_MSOP2SCHEMAOP_DECLARE(Triu)
FUNC_MSOP2SCHEMAOP_DECLARE(Tril)
FUNC_MSOP2SCHEMAOP_DECLARE(AdamWeightDecay)
#endif
} // namespace mindspore::lite::ops
#else

View File

@ -278,6 +278,7 @@ REG_MINDSPORE_OPERATOR(Tril)
REG_MINDSPORE_OPERATOR(SparseFillEmptyRows)
REG_MINDSPORE_OPERATOR(SparseReshape)
REG_MINDSPORE_OPERATOR(SparseSegmentSum)
REG_MINDSPORE_OPERATOR(AdamWeightDecay)
} // namespace lite
} // namespace mindspore

View File

@ -30,9 +30,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Adam;
namespace mindspore::kernel {
constexpr static int kWeightIdx = 0;
constexpr static int kMomentVector1stIdx = 1;
constexpr static int kMomentVector2stIdx = 2;
constexpr static int kBeta1PowerIdx = 3;
constexpr static int kBeta2PowerIdx = 4;
constexpr static int kBeta1Idx = 6;

View File

@ -0,0 +1,150 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/litert/kernel/cpu/fp32_grad/adam_weight_decay.h"
#include <cmath>
#include <string>
#include "schema/model_generated.h"
#include "src/litert/kernel_registry.h"
#include "include/errorcode.h"
#include "plugin/device/cpu/kernel/nnacl/fp32/adam_fp32.h"
using mindspore::kernel::KERNEL_ARCH;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_AdamWeightDecay;
namespace mindspore::kernel {
namespace {
constexpr static int kBeta1Idx = 4;
constexpr static int kBeta2Idx = 5;
constexpr static int kEpsilonIdx = 6;
constexpr static int kDecayIdx = 7;
} // namespace
int AdamWeightDecayCPUKernel::ReSize() { return RET_OK; }
int AdamWeightDecayCPUKernel::DoExecute(int task_id) {
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_9D);
auto weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIdx)->MutableData());
auto m = reinterpret_cast<float *>(in_tensors_.at(kMomentVector1stIdx)->MutableData());
auto v = reinterpret_cast<float *>(in_tensors_.at(kMomentVector2stIdx)->MutableData());
auto learning_rate = lr_;
auto beta1 = reinterpret_cast<float *>(in_tensors_.at(kBeta1Idx)->MutableData())[0];
auto beta2 = reinterpret_cast<float *>(in_tensors_.at(kBeta2Idx)->MutableData())[0];
auto eps = reinterpret_cast<float *>(in_tensors_.at(kEpsilonIdx)->MutableData())[0];
auto decay = reinterpret_cast<float *>(in_tensors_.at(kDecayIdx)->MutableData())[0];
auto gradient = reinterpret_cast<float *>(in_tensors_.at(kGradientIdx)->MutableData());
int length = in_tensors_.at(kWeightIdx)->ElementsNum();
CHECK_NULL_RETURN(weight);
CHECK_NULL_RETURN(m);
CHECK_NULL_RETURN(v);
CHECK_NULL_RETURN(gradient);
int stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id);
int start = stride * task_id;
int end = start + count;
return AdamWeightDecayFp32(weight, m, v, learning_rate, beta1, beta2, eps, decay, gradient, start, end);
}
int AdamWeightDecayRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
auto adam_weight_decay_kernel = reinterpret_cast<AdamWeightDecayCPUKernel *>(cdata);
CHECK_NULL_RETURN(adam_weight_decay_kernel);
auto error_code = RET_OK;
if (adam_weight_decay_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) {
error_code = adam_weight_decay_kernel->ExecuteVirtualBatch(task_id);
} else if (adam_weight_decay_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) {
error_code = adam_weight_decay_kernel->ExecuteVirtualBatch(task_id);
} else {
error_code = adam_weight_decay_kernel->DoExecute(task_id);
}
if (error_code != RET_OK) {
MS_LOG(ERROR) << "AdamWeightDecay run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int AdamWeightDecayCPUKernel::Run() {
int error_code = ParallelLaunch(this->ms_context_, AdamWeightDecayRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "AdamWeightDecay function error error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int AdamWeightDecayCPUKernel::Prepare() {
auto ret = OptimizerKernel::Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Failed to initialize AdamWeightDecay Kernel";
return RET_ERROR;
}
return RET_OK;
}
std::vector<int> AdamWeightDecayCPUKernel::GetOptimizerParamsIdxs() const {
std::vector<int> indices = {4, 5, 6, 7};
return indices;
}
int AdamWeightDecayCPUKernel::OptimizerStep() {
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_9D - 1);
auto weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIdx)->MutableData());
auto m = reinterpret_cast<float *>(in_tensors_.at(kMomentVector1stIdx)->MutableData());
auto v = reinterpret_cast<float *>(in_tensors_.at(kMomentVector2stIdx)->MutableData());
auto learning_rate = lr_;
auto beta1 = reinterpret_cast<float *>(in_tensors_.at(kBeta1Idx)->MutableData())[0];
auto beta2 = reinterpret_cast<float *>(in_tensors_.at(kBeta2Idx)->MutableData())[0];
auto eps = reinterpret_cast<float *>(in_tensors_.at(kEpsilonIdx)->MutableData())[0];
auto decay = reinterpret_cast<float *>(in_tensors_.at(kDecayIdx)->MutableData())[0];
int length = in_tensors_.at(kWeightIdx)->ElementsNum();
CHECK_NULL_RETURN(weight);
CHECK_NULL_RETURN(m);
CHECK_NULL_RETURN(v);
int ret = RET_OK;
if (grad_sum_ != nullptr && valid_grad_sum_) {
size_t start = 0;
size_t end = length;
ret = AdamWeightDecayFp32(weight, m, v, learning_rate, beta1, beta2, eps, decay, grad_sum_, start, end);
std::fill(grad_sum_, grad_sum_ + length, 0);
OptimizerKernel::OptimizerStep();
}
return ret;
}
kernel::LiteKernel *CpuAdamWeightDecayFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter, const lite::InnerContext *ctx,
const kernel::KernelKey &desc) {
MS_CHECK_TRUE_MSG(opParameter != nullptr, nullptr, "Op parameter is nullptr.");
MS_ASSERT(desc.type == schema::PrimitiveType_AdamWeightDecay);
auto *kernel = new (std::nothrow) AdamWeightDecayCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new AdamWeightDecayCPUKernel fail!";
free(opParameter);
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_AdamWeightDecay, CpuAdamWeightDecayFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -0,0 +1,50 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_ADAM_WEIGHT_DECAY_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_ADAM_WEIGHT_DECAY_H_
#include <vector>
#include "src/train/optimizer_kernel.h"
namespace mindspore::kernel {
constexpr static int kLrIdx = 3;
constexpr static int kGradientIdx = 8;
class AdamWeightDecayCPUKernel : public OptimizerKernel {
public:
explicit AdamWeightDecayCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: OptimizerKernel(parameter, inputs, outputs, ctx, kLrIdx, kGradientIdx), thread_count_(ctx->thread_num_) {}
~AdamWeightDecayCPUKernel() override {
if (grad_sum_ != nullptr) {
ms_context_->allocator->Free(grad_sum_);
grad_sum_ = nullptr;
}
}
int Prepare() override;
int ReSize() override;
int Run() override;
int DoExecute(int task_id);
int OptimizerStep() override;
std::vector<int> GetOptimizerParamsIdxs() const override;
private:
int thread_count_;
};
} // namespace mindspore::kernel
#endif // #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_ADAM_WEIGHT_DECAY_H_

View File

@ -29,6 +29,10 @@ using mindspore::lite::RET_OK;
using mindspore::lite::RET_OUT_OF_TENSOR_RANGE;
namespace mindspore::kernel {
constexpr static int kWeightIdx = 0;
constexpr static int kMomentVector1stIdx = 1;
constexpr static int kMomentVector2stIdx = 2;
enum class WeightUpdateMode { NORMAL, VIRTUAL_BATCH, ACCUMULATE_GRADS };
class OptimizerKernel : public LiteKernel {

View File

@ -621,6 +621,8 @@ void PopulateTrainParameters() {
Registry SgdParameterRegistry(schema::PrimitiveType_SGD, PopulateSgdParameter, lite::SCHEMA_CUR);
Registry BNGradParameterRegistry(schema::PrimitiveType_BatchNormGrad, PopulateBNGradParameter, lite::SCHEMA_CUR);
Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter, lite::SCHEMA_CUR);
Registry AdamWeightDecayParameterRegistry(schema::PrimitiveType_AdamWeightDecay, lite::DefaultPopulateParameter,
lite::SCHEMA_CUR);
Registry AssignParameterRegistry(schema::PrimitiveType_Assign, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
Registry AssignAddParameterRegistry(schema::PrimitiveType_AssignAdd, lite::DefaultPopulateParameter,
lite::SCHEMA_CUR);

View File

@ -0,0 +1,57 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common_test.h"
#include "nnacl/infer/adam_weight_decay_infer.h"
namespace mindspore {
class AdamWeightDecayInfer : public mindspore::CommonTest {
public:
AdamWeightDecayInfer() {}
};
void AdamWeightDecayInferInitArgs(std::vector<TensorC *> *inputs, std::vector<TensorC *> *outputs) {
const size_t inputs_size = 9;
for (size_t i = 0; i < inputs_size; i++) {
auto *input_x = new TensorC;
input_x->shape_size_ = 1;
input_x->shape_[0] = 1;
inputs->push_back(input_x);
}
auto *output = new TensorC;
outputs->push_back(output);
}
void AdamWeightDecayInferReleaseResources(OpParameter *param, std::vector<TensorC *> inputs,
std::vector<TensorC *> outputs) {
delete param;
for (auto t : inputs) delete t;
for (auto t : outputs) delete t;
}
TEST_F(AdamWeightDecayInfer, OneDim) {
std::vector<TensorC *> inputs;
std::vector<TensorC *> outputs;
AdamWeightDecayInferInitArgs(&inputs, &outputs);
auto *param = new OpParameter;
int ret = AdamWeightDecayInferShape((const TensorC **)inputs.data(), inputs.size(), outputs.data(), outputs.size(),
reinterpret_cast<OpParameter *>(param));
ASSERT_EQ(ret, NNACL_OK);
ASSERT_EQ(outputs[0]->shape_size_, 1);
ASSERT_EQ(outputs[0]->shape_[0], 1);
AdamWeightDecayInferReleaseResources(param, inputs, outputs);
}
} // namespace mindspore

View File

@ -109,7 +109,7 @@ def init(backend_name=None):
have not been exported when backend is HCCL.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
.. note::
@ -225,7 +225,7 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
RuntimeError: If HCCL/NCCL is not available.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
.. note::
@ -322,7 +322,7 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
RuntimeError: If HCCL/NCCL is not available.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
.. note::

View File

@ -95,8 +95,8 @@ class OnRequestExit(Callback):
self.train_file_path = os.path.abspath(os.path.join(directory, f"{file_name}_train"))
self.eval_file_path = os.path.abspath(os.path.join(directory, f"{file_name}_eval"))
self.sig = Validator.check_isinstance('sig', sig, int)
if self.sig == signal.SIGKILL or self.sig == signal.SIGINT:
raise ValueError("Not support send exit request by signal SIGKILL or SIGINT.")
if hasattr(signal, "SIGKILL") and self.sig == signal.SIGKILL:
raise ValueError("Not support send exit request by signal SIGKILL.")
self.exit = False
def on_train_begin(self, run_context):

View File

@ -202,7 +202,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_proximal_adagrad_cpu_kernel.cc"
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/unique_cpu_kernel.cc"
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/unique_with_pad_cpu_kernel.cc"
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/adam_delta_cpu_kernel.cc"
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/fused_ada_factor_cpu_kernel.cc"
"../../../mindspore/ccsrc/kernel/akg/*.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/kernel/akg/*.cc"
@ -243,7 +242,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/distributed/embedding_cache/*.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/profiler/*.cc"
"../../../mindspore/ccsrc/profiler/device/profiling.cc"
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/adam_fp32.c"
"../../../mindspore/ccsrc/kernel/kernel.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/kernel/akg/akg_kernel_metadata.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/kernel/ascend_kernel_mod.cc"

View File

@ -1,93 +0,0 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "common/common_test.h"
#define private public
#define protected public
#include "plugin/device/cpu/kernel/adam_delta_cpu_kernel.h"
#undef private
#undef protected
namespace mindspore {
namespace kernel {
class AdamDeltaCpuKernelTest : public UT::Common {
public:
AdamDeltaCpuKernelTest() : adam_delta_(std::make_shared<AdamDeltaCpuKernelMod>()) {}
void SetUp() override {
delta_.clear();
m_.clear();
v_.clear();
grad_.clear();
inputs_.clear();
workspace_.clear();
outputs_.clear();
}
AddressPtr CreateKernelAddress(void *addr, size_t elem_num) {
auto kernel_addr = std::make_shared<Address>();
kernel_addr->addr = addr;
kernel_addr->size = elem_num * 4;
return kernel_addr;
}
void CreateAddress() {
inputs_.push_back(CreateKernelAddress(m_.data(), elem_num_));
inputs_.push_back(CreateKernelAddress(v_.data(), elem_num_));
inputs_.push_back(CreateKernelAddress(&beta1_power_, 1));
inputs_.push_back(CreateKernelAddress(&beta2_power_, 1));
inputs_.push_back(CreateKernelAddress(&lr_, 1));
inputs_.push_back(CreateKernelAddress(&beta1_, 1));
inputs_.push_back(CreateKernelAddress(&beta2_, 1));
inputs_.push_back(CreateKernelAddress(&epsilon_, 1));
inputs_.push_back(CreateKernelAddress(grad_.data(), elem_num_));
outputs_.push_back(CreateKernelAddress(delta_.data(), elem_num_));
}
std::vector<float> delta_;
std::vector<float> m_;
std::vector<float> v_;
std::vector<float> grad_;
std::vector<AddressPtr> inputs_;
std::vector<AddressPtr> workspace_;
std::vector<AddressPtr> outputs_;
std::shared_ptr<AdamDeltaCpuKernelMod> adam_delta_;
float beta1_power_ = 0.9;
float beta2_power_ = 0.999;
float lr_ = 0.001;
float beta1_ = 0.9;
float beta2_ = 0.999;
float epsilon_ = 1e-8;
size_t elem_num_ = 27;
};
TEST_F(AdamDeltaCpuKernelTest, compute_test) {
for (size_t i = 0; i < elem_num_; ++i) {
delta_.push_back(1.0);
m_.push_back(1.0);
v_.push_back(1.0);
grad_.push_back(1.0);
}
adam_delta_->elem_num_ = elem_num_;
CreateAddress();
adam_delta_->Launch(inputs_, workspace_, outputs_);
for (size_t i = 0; i < elem_num_; ++i) {
EXPECT_TRUE(std::fabs(delta_[i] + 0.000316) < 1e-6);
}
}
} // namespace kernel
} // namespace mindspore