fix bugs for applyadamwithamsgrad

This commit is contained in:
OwenSec 2022-08-03 10:07:47 +08:00
parent ed1e4cc7b9
commit 815a06f410
12 changed files with 119 additions and 159 deletions

View File

@ -29,7 +29,7 @@ namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kApplyAdamWithAmsgradInputsNum = 8;
constexpr size_t kApplyAdamWithAmsgradOutputsNum = 4;
constexpr size_t kApplyAdamWithAmsgradOutputsNum = 1;
constexpr size_t kScalarIndex = 0;
constexpr size_t kIndexVar = 0;
constexpr size_t kIndexM = 1;
@ -65,9 +65,9 @@ bool ApplyAdamWithAmsgradCpuKernelMod::Init(const BaseOperatorPtr &base_operator
return false;
}
beta1_[0] = kernel_ptr->get_beta1();
beta2_[0] = kernel_ptr->get_beta2();
epsilon_[0] = kernel_ptr->get_epsilon();
beta1_ = kernel_ptr->get_beta1();
beta2_ = kernel_ptr->get_beta2();
epsilon_ = kernel_ptr->get_epsilon();
return true;
}
@ -139,15 +139,14 @@ void ApplyAdamWithAmsgradCpuKernelMod::LaunchApplyAdamWithAmsgrad(const std::vec
T *lr = reinterpret_cast<T *>(inputs[kIndexLr]->addr);
T *gradient = reinterpret_cast<T *>(inputs[kIndexGrad]->addr);
T beta1 = static_cast<T>(beta1_[0]);
T beta2 = static_cast<T>(beta2_[0]);
T epsilon = static_cast<T>(epsilon_[0]);
T beta1 = static_cast<T>(beta1_);
T beta2 = static_cast<T>(beta2_);
T epsilon = static_cast<T>(epsilon_);
constexpr float ONE = 1.0;
T ONE = static_cast<T>(1.0);
for (int64_t b = 0; b < batch_size_; b++) {
// multithreading
T new_lr = static_cast<T>(static_cast<float>(lr[b]) * std::sqrt(ONE - static_cast<float>(beta2_power[b])) /
(ONE - static_cast<float>(beta1_power[b])));
T new_lr = lr[b] * static_cast<T>(std::sqrt(static_cast<double>(ONE - beta2_power[b]))) / (ONE - beta1_power[b]);
auto task = [this, &var, &m, &v, &vhat, &gradient, new_lr, beta1, beta2, epsilon](size_t start, size_t end) {
T one = static_cast<T>(1.0);
for (size_t i = start; i < end; i++) {
@ -215,13 +214,7 @@ std::vector<KernelAttr> ApplyAdamWithAmsgradCpuKernelMod::GetOpSupport() {
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutInRef(0, 0)
.AddOutInRef(1, 1)
.AddOutInRef(2, 2)
.AddOutInRef(3, 3),
.AddOutInRef(0, 0),
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
@ -232,13 +225,7 @@ std::vector<KernelAttr> ApplyAdamWithAmsgradCpuKernelMod::GetOpSupport() {
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutInRef(0, 0)
.AddOutInRef(1, 1)
.AddOutInRef(2, 2)
.AddOutInRef(3, 3)};
.AddOutInRef(0, 0)};
return support_list;
}

View File

@ -46,9 +46,9 @@ class ApplyAdamWithAmsgradCpuKernelMod : public NativeCpuKernelMod {
template <typename T>
void LaunchApplyAdamWithAmsgrad(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
std::vector<float> beta1_ = {0.9};
std::vector<float> beta2_ = {0.999};
std::vector<float> epsilon_ = {1e-8};
float beta1_{0.9};
float beta2_{0.999};
float epsilon_{1e-8};
int64_t batch_rank_{0};
int64_t batch_size_{1};

View File

@ -25,7 +25,7 @@ __device__ __forceinline__ T sqrtFunc(T x) {
template <>
__device__ __forceinline__ half sqrtFunc(half x) {
return sqrt(__half2float(x));
return hsqrt(x);
}
template <typename T>
@ -35,51 +35,55 @@ __device__ __forceinline__ T maxFunc(T x, T y) {
template <>
__device__ __forceinline__ half maxFunc(half x, half y) {
return __half2float(x) > __half2float(y)? __half2float(x) : __half2float(y);
return x > y? x : y;
}
template <typename T>
__global__ void CalApplyAdamWithAmsgradKernel(const size_t input_elements, const int64_t batch_size, T *var, T *m,
__global__ void CalApplyAdamWithAmsgradKernel(const size_t size, const int64_t batch_size, T *var, T *m,
T *v, T *vhat, T *beta1_power, T *beta2_power, const T *lr,
const T *grad, const float beta1, const float beta2,
const float epsilon) {
auto all_elements = input_elements * batch_size;
const T *grad, const T beta1, const T beta2,
const T epsilon, T *output_var) {
auto all_elements = size * batch_size;
const T one = static_cast<T>(1.0);
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < all_elements; pos += gridDim.x * blockDim.x) {
auto batch = pos / input_elements;
auto batch = pos / size;
auto new_learning_rate = lr[batch] * sqrtFunc(one - beta2_power[batch]) / (one - beta1_power[batch]);
m[pos] += (grad[pos] - m[pos]) * (one - static_cast<T>(beta1));
v[pos] += (grad[pos] * grad[pos] - v[pos]) * (one - static_cast<T>(beta2));
vhat[pos] = maxFunc(vhat[pos], v[pos]);
var[pos] -= new_learning_rate * m[pos] / (sqrtFunc(vhat[pos]) + static_cast<T>(epsilon));
output_var[pos] = var[pos];
}
}
template <typename T>
void CalApplyAdamWithAmsgrad(const size_t input_elements, const int64_t batch_size, T *var, T *m, T *v, T *vhat,
T *beta1_power, T *beta2_power, const T *lr, const T *grad, const float beta1,
const float beta2, const float epsilon, const uint32_t &device_id,
void CalApplyAdamWithAmsgrad(const size_t size, const int64_t batch_size, T *var, T *m, T *v, T *vhat,
T *beta1_power, T *beta2_power, const T *lr, const T *grad, const T beta1,
const T beta2, const T epsilon, T *output_var, const uint32_t &device_id,
cudaStream_t stream_ptr) {
CalApplyAdamWithAmsgradKernel<<<CUDA_BLOCKS(device_id, input_elements * batch_size), CUDA_THREADS(device_id), 0,
stream_ptr>>>(input_elements, batch_size, var, m, v, vhat, beta1_power, beta2_power,
lr, grad, beta1, beta2, epsilon);
CalApplyAdamWithAmsgradKernel<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0,
stream_ptr>>>(size, batch_size, var, m, v, vhat, beta1_power, beta2_power,
lr, grad, beta1, beta2, epsilon, output_var);
}
template CUDA_LIB_EXPORT void CalApplyAdamWithAmsgrad<double>(const size_t size, const int64_t batch_size, double *var,
double *m, double *v, double *vhat, double *beta1_power,
double *beta2_power, const double *lr,
const double *grad, const float beta1, const float beta2,
const float epsilon, const uint32_t &device_id,
const double *grad, const double beta1,
const double beta2, const double epsilon,
double *output_var, const uint32_t &device_id,
cudaStream_t stream_ptr);
template CUDA_LIB_EXPORT void CalApplyAdamWithAmsgrad<float>(const size_t size, const int64_t batch_size, float *var,
float *m, float *v, float *vhat, float *beta1_power,
float *beta2_power, const float *lr, const float *grad,
const float beta1, const float beta2, const float epsilon,
const uint32_t &device_id, cudaStream_t stream_ptr);
float *output_var, const uint32_t &device_id,
cudaStream_t stream_ptr);
template CUDA_LIB_EXPORT void CalApplyAdamWithAmsgrad<half>(const size_t size, const int64_t batch_size, half *var,
half *m, half *v, half *vhat, half *beta1_power,
half *beta2_power, const half *lr, const half *grad,
const float beta1, const float beta2, const float epsilon,
const uint32_t &device_id, cudaStream_t stream_ptr);
const half beta1, const half beta2, const half epsilon,
half *output_var, const uint32_t &device_id,
cudaStream_t stream_ptr);

View File

@ -21,7 +21,7 @@
template <typename T>
CUDA_LIB_EXPORT void CalApplyAdamWithAmsgrad(const size_t size, const int64_t batch_size, T *var, T *m, T *v, T *vhat,
T *beta1_power, T *beta2_power, const T *lr, const T *grad,
const float beta1, const float beta2, const float epsilon,
const T beta1, const T beta2, const T epsilon, T *output_var,
const uint32_t &device_id, cudaStream_t stream_ptr);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_ADAM_WITH_AMSGRAD_IMPL_CUH_

View File

@ -20,14 +20,12 @@
#include "abstract/utils.h"
#include "kernel/common_utils.h"
#include "include/curand.h"
#include "mindspore/core/ops/apply_adam_with_amsgrad.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kApplyAdamWithAmsgradInputsNum = 8;
constexpr size_t kApplyAdamWithAmsgradOutputsNum = 4;
constexpr size_t kScalarIndex = 0;
constexpr size_t kApplyAdamWithAmsgradOutputsNum = 1;
constexpr size_t kIndexVar = 0;
constexpr size_t kIndexM = 1;
constexpr size_t kIndexV = 2;
@ -41,6 +39,13 @@ constexpr size_t kIndexGrad = 7;
bool ApplyAdamWithAmsgradGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
if (inputs.size() != kApplyAdamWithAmsgradInputsNum || outputs.size() != kApplyAdamWithAmsgradOutputsNum) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output size should be " << kApplyAdamWithAmsgradInputsNum
<< " and " << kApplyAdamWithAmsgradOutputsNum << ", but got " << inputs.size() << " and "
<< outputs.size();
return false;
}
kernel_name_ = base_operator->name();
batch_rank_ = base_operator->get_batch_rank();
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::ApplyAdamWithAmsgrad>(base_operator);
@ -48,12 +53,6 @@ bool ApplyAdamWithAmsgradGpuKernelMod::Init(const BaseOperatorPtr &base_operator
beta1_ = kernel_ptr_->get_beta1();
beta2_ = kernel_ptr_->get_beta2();
epsilon_ = kernel_ptr_->get_epsilon();
if (inputs.size() != kApplyAdamWithAmsgradInputsNum || outputs.size() != kApplyAdamWithAmsgradOutputsNum) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output size should be " << kApplyAdamWithAmsgradInputsNum
<< " and " << kApplyAdamWithAmsgradOutputsNum << ", but got " << inputs.size() << " and "
<< outputs.size();
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
@ -63,7 +62,7 @@ bool ApplyAdamWithAmsgradGpuKernelMod::Init(const BaseOperatorPtr &base_operator
}
kernel_func_ = func_list_[index].second;
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndexVar).first);
return true;
}
@ -75,6 +74,7 @@ int ApplyAdamWithAmsgradGpuKernelMod::Resize(const BaseOperatorPtr &base_operato
if (ret != 0) {
return ret;
}
input_elements_ = 0;
std::vector<int64_t> var_shape = inputs[kIndexVar]->GetShapeVector();
std::vector<int64_t> m_shape = inputs[kIndexM]->GetShapeVector();
@ -150,24 +150,31 @@ int ApplyAdamWithAmsgradGpuKernelMod::Resize(const BaseOperatorPtr &base_operato
bool ApplyAdamWithAmsgradGpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs, void *stream_ptr) {
MS_EXCEPTION_IF_NULL(stream_ptr);
kernel_func_(this, inputs, outputs, stream_ptr);
return true;
}
template <typename T>
bool ApplyAdamWithAmsgradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &, void *stream_ptr) {
auto var = reinterpret_cast<T *>(inputs[kIndexVar]->addr);
auto m = reinterpret_cast<T *>(inputs[kIndexM]->addr);
auto v = reinterpret_cast<T *>(inputs[kIndexV]->addr);
auto vhat = reinterpret_cast<T *>(inputs[kIndexVhat]->addr);
auto beta1_power = reinterpret_cast<T *>(inputs[kIndexBeta1Power]->addr);
auto beta2_power = reinterpret_cast<T *>(inputs[kIndexBeta2Power]->addr);
auto lr = reinterpret_cast<T *>(inputs[kIndexLr]->addr);
auto grad = reinterpret_cast<T *>(inputs[kIndexGrad]->addr);
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto *var = reinterpret_cast<T *>(inputs[kIndexVar]->addr);
auto *m = reinterpret_cast<T *>(inputs[kIndexM]->addr);
auto *v = reinterpret_cast<T *>(inputs[kIndexV]->addr);
auto *vhat = reinterpret_cast<T *>(inputs[kIndexVhat]->addr);
auto *beta1_power = reinterpret_cast<T *>(inputs[kIndexBeta1Power]->addr);
auto *beta2_power = reinterpret_cast<T *>(inputs[kIndexBeta2Power]->addr);
auto *lr = reinterpret_cast<T *>(inputs[kIndexLr]->addr);
auto *grad = reinterpret_cast<T *>(inputs[kIndexGrad]->addr);
CalApplyAdamWithAmsgrad(input_elements_, batch_size_, var, m, v, vhat, beta1_power, beta2_power, lr, grad, beta1_,
beta2_, epsilon_, device_id_, reinterpret_cast<cudaStream_t>(stream_ptr));
T beta1 = static_cast<T>(beta1_);
T beta2 = static_cast<T>(beta2_);
T epsilon = static_cast<T>(epsilon_);
auto *output_var = reinterpret_cast<T *>(outputs[kIndexVar]->addr);
CalApplyAdamWithAmsgrad(input_elements_, batch_size_, var, m, v, vhat, beta1_power, beta2_power, lr, grad, beta1,
beta2, epsilon, output_var, device_id_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -183,13 +190,7 @@ std::vector<std::pair<KernelAttr, ApplyAdamWithAmsgradGpuKernelMod::KernelFunc>>
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutInRef(0, 0)
.AddOutInRef(1, 1)
.AddOutInRef(2, 2)
.AddOutInRef(3, 3),
.AddOutInRef(0, 0),
&ApplyAdamWithAmsgradGpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
@ -201,13 +202,7 @@ std::vector<std::pair<KernelAttr, ApplyAdamWithAmsgradGpuKernelMod::KernelFunc>>
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutInRef(0, 0)
.AddOutInRef(1, 1)
.AddOutInRef(2, 2)
.AddOutInRef(3, 3),
.AddOutInRef(0, 0),
&ApplyAdamWithAmsgradGpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
@ -219,13 +214,7 @@ std::vector<std::pair<KernelAttr, ApplyAdamWithAmsgradGpuKernelMod::KernelFunc>>
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutInRef(0, 0)
.AddOutInRef(1, 1)
.AddOutInRef(2, 2)
.AddOutInRef(3, 3),
.AddOutInRef(0, 0),
&ApplyAdamWithAmsgradGpuKernelMod::LaunchKernel<half>}};
std::vector<KernelAttr> ApplyAdamWithAmsgradGpuKernelMod::GetOpSupport() {

View File

@ -30,7 +30,6 @@
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adam_with_amsgrad_impl.cuh"
namespace mindspore {
namespace kernel {
@ -60,14 +59,14 @@ class ApplyAdamWithAmsgradGpuKernelMod : public NativeGpuKernelMod {
KernelFunc kernel_func_{};
static std::vector<std::pair<KernelAttr, KernelFunc>> func_list_;
int unit_size_;
size_t input_elements_;
size_t unit_size_{0};
size_t input_elements_{0};
int64_t batch_rank_;
int64_t batch_size_;
float beta1_ = 0.9;
float beta2_ = 0.999;
float epsilon_ = 1e-8;
float beta1_{0.9};
float beta2_{0.999};
float epsilon_{1e-8};
};
} // namespace kernel
} // namespace mindspore

View File

@ -30,21 +30,20 @@
namespace mindspore {
namespace ops {
namespace {
abstract::TupleShapePtr ApplyAdamWithAmsgradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
abstract::ShapePtr ApplyAdamWithAmsgradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto prim_name = primitive->name();
auto var_shape = input_args[0]->BuildShape();
auto m_shape = input_args[1]->BuildShape();
auto v_shape = input_args[2]->BuildShape();
auto vhat_shape = input_args[3]->BuildShape();
auto grad_shape = input_args[7]->BuildShape();
auto beta1_power_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape];
auto beta2_power_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[5]->BuildShape())[kShape];
auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[6]->BuildShape())[kShape];
auto grad_shape = input_args[7]->BuildShape();
int64_t batch_rank = 0;
if (primitive->HasAttr(kBatchRank)) {
@ -60,8 +59,7 @@ abstract::TupleShapePtr ApplyAdamWithAmsgradInferShape(const PrimitivePtr &primi
if (var_shape->IsDynamic() || m_shape->IsDynamic() || v_shape->IsDynamic() || vhat_shape->IsDynamic() ||
grad_shape->IsDynamic()) {
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{var_shape, m_shape, v_shape, vhat_shape});
return var_shape->cast<abstract::ShapePtr>();
}
// shape of var, m, v, vhat must be the same
@ -78,14 +76,12 @@ abstract::TupleShapePtr ApplyAdamWithAmsgradInferShape(const PrimitivePtr &primi
<< ".";
}
}
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{var_shape, m_shape, v_shape, vhat_shape});
auto shape_ptr = var_shape->cast<abstract::ShapePtr>();
return shape_ptr;
}
TuplePtr ApplyAdamWithAmsgradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
TypePtr ApplyAdamWithAmsgradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = prim->name();
// get all input_args' shape
auto var_type = input_args[0]->BuildType();
auto m_type = input_args[1]->BuildType();
auto v_type = input_args[2]->BuildType();
@ -102,15 +98,22 @@ TuplePtr ApplyAdamWithAmsgradInferType(const PrimitivePtr &prim, const std::vect
(void)args.insert(std::make_pair("v_type", v_type));
(void)args.insert(std::make_pair("vhat_type", vhat_type));
(void)args.insert(std::make_pair("grad_type", grad_type));
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
// beta1_power, beta2_power, lr type valid
CheckAndConvertUtils::CheckTensorTypeValid("beta1_power_type", beta1_power_type, valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("beta2_power_type", beta2_power_type, valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("lr_type", lr_type, valid_types, prim_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type, v_type, vhat_type});
return var_type;
}
} // namespace
void ApplyAdamWithAmsgrad::Init(const float beta1, const float beta2, const float epsilon, const bool use_locking) {
this->set_beta1(beta1);
this->set_beta2(beta2);
this->set_epsilon(epsilon);
this->set_use_locking(use_locking);
}
void ApplyAdamWithAmsgrad::set_beta1(const float beta1) { (void)this->AddAttr(kBeta1, api::MakeValue(beta1)); }
void ApplyAdamWithAmsgrad::set_beta2(const float beta2) { (void)this->AddAttr(kBeta2, api::MakeValue(beta2)); }
@ -146,7 +149,7 @@ AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, c
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 8;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
auto infer_type = ApplyAdamWithAmsgradInferType(primitive, input_args);
auto infer_shape = ApplyAdamWithAmsgradInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);

View File

@ -29,9 +29,12 @@ class MIND_API ApplyAdamWithAmsgrad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ApplyAdamWithAmsgrad);
ApplyAdamWithAmsgrad() : BaseOperator(kNameApplyAdamWithAmsgrad) {
InitIOName({"var", "m", "v", "vhat", "beta1_power", "beta2_power", "lr", "grad"}, {"var", "m", "v", "vhat"});
InitIOName({"var", "m", "v", "vhat", "beta1_power", "beta2_power", "lr", "grad"}, {"var"});
}
void Init(const float beta1 = 0.9, const float beta2 = 0.999, const float epsilon = 1e-8,
const bool use_locking = false);
void set_beta1(const float beta1);
void set_beta2(const float beta2);

