!46920 develop AdamWeightDecay op on cpu lite
Merge pull request !46920 from zhangbuxue/develop_AdamWeightDecay_op_on_cpu_lite
This commit is contained in:
commit
63f565382c
|
@ -8,7 +8,6 @@ mindspore.communication
|
||||||
|
|
||||||
针对GPU设备,用户需要准备host文件和mpi,详见 `GPU指导文档 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/train_gpu.html#准备环节>`_ 。
|
针对GPU设备,用户需要准备host文件和mpi,详见 `GPU指导文档 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/train_gpu.html#准备环节>`_ 。
|
||||||
|
|
||||||
目前尚不支持CPU。
|
|
||||||
|
|
||||||
.. py:class:: mindspore.communication.GlobalComm
|
.. py:class:: mindspore.communication.GlobalComm
|
||||||
|
|
||||||
|
|
|
@ -13,13 +13,11 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include "nnacl/fp32/exp_fp32.h"
|
#include "nnacl/errorcode.h"
|
||||||
#include "nnacl/fp32/adam_fp32.h"
|
#include "nnacl/fp32/adam_fp32.h"
|
||||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
#include "nnacl/adam_fp32_simd.h"
|
||||||
#ifdef ENABLE_AVX512
|
|
||||||
#include "nnacl/avx512/adam_fp32_avx512.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient,
|
int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient,
|
||||||
size_t start, size_t end, bool use_nesterov) {
|
size_t start, size_t end, bool use_nesterov) {
|
||||||
|
@ -159,7 +157,7 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float
|
||||||
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||||
const float *gradient, size_t start, size_t end) {
|
const float *gradient, size_t start, size_t end) {
|
||||||
size_t c1 = start;
|
size_t c1 = start;
|
||||||
SIMD_RUN_AVX512(AdamWeightDecayFp32, c1, var, m, v, lr, beta1, beta2, epsilon, decay, gradient, end);
|
SIMD_RUN_NO_SCALAR(AdamWeightDecayFp32, c1, var, m, v, lr, beta1, beta2, epsilon, decay, gradient, end);
|
||||||
|
|
||||||
// remaining
|
// remaining
|
||||||
const float beta1_minus = 1 - beta1;
|
const float beta1_minus = 1 - beta1;
|
||||||
|
|
|
@ -23,7 +23,6 @@
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
@SIMD_INSTRUCTION_BEGIN@
|
@SIMD_INSTRUCTION_BEGIN@
|
||||||
#ifdef MS_SIMD_AVX512
|
|
||||||
static inline size_t AdamWeightDecayFp32@SIMD_INSTRUCTION@(size_t index, float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
static inline size_t AdamWeightDecayFp32@SIMD_INSTRUCTION@(size_t index, float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||||
const float *gradient, size_t end) {
|
const float *gradient, size_t end) {
|
||||||
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
|
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
|
||||||
|
@ -57,6 +56,7 @@ extern "C" {
|
||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef MS_SIMD_AVX512
|
||||||
static inline size_t FusedCastAdamFp32Fp16@SIMD_INSTRUCTION@(size_t index, float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
static inline size_t FusedCastAdamFp32Fp16@SIMD_INSTRUCTION@(size_t index, float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
|
||||||
float global_norm_reciprocal, size_t end) {
|
float global_norm_reciprocal, size_t end) {
|
||||||
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
|
SIMD_F32 beta1_r = SIMD_MOV_F32(beta1);
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "nnacl/infer/adam_weight_decay_infer.h"
|
||||||
|
#include "nnacl/infer/infer_register.h"
|
||||||
|
|
||||||
|
int AdamWeightDecayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||||
|
OpParameter *parameter) {
|
||||||
|
const size_t expected_inputs_size = 9;
|
||||||
|
const int var_idx = 0;
|
||||||
|
const int m_idx = 1;
|
||||||
|
const int v_idx = 2;
|
||||||
|
const int lr_idx = 3;
|
||||||
|
const int beta1_idx = 4;
|
||||||
|
const int beta2_idx = 5;
|
||||||
|
const int epsilon = 6;
|
||||||
|
const int decay_idx = 7;
|
||||||
|
const int grad_idx = 8;
|
||||||
|
int check_ret =
|
||||||
|
CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, expected_inputs_size);
|
||||||
|
if (check_ret != NNACL_OK) {
|
||||||
|
return check_ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (GetElementNum(inputs[var_idx]) != GetElementNum(inputs[m_idx]) ||
|
||||||
|
GetElementNum(inputs[var_idx]) != GetElementNum(inputs[v_idx]) ||
|
||||||
|
GetElementNum(inputs[var_idx]) != GetElementNum(inputs[grad_idx]) || GetElementNum(inputs[lr_idx]) != 1 ||
|
||||||
|
GetElementNum(inputs[beta1_idx]) != 1 || GetElementNum(inputs[beta2_idx]) != 1 ||
|
||||||
|
GetElementNum(inputs[epsilon]) != 1 || GetElementNum(inputs[decay_idx]) != 1) {
|
||||||
|
return NNACL_ERR;
|
||||||
|
}
|
||||||
|
if (outputs_size != 0) {
|
||||||
|
TensorC *out = outputs[0];
|
||||||
|
SetDataTypeFormat(out, inputs[0]);
|
||||||
|
out->shape_size_ = 1;
|
||||||
|
out->shape_[0] = 1;
|
||||||
|
}
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_INFER(AdamWeightDecay, PrimType_AdamWeightDecay, AdamWeightDecayInferShape)
|
|
@ -0,0 +1,32 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H
|
||||||
|
#define MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H
|
||||||
|
|
||||||
|
#include "nnacl/infer/common_infer.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
int AdamWeightDecayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||||
|
OpParameter *parameter);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif // MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H
|
|
@ -18,6 +18,7 @@
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
#include "nnacl/infer/activation_grad_infer.h"
|
#include "nnacl/infer/activation_grad_infer.h"
|
||||||
#include "nnacl/infer/adam_infer.h"
|
#include "nnacl/infer/adam_infer.h"
|
||||||
|
#include "nnacl/infer/adam_weight_decay_infer.h"
|
||||||
#include "nnacl/infer/add_sub_grad_infer.h"
|
#include "nnacl/infer/add_sub_grad_infer.h"
|
||||||
#include "nnacl/infer/addn_infer.h"
|
#include "nnacl/infer/addn_infer.h"
|
||||||
#include "nnacl/infer/affine_infer.h"
|
#include "nnacl/infer/affine_infer.h"
|
||||||
|
@ -168,6 +169,7 @@ void RegAllInferFunc1() {
|
||||||
g_infer_func[PrimType_Activation] = CommonInferShape;
|
g_infer_func[PrimType_Activation] = CommonInferShape;
|
||||||
g_infer_func[PrimType_ActivationGrad] = ActivationGradInferShape;
|
g_infer_func[PrimType_ActivationGrad] = ActivationGradInferShape;
|
||||||
g_infer_func[PrimType_Adam] = AdamInferShape;
|
g_infer_func[PrimType_Adam] = AdamInferShape;
|
||||||
|
g_infer_func[PrimType_AdamWeightDecay] = AdamWeightDecayInferShape;
|
||||||
g_infer_func[PrimType_AdderFusion] = Conv2dInferShape;
|
g_infer_func[PrimType_AdderFusion] = Conv2dInferShape;
|
||||||
g_infer_func[PrimType_AddFusion] = ArithmeticInferShape;
|
g_infer_func[PrimType_AddFusion] = ArithmeticInferShape;
|
||||||
g_infer_func[PrimType_AddGrad] = AddSubGradInferShape;
|
g_infer_func[PrimType_AddGrad] = AddSubGradInferShape;
|
||||||
|
|
|
@ -28,7 +28,7 @@ int ScatterNdInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
|
||||||
if (shape->data_ == NULL) {
|
if (shape->data_ == NULL) {
|
||||||
return NNACL_INFER_INVALID;
|
return NNACL_INFER_INVALID;
|
||||||
}
|
}
|
||||||
const TensorC *update = inputs[FIRST_INPUT];
|
const TensorC *update = inputs[SECOND_INPUT];
|
||||||
TensorC *output = outputs[0];
|
TensorC *output = outputs[0];
|
||||||
|
|
||||||
SetDataTypeFormat(output, update);
|
SetDataTypeFormat(output, update);
|
||||||
|
|
|
@ -538,8 +538,9 @@ enum PrimType {
|
||||||
PrimType_ScatterElements = 217,
|
PrimType_ScatterElements = 217,
|
||||||
PrimType_Triu = 218,
|
PrimType_Triu = 218,
|
||||||
PrimType_Tril = 219,
|
PrimType_Tril = 219,
|
||||||
|
PrimType_AdamWeightDecay = 220,
|
||||||
PrimType_MIN = PrimType_NONE,
|
PrimType_MIN = PrimType_NONE,
|
||||||
PrimType_MAX = PrimType_Tril + 1,
|
PrimType_MAX = PrimType_AdamWeightDecay + 1,
|
||||||
|
|
||||||
// inner operators.
|
// inner operators.
|
||||||
PrimType_Inner_ToFormat = 10000,
|
PrimType_Inner_ToFormat = 10000,
|
||||||
|
|
|
@ -237,6 +237,7 @@ union PrimitiveType {
|
||||||
ScatterElements,
|
ScatterElements,
|
||||||
Triu,
|
Triu,
|
||||||
Tril,
|
Tril,
|
||||||
|
AdamWeightDecay,
|
||||||
}
|
}
|
||||||
|
|
||||||
table Abs {
|
table Abs {
|
||||||
|
@ -1322,3 +1323,7 @@ table Triu {
|
||||||
|
|
||||||
table Tril {
|
table Tril {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
table AdamWeightDecay {
|
||||||
|
use_locking: bool;
|
||||||
|
}
|
||||||
|
|
|
@ -237,6 +237,7 @@ OP_TYPE(SparseSegmentSum)
|
||||||
OP_TYPE(ScatterElements)
|
OP_TYPE(ScatterElements)
|
||||||
OP_TYPE(Triu)
|
OP_TYPE(Triu)
|
||||||
OP_TYPE(Tril)
|
OP_TYPE(Tril)
|
||||||
|
OP_TYPE(AdamWeightDecay)
|
||||||
OP_TYPE_DEF_END(PrimitiveType)
|
OP_TYPE_DEF_END(PrimitiveType)
|
||||||
|
|
||||||
OP_SCHEMA_DEF(Abs)
|
OP_SCHEMA_DEF(Abs)
|
||||||
|
@ -1326,3 +1327,7 @@ OP_SCHEMA_DEF_END(Triu)
|
||||||
|
|
||||||
OP_SCHEMA_DEF(Tril)
|
OP_SCHEMA_DEF(Tril)
|
||||||
OP_SCHEMA_DEF_END(Tril)
|
OP_SCHEMA_DEF_END(Tril)
|
||||||
|
|
||||||
|
OP_SCHEMA_DEF(AdamWeightDecay)
|
||||||
|
OP_ATTR(use_locking, bool)
|
||||||
|
OP_SCHEMA_DEF_END(AdamWeightDecay)
|
||||||
|
|
|
@ -266,6 +266,7 @@
|
||||||
#include "ops/scatter_elements.h"
|
#include "ops/scatter_elements.h"
|
||||||
#include "ops/triu.h"
|
#include "ops/triu.h"
|
||||||
#include "ops/tril.h"
|
#include "ops/tril.h"
|
||||||
|
#include "ops/adam_weight_decay.h"
|
||||||
|
|
||||||
namespace mindspore::lite::ops {
|
namespace mindspore::lite::ops {
|
||||||
#define FUNC_MSOP2SCHEMAOP_DECLARE(OP) std::unique_ptr<schema::PrimitiveT> MSOp2SchemaOp(const mindspore::ops::OP *op);
|
#define FUNC_MSOP2SCHEMAOP_DECLARE(OP) std::unique_ptr<schema::PrimitiveT> MSOp2SchemaOp(const mindspore::ops::OP *op);
|
||||||
|
@ -500,6 +501,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(SparseSegmentSum)
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(ScatterElements)
|
FUNC_MSOP2SCHEMAOP_DECLARE(ScatterElements)
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Triu)
|
FUNC_MSOP2SCHEMAOP_DECLARE(Triu)
|
||||||
FUNC_MSOP2SCHEMAOP_DECLARE(Tril)
|
FUNC_MSOP2SCHEMAOP_DECLARE(Tril)
|
||||||
|
FUNC_MSOP2SCHEMAOP_DECLARE(AdamWeightDecay)
|
||||||
#endif
|
#endif
|
||||||
} // namespace mindspore::lite::ops
|
} // namespace mindspore::lite::ops
|
||||||
#else
|
#else
|
||||||
|
|
|
@ -278,6 +278,7 @@ REG_MINDSPORE_OPERATOR(Tril)
|
||||||
REG_MINDSPORE_OPERATOR(SparseFillEmptyRows)
|
REG_MINDSPORE_OPERATOR(SparseFillEmptyRows)
|
||||||
REG_MINDSPORE_OPERATOR(SparseReshape)
|
REG_MINDSPORE_OPERATOR(SparseReshape)
|
||||||
REG_MINDSPORE_OPERATOR(SparseSegmentSum)
|
REG_MINDSPORE_OPERATOR(SparseSegmentSum)
|
||||||
|
REG_MINDSPORE_OPERATOR(AdamWeightDecay)
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -30,9 +30,6 @@ using mindspore::lite::RET_OK;
|
||||||
using mindspore::schema::PrimitiveType_Adam;
|
using mindspore::schema::PrimitiveType_Adam;
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
constexpr static int kWeightIdx = 0;
|
|
||||||
constexpr static int kMomentVector1stIdx = 1;
|
|
||||||
constexpr static int kMomentVector2stIdx = 2;
|
|
||||||
constexpr static int kBeta1PowerIdx = 3;
|
constexpr static int kBeta1PowerIdx = 3;
|
||||||
constexpr static int kBeta2PowerIdx = 4;
|
constexpr static int kBeta2PowerIdx = 4;
|
||||||
constexpr static int kBeta1Idx = 6;
|
constexpr static int kBeta1Idx = 6;
|
||||||
|
|
|
@ -0,0 +1,150 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/litert/kernel/cpu/fp32_grad/adam_weight_decay.h"
|
||||||
|
#include <cmath>
|
||||||
|
#include <string>
|
||||||
|
#include "schema/model_generated.h"
|
||||||
|
#include "src/litert/kernel_registry.h"
|
||||||
|
#include "include/errorcode.h"
|
||||||
|
#include "plugin/device/cpu/kernel/nnacl/fp32/adam_fp32.h"
|
||||||
|
|
||||||
|
using mindspore::kernel::KERNEL_ARCH;
|
||||||
|
using mindspore::lite::KernelRegistrar;
|
||||||
|
using mindspore::lite::RET_ERROR;
|
||||||
|
using mindspore::lite::RET_OK;
|
||||||
|
using mindspore::schema::PrimitiveType_AdamWeightDecay;
|
||||||
|
|
||||||
|
namespace mindspore::kernel {
|
||||||
|
namespace {
|
||||||
|
constexpr static int kBeta1Idx = 4;
|
||||||
|
constexpr static int kBeta2Idx = 5;
|
||||||
|
constexpr static int kEpsilonIdx = 6;
|
||||||
|
constexpr static int kDecayIdx = 7;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
int AdamWeightDecayCPUKernel::ReSize() { return RET_OK; }
|
||||||
|
|
||||||
|
int AdamWeightDecayCPUKernel::DoExecute(int task_id) {
|
||||||
|
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_9D);
|
||||||
|
auto weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIdx)->MutableData());
|
||||||
|
auto m = reinterpret_cast<float *>(in_tensors_.at(kMomentVector1stIdx)->MutableData());
|
||||||
|
auto v = reinterpret_cast<float *>(in_tensors_.at(kMomentVector2stIdx)->MutableData());
|
||||||
|
auto learning_rate = lr_;
|
||||||
|
auto beta1 = reinterpret_cast<float *>(in_tensors_.at(kBeta1Idx)->MutableData())[0];
|
||||||
|
auto beta2 = reinterpret_cast<float *>(in_tensors_.at(kBeta2Idx)->MutableData())[0];
|
||||||
|
auto eps = reinterpret_cast<float *>(in_tensors_.at(kEpsilonIdx)->MutableData())[0];
|
||||||
|
auto decay = reinterpret_cast<float *>(in_tensors_.at(kDecayIdx)->MutableData())[0];
|
||||||
|
auto gradient = reinterpret_cast<float *>(in_tensors_.at(kGradientIdx)->MutableData());
|
||||||
|
int length = in_tensors_.at(kWeightIdx)->ElementsNum();
|
||||||
|
CHECK_NULL_RETURN(weight);
|
||||||
|
CHECK_NULL_RETURN(m);
|
||||||
|
CHECK_NULL_RETURN(v);
|
||||||
|
CHECK_NULL_RETURN(gradient);
|
||||||
|
|
||||||
|
int stride = UP_DIV(length, thread_count_);
|
||||||
|
int count = MSMIN(stride, length - stride * task_id);
|
||||||
|
int start = stride * task_id;
|
||||||
|
int end = start + count;
|
||||||
|
|
||||||
|
return AdamWeightDecayFp32(weight, m, v, learning_rate, beta1, beta2, eps, decay, gradient, start, end);
|
||||||
|
}
|
||||||
|
|
||||||
|
int AdamWeightDecayRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||||
|
auto adam_weight_decay_kernel = reinterpret_cast<AdamWeightDecayCPUKernel *>(cdata);
|
||||||
|
CHECK_NULL_RETURN(adam_weight_decay_kernel);
|
||||||
|
auto error_code = RET_OK;
|
||||||
|
if (adam_weight_decay_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) {
|
||||||
|
error_code = adam_weight_decay_kernel->ExecuteVirtualBatch(task_id);
|
||||||
|
} else if (adam_weight_decay_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) {
|
||||||
|
error_code = adam_weight_decay_kernel->ExecuteVirtualBatch(task_id);
|
||||||
|
} else {
|
||||||
|
error_code = adam_weight_decay_kernel->DoExecute(task_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error_code != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "AdamWeightDecay run error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int AdamWeightDecayCPUKernel::Run() {
|
||||||
|
int error_code = ParallelLaunch(this->ms_context_, AdamWeightDecayRun, this, thread_count_);
|
||||||
|
if (error_code != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "AdamWeightDecay function error error_code[" << error_code << "]";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int AdamWeightDecayCPUKernel::Prepare() {
|
||||||
|
auto ret = OptimizerKernel::Prepare();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Failed to initialize AdamWeightDecay Kernel";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> AdamWeightDecayCPUKernel::GetOptimizerParamsIdxs() const {
|
||||||
|
std::vector<int> indices = {4, 5, 6, 7};
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
int AdamWeightDecayCPUKernel::OptimizerStep() {
|
||||||
|
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_9D - 1);
|
||||||
|
auto weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIdx)->MutableData());
|
||||||
|
auto m = reinterpret_cast<float *>(in_tensors_.at(kMomentVector1stIdx)->MutableData());
|
||||||
|
auto v = reinterpret_cast<float *>(in_tensors_.at(kMomentVector2stIdx)->MutableData());
|
||||||
|
auto learning_rate = lr_;
|
||||||
|
auto beta1 = reinterpret_cast<float *>(in_tensors_.at(kBeta1Idx)->MutableData())[0];
|
||||||
|
auto beta2 = reinterpret_cast<float *>(in_tensors_.at(kBeta2Idx)->MutableData())[0];
|
||||||
|
auto eps = reinterpret_cast<float *>(in_tensors_.at(kEpsilonIdx)->MutableData())[0];
|
||||||
|
auto decay = reinterpret_cast<float *>(in_tensors_.at(kDecayIdx)->MutableData())[0];
|
||||||
|
int length = in_tensors_.at(kWeightIdx)->ElementsNum();
|
||||||
|
CHECK_NULL_RETURN(weight);
|
||||||
|
CHECK_NULL_RETURN(m);
|
||||||
|
CHECK_NULL_RETURN(v);
|
||||||
|
|
||||||
|
int ret = RET_OK;
|
||||||
|
if (grad_sum_ != nullptr && valid_grad_sum_) {
|
||||||
|
size_t start = 0;
|
||||||
|
size_t end = length;
|
||||||
|
ret = AdamWeightDecayFp32(weight, m, v, learning_rate, beta1, beta2, eps, decay, grad_sum_, start, end);
|
||||||
|
std::fill(grad_sum_, grad_sum_ + length, 0);
|
||||||
|
OptimizerKernel::OptimizerStep();
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel::LiteKernel *CpuAdamWeightDecayFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||||
|
const std::vector<lite::Tensor *> &outputs,
|
||||||
|
OpParameter *opParameter, const lite::InnerContext *ctx,
|
||||||
|
const kernel::KernelKey &desc) {
|
||||||
|
MS_CHECK_TRUE_MSG(opParameter != nullptr, nullptr, "Op parameter is nullptr.");
|
||||||
|
MS_ASSERT(desc.type == schema::PrimitiveType_AdamWeightDecay);
|
||||||
|
auto *kernel = new (std::nothrow) AdamWeightDecayCPUKernel(opParameter, inputs, outputs, ctx);
|
||||||
|
if (kernel == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new AdamWeightDecayCPUKernel fail!";
|
||||||
|
free(opParameter);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return kernel;
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_AdamWeightDecay, CpuAdamWeightDecayFp32KernelCreator)
|
||||||
|
} // namespace mindspore::kernel
|
|
@ -0,0 +1,50 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_ADAM_WEIGHT_DECAY_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_ADAM_WEIGHT_DECAY_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "src/train/optimizer_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore::kernel {
|
||||||
|
constexpr static int kLrIdx = 3;
|
||||||
|
constexpr static int kGradientIdx = 8;
|
||||||
|
|
||||||
|
class AdamWeightDecayCPUKernel : public OptimizerKernel {
|
||||||
|
public:
|
||||||
|
explicit AdamWeightDecayCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||||
|
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||||
|
: OptimizerKernel(parameter, inputs, outputs, ctx, kLrIdx, kGradientIdx), thread_count_(ctx->thread_num_) {}
|
||||||
|
~AdamWeightDecayCPUKernel() override {
|
||||||
|
if (grad_sum_ != nullptr) {
|
||||||
|
ms_context_->allocator->Free(grad_sum_);
|
||||||
|
grad_sum_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int Prepare() override;
|
||||||
|
int ReSize() override;
|
||||||
|
int Run() override;
|
||||||
|
int DoExecute(int task_id);
|
||||||
|
int OptimizerStep() override;
|
||||||
|
std::vector<int> GetOptimizerParamsIdxs() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int thread_count_;
|
||||||
|
};
|
||||||
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
|
#endif // #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_ADAM_WEIGHT_DECAY_H_
|
|
@ -29,6 +29,10 @@ using mindspore::lite::RET_OK;
|
||||||
using mindspore::lite::RET_OUT_OF_TENSOR_RANGE;
|
using mindspore::lite::RET_OUT_OF_TENSOR_RANGE;
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
|
constexpr static int kWeightIdx = 0;
|
||||||
|
constexpr static int kMomentVector1stIdx = 1;
|
||||||
|
constexpr static int kMomentVector2stIdx = 2;
|
||||||
|
|
||||||
enum class WeightUpdateMode { NORMAL, VIRTUAL_BATCH, ACCUMULATE_GRADS };
|
enum class WeightUpdateMode { NORMAL, VIRTUAL_BATCH, ACCUMULATE_GRADS };
|
||||||
|
|
||||||
class OptimizerKernel : public LiteKernel {
|
class OptimizerKernel : public LiteKernel {
|
||||||
|
|
|
@ -621,6 +621,8 @@ void PopulateTrainParameters() {
|
||||||
Registry SgdParameterRegistry(schema::PrimitiveType_SGD, PopulateSgdParameter, lite::SCHEMA_CUR);
|
Registry SgdParameterRegistry(schema::PrimitiveType_SGD, PopulateSgdParameter, lite::SCHEMA_CUR);
|
||||||
Registry BNGradParameterRegistry(schema::PrimitiveType_BatchNormGrad, PopulateBNGradParameter, lite::SCHEMA_CUR);
|
Registry BNGradParameterRegistry(schema::PrimitiveType_BatchNormGrad, PopulateBNGradParameter, lite::SCHEMA_CUR);
|
||||||
Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter, lite::SCHEMA_CUR);
|
Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter, lite::SCHEMA_CUR);
|
||||||
|
Registry AdamWeightDecayParameterRegistry(schema::PrimitiveType_AdamWeightDecay, lite::DefaultPopulateParameter,
|
||||||
|
lite::SCHEMA_CUR);
|
||||||
Registry AssignParameterRegistry(schema::PrimitiveType_Assign, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
|
Registry AssignParameterRegistry(schema::PrimitiveType_Assign, lite::DefaultPopulateParameter, lite::SCHEMA_CUR);
|
||||||
Registry AssignAddParameterRegistry(schema::PrimitiveType_AssignAdd, lite::DefaultPopulateParameter,
|
Registry AssignAddParameterRegistry(schema::PrimitiveType_AssignAdd, lite::DefaultPopulateParameter,
|
||||||
lite::SCHEMA_CUR);
|
lite::SCHEMA_CUR);
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "common/common_test.h"
|
||||||
|
#include "nnacl/infer/adam_weight_decay_infer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class AdamWeightDecayInfer : public mindspore::CommonTest {
|
||||||
|
public:
|
||||||
|
AdamWeightDecayInfer() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
void AdamWeightDecayInferInitArgs(std::vector<TensorC *> *inputs, std::vector<TensorC *> *outputs) {
|
||||||
|
const size_t inputs_size = 9;
|
||||||
|
for (size_t i = 0; i < inputs_size; i++) {
|
||||||
|
auto *input_x = new TensorC;
|
||||||
|
input_x->shape_size_ = 1;
|
||||||
|
input_x->shape_[0] = 1;
|
||||||
|
inputs->push_back(input_x);
|
||||||
|
}
|
||||||
|
auto *output = new TensorC;
|
||||||
|
outputs->push_back(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AdamWeightDecayInferReleaseResources(OpParameter *param, std::vector<TensorC *> inputs,
|
||||||
|
std::vector<TensorC *> outputs) {
|
||||||
|
delete param;
|
||||||
|
for (auto t : inputs) delete t;
|
||||||
|
for (auto t : outputs) delete t;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AdamWeightDecayInfer, OneDim) {
|
||||||
|
std::vector<TensorC *> inputs;
|
||||||
|
std::vector<TensorC *> outputs;
|
||||||
|
AdamWeightDecayInferInitArgs(&inputs, &outputs);
|
||||||
|
auto *param = new OpParameter;
|
||||||
|
int ret = AdamWeightDecayInferShape((const TensorC **)inputs.data(), inputs.size(), outputs.data(), outputs.size(),
|
||||||
|
reinterpret_cast<OpParameter *>(param));
|
||||||
|
ASSERT_EQ(ret, NNACL_OK);
|
||||||
|
ASSERT_EQ(outputs[0]->shape_size_, 1);
|
||||||
|
ASSERT_EQ(outputs[0]->shape_[0], 1);
|
||||||
|
AdamWeightDecayInferReleaseResources(param, inputs, outputs);
|
||||||
|
}
|
||||||
|
} // namespace mindspore
|
|
@ -109,7 +109,7 @@ def init(backend_name=None):
|
||||||
have not been exported when backend is HCCL.
|
have not been exported when backend is HCCL.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
.. note::
|
.. note::
|
||||||
|
@ -225,7 +225,7 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
||||||
RuntimeError: If HCCL/NCCL is not available.
|
RuntimeError: If HCCL/NCCL is not available.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
.. note::
|
.. note::
|
||||||
|
@ -322,7 +322,7 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
|
||||||
RuntimeError: If HCCL/NCCL is not available.
|
RuntimeError: If HCCL/NCCL is not available.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
.. note::
|
.. note::
|
||||||
|
|
|
@ -95,8 +95,8 @@ class OnRequestExit(Callback):
|
||||||
self.train_file_path = os.path.abspath(os.path.join(directory, f"{file_name}_train"))
|
self.train_file_path = os.path.abspath(os.path.join(directory, f"{file_name}_train"))
|
||||||
self.eval_file_path = os.path.abspath(os.path.join(directory, f"{file_name}_eval"))
|
self.eval_file_path = os.path.abspath(os.path.join(directory, f"{file_name}_eval"))
|
||||||
self.sig = Validator.check_isinstance('sig', sig, int)
|
self.sig = Validator.check_isinstance('sig', sig, int)
|
||||||
if self.sig == signal.SIGKILL or self.sig == signal.SIGINT:
|
if hasattr(signal, "SIGKILL") and self.sig == signal.SIGKILL:
|
||||||
raise ValueError("Not support send exit request by signal SIGKILL or SIGINT.")
|
raise ValueError("Not support send exit request by signal SIGKILL.")
|
||||||
self.exit = False
|
self.exit = False
|
||||||
|
|
||||||
def on_train_begin(self, run_context):
|
def on_train_begin(self, run_context):
|
||||||
|
|
|
@ -202,7 +202,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_proximal_adagrad_cpu_kernel.cc"
|
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_proximal_adagrad_cpu_kernel.cc"
|
||||||
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/unique_cpu_kernel.cc"
|
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/unique_cpu_kernel.cc"
|
||||||
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/unique_with_pad_cpu_kernel.cc"
|
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/unique_with_pad_cpu_kernel.cc"
|
||||||
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/adam_delta_cpu_kernel.cc"
|
|
||||||
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/fused_ada_factor_cpu_kernel.cc"
|
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/fused_ada_factor_cpu_kernel.cc"
|
||||||
"../../../mindspore/ccsrc/kernel/akg/*.cc"
|
"../../../mindspore/ccsrc/kernel/akg/*.cc"
|
||||||
"../../../mindspore/ccsrc/plugin/device/ascend/kernel/akg/*.cc"
|
"../../../mindspore/ccsrc/plugin/device/ascend/kernel/akg/*.cc"
|
||||||
|
@ -245,7 +244,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
"../../../mindspore/ccsrc/distributed/embedding_cache/*.cc"
|
"../../../mindspore/ccsrc/distributed/embedding_cache/*.cc"
|
||||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/profiler/*.cc"
|
"../../../mindspore/ccsrc/plugin/device/ascend/hal/profiler/*.cc"
|
||||||
"../../../mindspore/ccsrc/profiler/device/profiling.cc"
|
"../../../mindspore/ccsrc/profiler/device/profiling.cc"
|
||||||
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/adam_fp32.c"
|
|
||||||
"../../../mindspore/ccsrc/kernel/kernel.cc"
|
"../../../mindspore/ccsrc/kernel/kernel.cc"
|
||||||
"../../../mindspore/ccsrc/plugin/device/ascend/kernel/akg/akg_kernel_metadata.cc"
|
"../../../mindspore/ccsrc/plugin/device/ascend/kernel/akg/akg_kernel_metadata.cc"
|
||||||
"../../../mindspore/ccsrc/plugin/device/ascend/kernel/ascend_kernel_mod.cc"
|
"../../../mindspore/ccsrc/plugin/device/ascend/kernel/ascend_kernel_mod.cc"
|
||||||
|
|
|
@ -1,93 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include "common/common_test.h"
|
|
||||||
#define private public
|
|
||||||
#define protected public
|
|
||||||
#include "plugin/device/cpu/kernel/adam_delta_cpu_kernel.h"
|
|
||||||
#undef private
|
|
||||||
#undef protected
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace kernel {
|
|
||||||
class AdamDeltaCpuKernelTest : public UT::Common {
|
|
||||||
public:
|
|
||||||
AdamDeltaCpuKernelTest() : adam_delta_(std::make_shared<AdamDeltaCpuKernelMod>()) {}
|
|
||||||
|
|
||||||
void SetUp() override {
|
|
||||||
delta_.clear();
|
|
||||||
m_.clear();
|
|
||||||
v_.clear();
|
|
||||||
grad_.clear();
|
|
||||||
inputs_.clear();
|
|
||||||
workspace_.clear();
|
|
||||||
outputs_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
AddressPtr CreateKernelAddress(void *addr, size_t elem_num) {
|
|
||||||
auto kernel_addr = std::make_shared<Address>();
|
|
||||||
kernel_addr->addr = addr;
|
|
||||||
kernel_addr->size = elem_num * 4;
|
|
||||||
return kernel_addr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateAddress() {
|
|
||||||
inputs_.push_back(CreateKernelAddress(m_.data(), elem_num_));
|
|
||||||
inputs_.push_back(CreateKernelAddress(v_.data(), elem_num_));
|
|
||||||
inputs_.push_back(CreateKernelAddress(&beta1_power_, 1));
|
|
||||||
inputs_.push_back(CreateKernelAddress(&beta2_power_, 1));
|
|
||||||
inputs_.push_back(CreateKernelAddress(&lr_, 1));
|
|
||||||
inputs_.push_back(CreateKernelAddress(&beta1_, 1));
|
|
||||||
inputs_.push_back(CreateKernelAddress(&beta2_, 1));
|
|
||||||
inputs_.push_back(CreateKernelAddress(&epsilon_, 1));
|
|
||||||
inputs_.push_back(CreateKernelAddress(grad_.data(), elem_num_));
|
|
||||||
outputs_.push_back(CreateKernelAddress(delta_.data(), elem_num_));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<float> delta_;
|
|
||||||
std::vector<float> m_;
|
|
||||||
std::vector<float> v_;
|
|
||||||
std::vector<float> grad_;
|
|
||||||
std::vector<AddressPtr> inputs_;
|
|
||||||
std::vector<AddressPtr> workspace_;
|
|
||||||
std::vector<AddressPtr> outputs_;
|
|
||||||
std::shared_ptr<AdamDeltaCpuKernelMod> adam_delta_;
|
|
||||||
float beta1_power_ = 0.9;
|
|
||||||
float beta2_power_ = 0.999;
|
|
||||||
float lr_ = 0.001;
|
|
||||||
float beta1_ = 0.9;
|
|
||||||
float beta2_ = 0.999;
|
|
||||||
float epsilon_ = 1e-8;
|
|
||||||
size_t elem_num_ = 27;
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AdamDeltaCpuKernelTest, compute_test) {
|
|
||||||
for (size_t i = 0; i < elem_num_; ++i) {
|
|
||||||
delta_.push_back(1.0);
|
|
||||||
m_.push_back(1.0);
|
|
||||||
v_.push_back(1.0);
|
|
||||||
grad_.push_back(1.0);
|
|
||||||
}
|
|
||||||
adam_delta_->elem_num_ = elem_num_;
|
|
||||||
CreateAddress();
|
|
||||||
adam_delta_->Launch(inputs_, workspace_, outputs_);
|
|
||||||
for (size_t i = 0; i < elem_num_; ++i) {
|
|
||||||
EXPECT_TRUE(std::fabs(delta_[i] + 0.000316) < 1e-6);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace kernel
|
|
||||||
} // namespace mindspore
|
|
Loading…
Reference in New Issue