From 815a06f410a33a6770b88a1fd2f11dfa734de323 Mon Sep 17 00:00:00 2001 From: OwenSec Date: Wed, 3 Aug 2022 10:07:47 +0800 Subject: [PATCH] fix bugs for applyadamwithamsgrad --- .../apply_adam_with_amsgrad_cpu_kernel.cc | 35 +++------ .../apply_adam_with_amsgrad_cpu_kernel.h | 6 +- .../cuda_ops/apply_adam_with_amsgrad_impl.cu | 40 +++++----- .../cuda_ops/apply_adam_with_amsgrad_impl.cuh | 2 +- .../nn/apply_adam_with_amsgrad_gpu_kernel.cc | 73 ++++++++----------- .../nn/apply_adam_with_amsgrad_gpu_kernel.h | 11 ++- mindspore/core/ops/apply_adam_with_amsgrad.cc | 33 +++++---- mindspore/core/ops/apply_adam_with_amsgrad.h | 5 +- mindspore/python/mindspore/nn/optim/adam.py | 2 +- .../python/mindspore/ops/_vmap/vmap_nn_ops.py | 8 +- .../cpu/test_apply_adam_with_amsgrad_op.py | 30 +++----- .../gpu/test_apply_adam_with_amsgrad_op.py | 33 +++------ 12 files changed, 119 insertions(+), 159 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/apply_adam_with_amsgrad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/apply_adam_with_amsgrad_cpu_kernel.cc index 3ad3782a1c7..186a90ad3be 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/apply_adam_with_amsgrad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/apply_adam_with_amsgrad_cpu_kernel.cc @@ -29,7 +29,7 @@ namespace mindspore { namespace kernel { namespace { constexpr size_t kApplyAdamWithAmsgradInputsNum = 8; -constexpr size_t kApplyAdamWithAmsgradOutputsNum = 4; +constexpr size_t kApplyAdamWithAmsgradOutputsNum = 1; constexpr size_t kScalarIndex = 0; constexpr size_t kIndexVar = 0; constexpr size_t kIndexM = 1; @@ -65,9 +65,9 @@ bool ApplyAdamWithAmsgradCpuKernelMod::Init(const BaseOperatorPtr &base_operator return false; } - beta1_[0] = kernel_ptr->get_beta1(); - beta2_[0] = kernel_ptr->get_beta2(); - epsilon_[0] = kernel_ptr->get_epsilon(); + beta1_ = kernel_ptr->get_beta1(); + beta2_ = kernel_ptr->get_beta2(); + epsilon_ = kernel_ptr->get_epsilon(); return true; } @@ -139,15 +139,14 @@ void ApplyAdamWithAmsgradCpuKernelMod::LaunchApplyAdamWithAmsgrad(const std::vec T *lr = reinterpret_cast(inputs[kIndexLr]->addr); T *gradient = reinterpret_cast(inputs[kIndexGrad]->addr); - T beta1 = static_cast(beta1_[0]); - T beta2 = static_cast(beta2_[0]); - T epsilon = static_cast(epsilon_[0]); + T beta1 = static_cast(beta1_); + T beta2 = static_cast(beta2_); + T epsilon = static_cast(epsilon_); - constexpr float ONE = 1.0; + T ONE = static_cast(1.0); for (int64_t b = 0; b < batch_size_; b++) { // multithreading - T new_lr = static_cast(static_cast(lr[b]) * std::sqrt(ONE - static_cast(beta2_power[b])) / - (ONE - static_cast(beta1_power[b]))); + T new_lr = lr[b] * static_cast(std::sqrt(static_cast(ONE - beta2_power[b]))) / (ONE - beta1_power[b]); auto task = [this, &var, &m, &v, &vhat, &gradient, new_lr, beta1, beta2, epsilon](size_t start, size_t end) { T one = static_cast(1.0); for (size_t i = start; i < end; i++) { @@ -215,13 +214,7 @@ std::vector ApplyAdamWithAmsgradCpuKernelMod::GetOpSupport() { .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutInRef(0, 0) - .AddOutInRef(1, 1) - .AddOutInRef(2, 2) - .AddOutInRef(3, 3), + .AddOutInRef(0, 0), KernelAttr() .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16) @@ -232,13 +225,7 @@ std::vector ApplyAdamWithAmsgradCpuKernelMod::GetOpSupport() { .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutInRef(0, 0) - .AddOutInRef(1, 1) - .AddOutInRef(2, 2) - .AddOutInRef(3, 3)}; + .AddOutInRef(0, 0)}; return support_list; } diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/apply_adam_with_amsgrad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/apply_adam_with_amsgrad_cpu_kernel.h index 5b50c82050d..d333436de76 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/apply_adam_with_amsgrad_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/apply_adam_with_amsgrad_cpu_kernel.h @@ -46,9 +46,9 @@ class ApplyAdamWithAmsgradCpuKernelMod : public NativeCpuKernelMod { template void LaunchApplyAdamWithAmsgrad(const std::vector &inputs, const std::vector &outputs); - std::vector beta1_ = {0.9}; - std::vector beta2_ = {0.999}; - std::vector epsilon_ = {1e-8}; + float beta1_{0.9}; + float beta2_{0.999}; + float epsilon_{1e-8}; int64_t batch_rank_{0}; int64_t batch_size_{1}; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adam_with_amsgrad_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adam_with_amsgrad_impl.cu index 0c03b7b9ef0..418f73adfbb 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adam_with_amsgrad_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adam_with_amsgrad_impl.cu @@ -25,7 +25,7 @@ __device__ __forceinline__ T sqrtFunc(T x) { template <> __device__ __forceinline__ half sqrtFunc(half x) { - return sqrt(__half2float(x)); + return hsqrt(x); } template @@ -35,51 +35,55 @@ __device__ __forceinline__ T maxFunc(T x, T y) { template <> __device__ __forceinline__ half maxFunc(half x, half y) { - return __half2float(x) > __half2float(y)? __half2float(x) : __half2float(y); + return x > y? x : y; } template -__global__ void CalApplyAdamWithAmsgradKernel(const size_t input_elements, const int64_t batch_size, T *var, T *m, +__global__ void CalApplyAdamWithAmsgradKernel(const size_t size, const int64_t batch_size, T *var, T *m, T *v, T *vhat, T *beta1_power, T *beta2_power, const T *lr, - const T *grad, const float beta1, const float beta2, - const float epsilon) { - auto all_elements = input_elements * batch_size; + const T *grad, const T beta1, const T beta2, + const T epsilon, T *output_var) { + auto all_elements = size * batch_size; const T one = static_cast(1.0); for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < all_elements; pos += gridDim.x * blockDim.x) { - auto batch = pos / input_elements; + auto batch = pos / size; auto new_learning_rate = lr[batch] * sqrtFunc(one - beta2_power[batch]) / (one - beta1_power[batch]); m[pos] += (grad[pos] - m[pos]) * (one - static_cast(beta1)); v[pos] += (grad[pos] * grad[pos] - v[pos]) * (one - static_cast(beta2)); vhat[pos] = maxFunc(vhat[pos], v[pos]); var[pos] -= new_learning_rate * m[pos] / (sqrtFunc(vhat[pos]) + static_cast(epsilon)); + output_var[pos] = var[pos]; } } template -void CalApplyAdamWithAmsgrad(const size_t input_elements, const int64_t batch_size, T *var, T *m, T *v, T *vhat, - T *beta1_power, T *beta2_power, const T *lr, const T *grad, const float beta1, - const float beta2, const float epsilon, const uint32_t &device_id, +void CalApplyAdamWithAmsgrad(const size_t size, const int64_t batch_size, T *var, T *m, T *v, T *vhat, + T *beta1_power, T *beta2_power, const T *lr, const T *grad, const T beta1, + const T beta2, const T epsilon, T *output_var, const uint32_t &device_id, cudaStream_t stream_ptr) { - CalApplyAdamWithAmsgradKernel<<>>(input_elements, batch_size, var, m, v, vhat, beta1_power, beta2_power, - lr, grad, beta1, beta2, epsilon); + CalApplyAdamWithAmsgradKernel<<>>(size, batch_size, var, m, v, vhat, beta1_power, beta2_power, + lr, grad, beta1, beta2, epsilon, output_var); } template CUDA_LIB_EXPORT void CalApplyAdamWithAmsgrad(const size_t size, const int64_t batch_size, double *var, double *m, double *v, double *vhat, double *beta1_power, double *beta2_power, const double *lr, - const double *grad, const float beta1, const float beta2, - const float epsilon, const uint32_t &device_id, + const double *grad, const double beta1, + const double beta2, const double epsilon, + double *output_var, const uint32_t &device_id, cudaStream_t stream_ptr); template CUDA_LIB_EXPORT void CalApplyAdamWithAmsgrad(const size_t size, const int64_t batch_size, float *var, float *m, float *v, float *vhat, float *beta1_power, float *beta2_power, const float *lr, const float *grad, const float beta1, const float beta2, const float epsilon, - const uint32_t &device_id, cudaStream_t stream_ptr); + float *output_var, const uint32_t &device_id, + cudaStream_t stream_ptr); template CUDA_LIB_EXPORT void CalApplyAdamWithAmsgrad(const size_t size, const int64_t batch_size, half *var, half *m, half *v, half *vhat, half *beta1_power, half *beta2_power, const half *lr, const half *grad, - const float beta1, const float beta2, const float epsilon, - const uint32_t &device_id, cudaStream_t stream_ptr); + const half beta1, const half beta2, const half epsilon, + half *output_var, const uint32_t &device_id, + cudaStream_t stream_ptr); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adam_with_amsgrad_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adam_with_amsgrad_impl.cuh index e99ca94d0fd..cbfc1d9e41b 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adam_with_amsgrad_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adam_with_amsgrad_impl.cuh @@ -21,7 +21,7 @@ template CUDA_LIB_EXPORT void CalApplyAdamWithAmsgrad(const size_t size, const int64_t batch_size, T *var, T *m, T *v, T *vhat, T *beta1_power, T *beta2_power, const T *lr, const T *grad, - const float beta1, const float beta2, const float epsilon, + const T beta1, const T beta2, const T epsilon, T *output_var, const uint32_t &device_id, cudaStream_t stream_ptr); #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_ADAM_WITH_AMSGRAD_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_adam_with_amsgrad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_adam_with_amsgrad_gpu_kernel.cc index 1540bdafcba..630016574b7 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_adam_with_amsgrad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_adam_with_amsgrad_gpu_kernel.cc @@ -20,14 +20,12 @@ #include "abstract/utils.h" #include "kernel/common_utils.h" #include "include/curand.h" -#include "mindspore/core/ops/apply_adam_with_amsgrad.h" namespace mindspore { namespace kernel { namespace { constexpr size_t kApplyAdamWithAmsgradInputsNum = 8; -constexpr size_t kApplyAdamWithAmsgradOutputsNum = 4; -constexpr size_t kScalarIndex = 0; +constexpr size_t kApplyAdamWithAmsgradOutputsNum = 1; constexpr size_t kIndexVar = 0; constexpr size_t kIndexM = 1; constexpr size_t kIndexV = 2; @@ -41,6 +39,13 @@ constexpr size_t kIndexGrad = 7; bool ApplyAdamWithAmsgradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs) { + if (inputs.size() != kApplyAdamWithAmsgradInputsNum || outputs.size() != kApplyAdamWithAmsgradOutputsNum) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output size should be " << kApplyAdamWithAmsgradInputsNum + << " and " << kApplyAdamWithAmsgradOutputsNum << ", but got " << inputs.size() << " and " + << outputs.size(); + return false; + } + kernel_name_ = base_operator->name(); batch_rank_ = base_operator->get_batch_rank(); auto kernel_ptr_ = std::dynamic_pointer_cast(base_operator); @@ -48,12 +53,6 @@ bool ApplyAdamWithAmsgradGpuKernelMod::Init(const BaseOperatorPtr &base_operator beta1_ = kernel_ptr_->get_beta1(); beta2_ = kernel_ptr_->get_beta2(); epsilon_ = kernel_ptr_->get_epsilon(); - if (inputs.size() != kApplyAdamWithAmsgradInputsNum || outputs.size() != kApplyAdamWithAmsgradOutputsNum) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output size should be " << kApplyAdamWithAmsgradInputsNum - << " and " << kApplyAdamWithAmsgradOutputsNum << ", but got " << inputs.size() << " and " - << outputs.size(); - return false; - } auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); @@ -63,7 +62,7 @@ bool ApplyAdamWithAmsgradGpuKernelMod::Init(const BaseOperatorPtr &base_operator } kernel_func_ = func_list_[index].second; - unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first); + unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndexVar).first); return true; } @@ -75,6 +74,7 @@ int ApplyAdamWithAmsgradGpuKernelMod::Resize(const BaseOperatorPtr &base_operato if (ret != 0) { return ret; } + input_elements_ = 0; std::vector var_shape = inputs[kIndexVar]->GetShapeVector(); std::vector m_shape = inputs[kIndexM]->GetShapeVector(); @@ -150,24 +150,31 @@ int ApplyAdamWithAmsgradGpuKernelMod::Resize(const BaseOperatorPtr &base_operato bool ApplyAdamWithAmsgradGpuKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) { + MS_EXCEPTION_IF_NULL(stream_ptr); kernel_func_(this, inputs, outputs, stream_ptr); return true; } template bool ApplyAdamWithAmsgradGpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &, void *stream_ptr) { - auto var = reinterpret_cast(inputs[kIndexVar]->addr); - auto m = reinterpret_cast(inputs[kIndexM]->addr); - auto v = reinterpret_cast(inputs[kIndexV]->addr); - auto vhat = reinterpret_cast(inputs[kIndexVhat]->addr); - auto beta1_power = reinterpret_cast(inputs[kIndexBeta1Power]->addr); - auto beta2_power = reinterpret_cast(inputs[kIndexBeta2Power]->addr); - auto lr = reinterpret_cast(inputs[kIndexLr]->addr); - auto grad = reinterpret_cast(inputs[kIndexGrad]->addr); + const std::vector &outputs, void *stream_ptr) { + auto *var = reinterpret_cast(inputs[kIndexVar]->addr); + auto *m = reinterpret_cast(inputs[kIndexM]->addr); + auto *v = reinterpret_cast(inputs[kIndexV]->addr); + auto *vhat = reinterpret_cast(inputs[kIndexVhat]->addr); + auto *beta1_power = reinterpret_cast(inputs[kIndexBeta1Power]->addr); + auto *beta2_power = reinterpret_cast(inputs[kIndexBeta2Power]->addr); + auto *lr = reinterpret_cast(inputs[kIndexLr]->addr); + auto *grad = reinterpret_cast(inputs[kIndexGrad]->addr); - CalApplyAdamWithAmsgrad(input_elements_, batch_size_, var, m, v, vhat, beta1_power, beta2_power, lr, grad, beta1_, - beta2_, epsilon_, device_id_, reinterpret_cast(stream_ptr)); + T beta1 = static_cast(beta1_); + T beta2 = static_cast(beta2_); + T epsilon = static_cast(epsilon_); + + auto *output_var = reinterpret_cast(outputs[kIndexVar]->addr); + + CalApplyAdamWithAmsgrad(input_elements_, batch_size_, var, m, v, vhat, beta1_power, beta2_power, lr, grad, beta1, + beta2, epsilon, output_var, device_id_, reinterpret_cast(stream_ptr)); return true; } @@ -183,13 +190,7 @@ std::vector> .AddInputAttr(kNumberTypeFloat64) .AddInputAttr(kNumberTypeFloat64) .AddOutputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64) - .AddOutInRef(0, 0) - .AddOutInRef(1, 1) - .AddOutInRef(2, 2) - .AddOutInRef(3, 3), + .AddOutInRef(0, 0), &ApplyAdamWithAmsgradGpuKernelMod::LaunchKernel}, {KernelAttr() .AddInputAttr(kNumberTypeFloat32) @@ -201,13 +202,7 @@ std::vector> .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutInRef(0, 0) - .AddOutInRef(1, 1) - .AddOutInRef(2, 2) - .AddOutInRef(3, 3), + .AddOutInRef(0, 0), &ApplyAdamWithAmsgradGpuKernelMod::LaunchKernel}, {KernelAttr() .AddInputAttr(kNumberTypeFloat16) @@ -219,13 +214,7 @@ std::vector> .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutInRef(0, 0) - .AddOutInRef(1, 1) - .AddOutInRef(2, 2) - .AddOutInRef(3, 3), + .AddOutInRef(0, 0), &ApplyAdamWithAmsgradGpuKernelMod::LaunchKernel}}; std::vector ApplyAdamWithAmsgradGpuKernelMod::GetOpSupport() { diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_adam_with_amsgrad_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_adam_with_amsgrad_gpu_kernel.h index 3fda5aec11a..d370bccbe9a 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_adam_with_amsgrad_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/apply_adam_with_amsgrad_gpu_kernel.h @@ -30,7 +30,6 @@ #include "plugin/device/gpu/kernel/gpu_kernel.h" #include "plugin/device/gpu/kernel/gpu_kernel_factory.h" #include "plugin/factory/ms_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adam_with_amsgrad_impl.cuh" namespace mindspore { namespace kernel { @@ -60,14 +59,14 @@ class ApplyAdamWithAmsgradGpuKernelMod : public NativeGpuKernelMod { KernelFunc kernel_func_{}; static std::vector> func_list_; - int unit_size_; - size_t input_elements_; + size_t unit_size_{0}; + size_t input_elements_{0}; int64_t batch_rank_; int64_t batch_size_; - float beta1_ = 0.9; - float beta2_ = 0.999; - float epsilon_ = 1e-8; + float beta1_{0.9}; + float beta2_{0.999}; + float epsilon_{1e-8}; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/core/ops/apply_adam_with_amsgrad.cc b/mindspore/core/ops/apply_adam_with_amsgrad.cc index cf44547eb53..9d3876a482c 100644 --- a/mindspore/core/ops/apply_adam_with_amsgrad.cc +++ b/mindspore/core/ops/apply_adam_with_amsgrad.cc @@ -30,21 +30,20 @@ namespace mindspore { namespace ops { namespace { -abstract::TupleShapePtr ApplyAdamWithAmsgradInferShape(const PrimitivePtr &primitive, - const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); +abstract::ShapePtr ApplyAdamWithAmsgradInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } + auto prim_name = primitive->name(); auto var_shape = input_args[0]->BuildShape(); auto m_shape = input_args[1]->BuildShape(); auto v_shape = input_args[2]->BuildShape(); auto vhat_shape = input_args[3]->BuildShape(); + auto grad_shape = input_args[7]->BuildShape(); auto beta1_power_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape]; auto beta2_power_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[5]->BuildShape())[kShape]; auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[6]->BuildShape())[kShape]; - auto grad_shape = input_args[7]->BuildShape(); int64_t batch_rank = 0; if (primitive->HasAttr(kBatchRank)) { @@ -60,8 +59,7 @@ abstract::TupleShapePtr ApplyAdamWithAmsgradInferShape(const PrimitivePtr &primi if (var_shape->IsDynamic() || m_shape->IsDynamic() || v_shape->IsDynamic() || vhat_shape->IsDynamic() || grad_shape->IsDynamic()) { - return std::make_shared( - std::vector{var_shape, m_shape, v_shape, vhat_shape}); + return var_shape->cast(); } // shape of var, m, v, vhat must be the same @@ -78,14 +76,12 @@ abstract::TupleShapePtr ApplyAdamWithAmsgradInferShape(const PrimitivePtr &primi << "."; } } - return std::make_shared( - std::vector{var_shape, m_shape, v_shape, vhat_shape}); + auto shape_ptr = var_shape->cast(); + return shape_ptr; } -TuplePtr ApplyAdamWithAmsgradInferType(const PrimitivePtr &prim, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(prim); +TypePtr ApplyAdamWithAmsgradInferType(const PrimitivePtr &prim, const std::vector &input_args) { auto prim_name = prim->name(); - // get all input_args' shape auto var_type = input_args[0]->BuildType(); auto m_type = input_args[1]->BuildType(); auto v_type = input_args[2]->BuildType(); @@ -102,15 +98,22 @@ TuplePtr ApplyAdamWithAmsgradInferType(const PrimitivePtr &prim, const std::vect (void)args.insert(std::make_pair("v_type", v_type)); (void)args.insert(std::make_pair("vhat_type", vhat_type)); (void)args.insert(std::make_pair("grad_type", grad_type)); - CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); // beta1_power, beta2_power, lr type valid CheckAndConvertUtils::CheckTensorTypeValid("beta1_power_type", beta1_power_type, valid_types, prim_name); CheckAndConvertUtils::CheckTensorTypeValid("beta2_power_type", beta2_power_type, valid_types, prim_name); CheckAndConvertUtils::CheckTensorTypeValid("lr_type", lr_type, valid_types, prim_name); - return std::make_shared(std::vector{var_type, m_type, v_type, vhat_type}); + return var_type; } } // namespace +void ApplyAdamWithAmsgrad::Init(const float beta1, const float beta2, const float epsilon, const bool use_locking) { + this->set_beta1(beta1); + this->set_beta2(beta2); + this->set_epsilon(epsilon); + this->set_use_locking(use_locking); +} + void ApplyAdamWithAmsgrad::set_beta1(const float beta1) { (void)this->AddAttr(kBeta1, api::MakeValue(beta1)); } void ApplyAdamWithAmsgrad::set_beta2(const float beta2) { (void)this->AddAttr(kBeta2, api::MakeValue(beta2)); } @@ -146,7 +149,7 @@ AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, c const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); const int64_t input_num = 8; - CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name()); + (void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name()); auto infer_type = ApplyAdamWithAmsgradInferType(primitive, input_args); auto infer_shape = ApplyAdamWithAmsgradInferShape(primitive, input_args); return abstract::MakeAbstract(infer_shape, infer_type); diff --git a/mindspore/core/ops/apply_adam_with_amsgrad.h b/mindspore/core/ops/apply_adam_with_amsgrad.h index 7b8fa232adc..4b287fe1142 100644 --- a/mindspore/core/ops/apply_adam_with_amsgrad.h +++ b/mindspore/core/ops/apply_adam_with_amsgrad.h @@ -29,9 +29,12 @@ class MIND_API ApplyAdamWithAmsgrad : public BaseOperator { public: MIND_API_BASE_MEMBER(ApplyAdamWithAmsgrad); ApplyAdamWithAmsgrad() : BaseOperator(kNameApplyAdamWithAmsgrad) { - InitIOName({"var", "m", "v", "vhat", "beta1_power", "beta2_power", "lr", "grad"}, {"var", "m", "v", "vhat"}); + InitIOName({"var", "m", "v", "vhat", "beta1_power", "beta2_power", "lr", "grad"}, {"var"}); } + void Init(const float beta1 = 0.9, const float beta2 = 0.999, const float epsilon = 1e-8, + const bool use_locking = false); + void set_beta1(const float beta1); void set_beta2(const float beta2); diff --git a/mindspore/python/mindspore/nn/optim/adam.py b/mindspore/python/mindspore/nn/optim/adam.py index 343818c3369..af7256eb45e 100755 --- a/mindspore/python/mindspore/nn/optim/adam.py +++ b/mindspore/python/mindspore/nn/optim/adam.py @@ -507,7 +507,7 @@ class Adam(Optimizer): self.use_amsgrad = use_amsgrad self.moment1 = self._parameters.clone(prefix="moment1", init='zeros') self.moment2 = self._parameters.clone(prefix="moment2", init='zeros') - self.vhat = self._parameters.clone(prefix="vhat", init=-100000) + self.vhat = self._parameters.clone(prefix="vhat", init='zeros') self._is_device = True if use_amsgrad: diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py index 4bc87664c0a..ec77c1849a4 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py @@ -1198,8 +1198,8 @@ def get_apply_adam_with_amsgrad_rule(prim, axis_size): ValueError("The source axis of `var` is None, " "but the source axis of `m/v/vhat/beta1_power/beta2_power/lr/grad` is not None. " "The execution of operator `{}` cannot be guaranteed.".format(prim_name)) - out_var, out_m, out_v, out_vhat = prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad) - return ((out_var, None), (out_m, None), (out_v, None), (out_vhat, None)) + out_var = prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad) + return (out_var, None) if any(dim != 0 for dim in [var_dim, m_dim, v_dim, vhat_dim]): ValueError("For `{}`, the source axis of `var/m/v/vhat` must be 0, " @@ -1211,8 +1211,8 @@ def get_apply_adam_with_amsgrad_rule(prim, axis_size): lr = _bdim_at_front(lr, lr_dim, axis_size) grad = _bdim_at_front(grad, grad_dim, axis_size) - out_var, out_m, out_v, out_vhat = batch_prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad) - return ((out_var, 0), (out_m, 0), (out_v, 0), (out_vhat, 0)) + out_var = batch_prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad) + return (out_var, 0) return vmap_rule diff --git a/tests/st/ops/cpu/test_apply_adam_with_amsgrad_op.py b/tests/st/ops/cpu/test_apply_adam_with_amsgrad_op.py index e0a62302e33..e925f546103 100644 --- a/tests/st/ops/cpu/test_apply_adam_with_amsgrad_op.py +++ b/tests/st/ops/cpu/test_apply_adam_with_amsgrad_op.py @@ -49,7 +49,7 @@ def numpy_apply_adam_with_amsgrad(var, m, v, vhat, grad, beta1=0.9, beta2=0.999, v = v * beta2 + grad * grad * (1 - beta2) vhat = np.maximum(vhat, v) var = var - new_lr * m / (np.sqrt(vhat) + eps) - return var, m, v, vhat + return var @pytest.mark.level0 @@ -72,14 +72,10 @@ def test_apply_adam_with_amsgrad_op(data_type): grad = Tensor(grad_np) output = amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad) - ms_var, ms_m, ms_v, ms_vhat = output[0].asnumpy(), output[1].asnumpy(), output[2].asnumpy(), output[3].asnumpy() - np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(amsgrad.var_np, amsgrad.m_np, - amsgrad.v_np, amsgrad.vhat_np, grad_np) + ms_var = output.asnumpy() + np_var = numpy_apply_adam_with_amsgrad(amsgrad.var_np, amsgrad.m_np, amsgrad.v_np, amsgrad.vhat_np, grad_np) np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error) - np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error) - np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error) - np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error) class AmsgradNetVmap(nn.Cell): @@ -121,15 +117,11 @@ def test_apply_adam_witm_amsgrad_op_vmap(): vmap_amsgrad = AmsgradNetVmap(cal_amsgrad) _ = vmap_amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad) - ms_var, ms_m = vmap_amsgrad.var.asnumpy(), vmap_amsgrad.m.asnumpy() - ms_v, ms_vhat = vmap_amsgrad.v.asnumpy(), vmap_amsgrad.vhat.asnumpy() - np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np, - vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np) + ms_var = vmap_amsgrad.var.asnumpy() + np_var = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np, + vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np) np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error) - np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error) - np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error) - np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error) class AmsgradNetVmap2(nn.Cell): @@ -173,12 +165,8 @@ def test_apply_adam_with_amsgrad_grad_op_vmap2(): vmap_amsgrad = AmsgradNetVmap2(cal_amsgrad) _ = vmap_amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad) - ms_var, ms_m = vmap_amsgrad.var.asnumpy(), vmap_amsgrad.m.asnumpy() - ms_v, ms_vhat = vmap_amsgrad.v.asnumpy(), vmap_amsgrad.vhat.asnumpy() - np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np, - vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np) + ms_var = vmap_amsgrad.var.asnumpy() + np_var = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np, + vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np) np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error) - np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error) - np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error) - np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error) diff --git a/tests/st/ops/gpu/test_apply_adam_with_amsgrad_op.py b/tests/st/ops/gpu/test_apply_adam_with_amsgrad_op.py index 70682a9d44b..ec2eecc90b2 100644 --- a/tests/st/ops/gpu/test_apply_adam_with_amsgrad_op.py +++ b/tests/st/ops/gpu/test_apply_adam_with_amsgrad_op.py @@ -49,7 +49,7 @@ def numpy_apply_adam_with_amsgrad(var, m, v, vhat, grad, beta1=0.9, beta2=0.999, v = v * beta2 + grad * grad * (1 - beta2) vhat = np.maximum(vhat, v) var = var - new_lr * m / (np.sqrt(vhat) + eps) - return var, m, v, vhat + return var @pytest.mark.level0 @@ -71,14 +71,11 @@ def test_apply_adam_with_amsgrad_op(data_type): grad = Tensor(grad_np) output = amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad) - ms_var, ms_m, ms_v, ms_vhat = output[0].asnumpy(), output[1].asnumpy(), output[2].asnumpy(), output[3].asnumpy() - np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(amsgrad.var_np, amsgrad.m_np, - amsgrad.v_np, amsgrad.vhat_np, grad_np) + ms_var = output.asnumpy() + np_var = numpy_apply_adam_with_amsgrad(amsgrad.var_np, amsgrad.m_np, + amsgrad.v_np, amsgrad.vhat_np, grad_np) np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error) - np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error) - np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error) - np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error) class AmsgradNetVmap(nn.Cell): @@ -119,15 +116,11 @@ def test_apply_adam_witm_amsgrad_op_vmap(): vmap_amsgrad = AmsgradNetVmap(cal_amsgrad) _ = vmap_amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad) - ms_var, ms_m = vmap_amsgrad.var.asnumpy(), vmap_amsgrad.m.asnumpy() - ms_v, ms_vhat = vmap_amsgrad.v.asnumpy(), vmap_amsgrad.vhat.asnumpy() - np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np, - vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np) + ms_var = vmap_amsgrad.var.asnumpy() + np_var = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np, + vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np) np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error) - np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error) - np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error) - np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error) class AmsgradNetVmap2(nn.Cell): @@ -167,15 +160,9 @@ def test_apply_adam_with_amsgrad_grad_op_vmap2(): grad_np = np.random.randn(*shape).astype(np.float32) grad = Tensor(grad_np) - vmap_amsgrad = AmsgradNetVmap2(cal_amsgrad) _ = vmap_amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad) - ms_var, ms_m = vmap_amsgrad.var.asnumpy(), vmap_amsgrad.m.asnumpy() - ms_v, ms_vhat = vmap_amsgrad.v.asnumpy(), vmap_amsgrad.vhat.asnumpy() - np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np, - vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np) - + ms_var = vmap_amsgrad.var.asnumpy() + np_var = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np, + vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np) np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error) - np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error) - np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error) - np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error)