View File

@ -507,7 +507,7 @@ class Adam(Optimizer):
self.use_amsgrad = use_amsgrad
self.moment1 = self._parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self._parameters.clone(prefix="moment2", init='zeros')
self.vhat = self._parameters.clone(prefix="vhat", init=-100000)
self.vhat = self._parameters.clone(prefix="vhat", init='zeros')
self._is_device = True
if use_amsgrad:

View File

@ -1198,8 +1198,8 @@ def get_apply_adam_with_amsgrad_rule(prim, axis_size):
ValueError("The source axis of `var` is None, "
"but the source axis of `m/v/vhat/beta1_power/beta2_power/lr/grad` is not None. "
"The execution of operator `{}` cannot be guaranteed.".format(prim_name))
out_var, out_m, out_v, out_vhat = prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad)
return ((out_var, None), (out_m, None), (out_v, None), (out_vhat, None))
out_var = prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad)
return (out_var, None)
if any(dim != 0 for dim in [var_dim, m_dim, v_dim, vhat_dim]):
ValueError("For `{}`, the source axis of `var/m/v/vhat` must be 0, "
@ -1211,8 +1211,8 @@ def get_apply_adam_with_amsgrad_rule(prim, axis_size):
lr = _bdim_at_front(lr, lr_dim, axis_size)
grad = _bdim_at_front(grad, grad_dim, axis_size)
out_var, out_m, out_v, out_vhat = batch_prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad)
return ((out_var, 0), (out_m, 0), (out_v, 0), (out_vhat, 0))
out_var = batch_prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad)
return (out_var, 0)
return vmap_rule

View File

@ -49,7 +49,7 @@ def numpy_apply_adam_with_amsgrad(var, m, v, vhat, grad, beta1=0.9, beta2=0.999,
v = v * beta2 + grad * grad * (1 - beta2)
vhat = np.maximum(vhat, v)
var = var - new_lr * m / (np.sqrt(vhat) + eps)
return var, m, v, vhat
return var
@pytest.mark.level0
@ -72,14 +72,10 @@ def test_apply_adam_with_amsgrad_op(data_type):
grad = Tensor(grad_np)
output = amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad)
ms_var, ms_m, ms_v, ms_vhat = output[0].asnumpy(), output[1].asnumpy(), output[2].asnumpy(), output[3].asnumpy()
np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(amsgrad.var_np, amsgrad.m_np,
amsgrad.v_np, amsgrad.vhat_np, grad_np)
ms_var = output.asnumpy()
np_var = numpy_apply_adam_with_amsgrad(amsgrad.var_np, amsgrad.m_np, amsgrad.v_np, amsgrad.vhat_np, grad_np)
np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error)
np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error)
np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error)
np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error)
class AmsgradNetVmap(nn.Cell):
@ -121,15 +117,11 @@ def test_apply_adam_witm_amsgrad_op_vmap():
vmap_amsgrad = AmsgradNetVmap(cal_amsgrad)
_ = vmap_amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad)
ms_var, ms_m = vmap_amsgrad.var.asnumpy(), vmap_amsgrad.m.asnumpy()
ms_v, ms_vhat = vmap_amsgrad.v.asnumpy(), vmap_amsgrad.vhat.asnumpy()
np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
ms_var = vmap_amsgrad.var.asnumpy()
np_var = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error)
np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error)
np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error)
np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error)
class AmsgradNetVmap2(nn.Cell):
@ -173,12 +165,8 @@ def test_apply_adam_with_amsgrad_grad_op_vmap2():
vmap_amsgrad = AmsgradNetVmap2(cal_amsgrad)
_ = vmap_amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad)
ms_var, ms_m = vmap_amsgrad.var.asnumpy(), vmap_amsgrad.m.asnumpy()
ms_v, ms_vhat = vmap_amsgrad.v.asnumpy(), vmap_amsgrad.vhat.asnumpy()
np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
ms_var = vmap_amsgrad.var.asnumpy()
np_var = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error)
np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error)
np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error)
np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error)

View File

@ -49,7 +49,7 @@ def numpy_apply_adam_with_amsgrad(var, m, v, vhat, grad, beta1=0.9, beta2=0.999,
v = v * beta2 + grad * grad * (1 - beta2)
vhat = np.maximum(vhat, v)
var = var - new_lr * m / (np.sqrt(vhat) + eps)
return var, m, v, vhat
return var
@pytest.mark.level0
@ -71,14 +71,11 @@ def test_apply_adam_with_amsgrad_op(data_type):
grad = Tensor(grad_np)
output = amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad)
ms_var, ms_m, ms_v, ms_vhat = output[0].asnumpy(), output[1].asnumpy(), output[2].asnumpy(), output[3].asnumpy()
np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(amsgrad.var_np, amsgrad.m_np,
amsgrad.v_np, amsgrad.vhat_np, grad_np)
ms_var = output.asnumpy()
np_var = numpy_apply_adam_with_amsgrad(amsgrad.var_np, amsgrad.m_np,
amsgrad.v_np, amsgrad.vhat_np, grad_np)
np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error)
np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error)
np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error)
np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error)
class AmsgradNetVmap(nn.Cell):
@ -119,15 +116,11 @@ def test_apply_adam_witm_amsgrad_op_vmap():
vmap_amsgrad = AmsgradNetVmap(cal_amsgrad)
_ = vmap_amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad)
ms_var, ms_m = vmap_amsgrad.var.asnumpy(), vmap_amsgrad.m.asnumpy()
ms_v, ms_vhat = vmap_amsgrad.v.asnumpy(), vmap_amsgrad.vhat.asnumpy()
np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
ms_var = vmap_amsgrad.var.asnumpy()
np_var = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error)
np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error)
np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error)
np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error)
class AmsgradNetVmap2(nn.Cell):
@ -167,15 +160,9 @@ def test_apply_adam_with_amsgrad_grad_op_vmap2():
grad_np = np.random.randn(*shape).astype(np.float32)
grad = Tensor(grad_np)
vmap_amsgrad = AmsgradNetVmap2(cal_amsgrad)
_ = vmap_amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad)
ms_var, ms_m = vmap_amsgrad.var.asnumpy(), vmap_amsgrad.m.asnumpy()
ms_v, ms_vhat = vmap_amsgrad.v.asnumpy(), vmap_amsgrad.vhat.asnumpy()
np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
ms_var = vmap_amsgrad.var.asnumpy()
np_var = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error)
np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error)
np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error)
np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error)