fix bugs for applyadamwithamsgrad
This commit is contained in:
parent
ed1e4cc7b9
commit
815a06f410
|
@ -29,7 +29,7 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kApplyAdamWithAmsgradInputsNum = 8;
|
||||
constexpr size_t kApplyAdamWithAmsgradOutputsNum = 4;
|
||||
constexpr size_t kApplyAdamWithAmsgradOutputsNum = 1;
|
||||
constexpr size_t kScalarIndex = 0;
|
||||
constexpr size_t kIndexVar = 0;
|
||||
constexpr size_t kIndexM = 1;
|
||||
|
@ -65,9 +65,9 @@ bool ApplyAdamWithAmsgradCpuKernelMod::Init(const BaseOperatorPtr &base_operator
|
|||
return false;
|
||||
}
|
||||
|
||||
beta1_[0] = kernel_ptr->get_beta1();
|
||||
beta2_[0] = kernel_ptr->get_beta2();
|
||||
epsilon_[0] = kernel_ptr->get_epsilon();
|
||||
beta1_ = kernel_ptr->get_beta1();
|
||||
beta2_ = kernel_ptr->get_beta2();
|
||||
epsilon_ = kernel_ptr->get_epsilon();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -139,15 +139,14 @@ void ApplyAdamWithAmsgradCpuKernelMod::LaunchApplyAdamWithAmsgrad(const std::vec
|
|||
T *lr = reinterpret_cast<T *>(inputs[kIndexLr]->addr);
|
||||
T *gradient = reinterpret_cast<T *>(inputs[kIndexGrad]->addr);
|
||||
|
||||
T beta1 = static_cast<T>(beta1_[0]);
|
||||
T beta2 = static_cast<T>(beta2_[0]);
|
||||
T epsilon = static_cast<T>(epsilon_[0]);
|
||||
T beta1 = static_cast<T>(beta1_);
|
||||
T beta2 = static_cast<T>(beta2_);
|
||||
T epsilon = static_cast<T>(epsilon_);
|
||||
|
||||
constexpr float ONE = 1.0;
|
||||
T ONE = static_cast<T>(1.0);
|
||||
for (int64_t b = 0; b < batch_size_; b++) {
|
||||
// multithreading
|
||||
T new_lr = static_cast<T>(static_cast<float>(lr[b]) * std::sqrt(ONE - static_cast<float>(beta2_power[b])) /
|
||||
(ONE - static_cast<float>(beta1_power[b])));
|
||||
T new_lr = lr[b] * static_cast<T>(std::sqrt(static_cast<double>(ONE - beta2_power[b]))) / (ONE - beta1_power[b]);
|
||||
auto task = [this, &var, &m, &v, &vhat, &gradient, new_lr, beta1, beta2, epsilon](size_t start, size_t end) {
|
||||
T one = static_cast<T>(1.0);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
@ -215,13 +214,7 @@ std::vector<KernelAttr> ApplyAdamWithAmsgradCpuKernelMod::GetOpSupport() {
|
|||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutInRef(0, 0)
|
||||
.AddOutInRef(1, 1)
|
||||
.AddOutInRef(2, 2)
|
||||
.AddOutInRef(3, 3),
|
||||
.AddOutInRef(0, 0),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
|
@ -232,13 +225,7 @@ std::vector<KernelAttr> ApplyAdamWithAmsgradCpuKernelMod::GetOpSupport() {
|
|||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutInRef(0, 0)
|
||||
.AddOutInRef(1, 1)
|
||||
.AddOutInRef(2, 2)
|
||||
.AddOutInRef(3, 3)};
|
||||
.AddOutInRef(0, 0)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
|
|
|
@ -46,9 +46,9 @@ class ApplyAdamWithAmsgradCpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
void LaunchApplyAdamWithAmsgrad(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
std::vector<float> beta1_ = {0.9};
|
||||
std::vector<float> beta2_ = {0.999};
|
||||
std::vector<float> epsilon_ = {1e-8};
|
||||
float beta1_{0.9};
|
||||
float beta2_{0.999};
|
||||
float epsilon_{1e-8};
|
||||
|
||||
int64_t batch_rank_{0};
|
||||
int64_t batch_size_{1};
|
||||
|
|
|
@ -25,7 +25,7 @@ __device__ __forceinline__ T sqrtFunc(T x) {
|
|||
|
||||
template <>
|
||||
__device__ __forceinline__ half sqrtFunc(half x) {
|
||||
return sqrt(__half2float(x));
|
||||
return hsqrt(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -35,51 +35,55 @@ __device__ __forceinline__ T maxFunc(T x, T y) {
|
|||
|
||||
template <>
|
||||
__device__ __forceinline__ half maxFunc(half x, half y) {
|
||||
return __half2float(x) > __half2float(y)? __half2float(x) : __half2float(y);
|
||||
return x > y? x : y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void CalApplyAdamWithAmsgradKernel(const size_t input_elements, const int64_t batch_size, T *var, T *m,
|
||||
__global__ void CalApplyAdamWithAmsgradKernel(const size_t size, const int64_t batch_size, T *var, T *m,
|
||||
T *v, T *vhat, T *beta1_power, T *beta2_power, const T *lr,
|
||||
const T *grad, const float beta1, const float beta2,
|
||||
const float epsilon) {
|
||||
auto all_elements = input_elements * batch_size;
|
||||
const T *grad, const T beta1, const T beta2,
|
||||
const T epsilon, T *output_var) {
|
||||
auto all_elements = size * batch_size;
|
||||
const T one = static_cast<T>(1.0);
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < all_elements; pos += gridDim.x * blockDim.x) {
|
||||
auto batch = pos / input_elements;
|
||||
auto batch = pos / size;
|
||||
auto new_learning_rate = lr[batch] * sqrtFunc(one - beta2_power[batch]) / (one - beta1_power[batch]);
|
||||
m[pos] += (grad[pos] - m[pos]) * (one - static_cast<T>(beta1));
|
||||
v[pos] += (grad[pos] * grad[pos] - v[pos]) * (one - static_cast<T>(beta2));
|
||||
vhat[pos] = maxFunc(vhat[pos], v[pos]);
|
||||
var[pos] -= new_learning_rate * m[pos] / (sqrtFunc(vhat[pos]) + static_cast<T>(epsilon));
|
||||
output_var[pos] = var[pos];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalApplyAdamWithAmsgrad(const size_t input_elements, const int64_t batch_size, T *var, T *m, T *v, T *vhat,
|
||||
T *beta1_power, T *beta2_power, const T *lr, const T *grad, const float beta1,
|
||||
const float beta2, const float epsilon, const uint32_t &device_id,
|
||||
void CalApplyAdamWithAmsgrad(const size_t size, const int64_t batch_size, T *var, T *m, T *v, T *vhat,
|
||||
T *beta1_power, T *beta2_power, const T *lr, const T *grad, const T beta1,
|
||||
const T beta2, const T epsilon, T *output_var, const uint32_t &device_id,
|
||||
cudaStream_t stream_ptr) {
|
||||
CalApplyAdamWithAmsgradKernel<<<CUDA_BLOCKS(device_id, input_elements * batch_size), CUDA_THREADS(device_id), 0,
|
||||
stream_ptr>>>(input_elements, batch_size, var, m, v, vhat, beta1_power, beta2_power,
|
||||
lr, grad, beta1, beta2, epsilon);
|
||||
CalApplyAdamWithAmsgradKernel<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0,
|
||||
stream_ptr>>>(size, batch_size, var, m, v, vhat, beta1_power, beta2_power,
|
||||
lr, grad, beta1, beta2, epsilon, output_var);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalApplyAdamWithAmsgrad<double>(const size_t size, const int64_t batch_size, double *var,
|
||||
double *m, double *v, double *vhat, double *beta1_power,
|
||||
double *beta2_power, const double *lr,
|
||||
const double *grad, const float beta1, const float beta2,
|
||||
const float epsilon, const uint32_t &device_id,
|
||||
const double *grad, const double beta1,
|
||||
const double beta2, const double epsilon,
|
||||
double *output_var, const uint32_t &device_id,
|
||||
cudaStream_t stream_ptr);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalApplyAdamWithAmsgrad<float>(const size_t size, const int64_t batch_size, float *var,
|
||||
float *m, float *v, float *vhat, float *beta1_power,
|
||||
float *beta2_power, const float *lr, const float *grad,
|
||||
const float beta1, const float beta2, const float epsilon,
|
||||
const uint32_t &device_id, cudaStream_t stream_ptr);
|
||||
float *output_var, const uint32_t &device_id,
|
||||
cudaStream_t stream_ptr);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalApplyAdamWithAmsgrad<half>(const size_t size, const int64_t batch_size, half *var,
|
||||
half *m, half *v, half *vhat, half *beta1_power,
|
||||
half *beta2_power, const half *lr, const half *grad,
|
||||
const float beta1, const float beta2, const float epsilon,
|
||||
const uint32_t &device_id, cudaStream_t stream_ptr);
|
||||
const half beta1, const half beta2, const half epsilon,
|
||||
half *output_var, const uint32_t &device_id,
|
||||
cudaStream_t stream_ptr);
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalApplyAdamWithAmsgrad(const size_t size, const int64_t batch_size, T *var, T *m, T *v, T *vhat,
|
||||
T *beta1_power, T *beta2_power, const T *lr, const T *grad,
|
||||
const float beta1, const float beta2, const float epsilon,
|
||||
const T beta1, const T beta2, const T epsilon, T *output_var,
|
||||
const uint32_t &device_id, cudaStream_t stream_ptr);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_APPLY_ADAM_WITH_AMSGRAD_IMPL_CUH_
|
||||
|
|
|
@ -20,14 +20,12 @@
|
|||
#include "abstract/utils.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "include/curand.h"
|
||||
#include "mindspore/core/ops/apply_adam_with_amsgrad.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kApplyAdamWithAmsgradInputsNum = 8;
|
||||
constexpr size_t kApplyAdamWithAmsgradOutputsNum = 4;
|
||||
constexpr size_t kScalarIndex = 0;
|
||||
constexpr size_t kApplyAdamWithAmsgradOutputsNum = 1;
|
||||
constexpr size_t kIndexVar = 0;
|
||||
constexpr size_t kIndexM = 1;
|
||||
constexpr size_t kIndexV = 2;
|
||||
|
@ -41,6 +39,13 @@ constexpr size_t kIndexGrad = 7;
|
|||
bool ApplyAdamWithAmsgradGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
if (inputs.size() != kApplyAdamWithAmsgradInputsNum || outputs.size() != kApplyAdamWithAmsgradOutputsNum) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output size should be " << kApplyAdamWithAmsgradInputsNum
|
||||
<< " and " << kApplyAdamWithAmsgradOutputsNum << ", but got " << inputs.size() << " and "
|
||||
<< outputs.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
kernel_name_ = base_operator->name();
|
||||
batch_rank_ = base_operator->get_batch_rank();
|
||||
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::ApplyAdamWithAmsgrad>(base_operator);
|
||||
|
@ -48,12 +53,6 @@ bool ApplyAdamWithAmsgradGpuKernelMod::Init(const BaseOperatorPtr &base_operator
|
|||
beta1_ = kernel_ptr_->get_beta1();
|
||||
beta2_ = kernel_ptr_->get_beta2();
|
||||
epsilon_ = kernel_ptr_->get_epsilon();
|
||||
if (inputs.size() != kApplyAdamWithAmsgradInputsNum || outputs.size() != kApplyAdamWithAmsgradOutputsNum) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output size should be " << kApplyAdamWithAmsgradInputsNum
|
||||
<< " and " << kApplyAdamWithAmsgradOutputsNum << ", but got " << inputs.size() << " and "
|
||||
<< outputs.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
|
@ -63,7 +62,7 @@ bool ApplyAdamWithAmsgradGpuKernelMod::Init(const BaseOperatorPtr &base_operator
|
|||
}
|
||||
|
||||
kernel_func_ = func_list_[index].second;
|
||||
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndexVar).first);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -75,6 +74,7 @@ int ApplyAdamWithAmsgradGpuKernelMod::Resize(const BaseOperatorPtr &base_operato
|
|||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
input_elements_ = 0;
|
||||
|
||||
std::vector<int64_t> var_shape = inputs[kIndexVar]->GetShapeVector();
|
||||
std::vector<int64_t> m_shape = inputs[kIndexM]->GetShapeVector();
|
||||
|
@ -150,24 +150,31 @@ int ApplyAdamWithAmsgradGpuKernelMod::Resize(const BaseOperatorPtr &base_operato
|
|||
bool ApplyAdamWithAmsgradGpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs, void *stream_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(stream_ptr);
|
||||
kernel_func_(this, inputs, outputs, stream_ptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ApplyAdamWithAmsgradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &, void *stream_ptr) {
|
||||
auto var = reinterpret_cast<T *>(inputs[kIndexVar]->addr);
|
||||
auto m = reinterpret_cast<T *>(inputs[kIndexM]->addr);
|
||||
auto v = reinterpret_cast<T *>(inputs[kIndexV]->addr);
|
||||
auto vhat = reinterpret_cast<T *>(inputs[kIndexVhat]->addr);
|
||||
auto beta1_power = reinterpret_cast<T *>(inputs[kIndexBeta1Power]->addr);
|
||||
auto beta2_power = reinterpret_cast<T *>(inputs[kIndexBeta2Power]->addr);
|
||||
auto lr = reinterpret_cast<T *>(inputs[kIndexLr]->addr);
|
||||
auto grad = reinterpret_cast<T *>(inputs[kIndexGrad]->addr);
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
auto *var = reinterpret_cast<T *>(inputs[kIndexVar]->addr);
|
||||
auto *m = reinterpret_cast<T *>(inputs[kIndexM]->addr);
|
||||
auto *v = reinterpret_cast<T *>(inputs[kIndexV]->addr);
|
||||
auto *vhat = reinterpret_cast<T *>(inputs[kIndexVhat]->addr);
|
||||
auto *beta1_power = reinterpret_cast<T *>(inputs[kIndexBeta1Power]->addr);
|
||||
auto *beta2_power = reinterpret_cast<T *>(inputs[kIndexBeta2Power]->addr);
|
||||
auto *lr = reinterpret_cast<T *>(inputs[kIndexLr]->addr);
|
||||
auto *grad = reinterpret_cast<T *>(inputs[kIndexGrad]->addr);
|
||||
|
||||
CalApplyAdamWithAmsgrad(input_elements_, batch_size_, var, m, v, vhat, beta1_power, beta2_power, lr, grad, beta1_,
|
||||
beta2_, epsilon_, device_id_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
T beta1 = static_cast<T>(beta1_);
|
||||
T beta2 = static_cast<T>(beta2_);
|
||||
T epsilon = static_cast<T>(epsilon_);
|
||||
|
||||
auto *output_var = reinterpret_cast<T *>(outputs[kIndexVar]->addr);
|
||||
|
||||
CalApplyAdamWithAmsgrad(input_elements_, batch_size_, var, m, v, vhat, beta1_power, beta2_power, lr, grad, beta1,
|
||||
beta2, epsilon, output_var, device_id_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -183,13 +190,7 @@ std::vector<std::pair<KernelAttr, ApplyAdamWithAmsgradGpuKernelMod::KernelFunc>>
|
|||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutInRef(0, 0)
|
||||
.AddOutInRef(1, 1)
|
||||
.AddOutInRef(2, 2)
|
||||
.AddOutInRef(3, 3),
|
||||
.AddOutInRef(0, 0),
|
||||
&ApplyAdamWithAmsgradGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -201,13 +202,7 @@ std::vector<std::pair<KernelAttr, ApplyAdamWithAmsgradGpuKernelMod::KernelFunc>>
|
|||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutInRef(0, 0)
|
||||
.AddOutInRef(1, 1)
|
||||
.AddOutInRef(2, 2)
|
||||
.AddOutInRef(3, 3),
|
||||
.AddOutInRef(0, 0),
|
||||
&ApplyAdamWithAmsgradGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
|
@ -219,13 +214,7 @@ std::vector<std::pair<KernelAttr, ApplyAdamWithAmsgradGpuKernelMod::KernelFunc>>
|
|||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutInRef(0, 0)
|
||||
.AddOutInRef(1, 1)
|
||||
.AddOutInRef(2, 2)
|
||||
.AddOutInRef(3, 3),
|
||||
.AddOutInRef(0, 0),
|
||||
&ApplyAdamWithAmsgradGpuKernelMod::LaunchKernel<half>}};
|
||||
|
||||
std::vector<KernelAttr> ApplyAdamWithAmsgradGpuKernelMod::GetOpSupport() {
|
||||
|
|
|
@ -30,7 +30,6 @@
|
|||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/apply_adam_with_amsgrad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -60,14 +59,14 @@ class ApplyAdamWithAmsgradGpuKernelMod : public NativeGpuKernelMod {
|
|||
KernelFunc kernel_func_{};
|
||||
static std::vector<std::pair<KernelAttr, KernelFunc>> func_list_;
|
||||
|
||||
int unit_size_;
|
||||
size_t input_elements_;
|
||||
size_t unit_size_{0};
|
||||
size_t input_elements_{0};
|
||||
int64_t batch_rank_;
|
||||
int64_t batch_size_;
|
||||
|
||||
float beta1_ = 0.9;
|
||||
float beta2_ = 0.999;
|
||||
float epsilon_ = 1e-8;
|
||||
float beta1_{0.9};
|
||||
float beta2_{0.999};
|
||||
float epsilon_{1e-8};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,21 +30,20 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::TupleShapePtr ApplyAdamWithAmsgradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
abstract::ShapePtr ApplyAdamWithAmsgradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto prim_name = primitive->name();
|
||||
auto var_shape = input_args[0]->BuildShape();
|
||||
auto m_shape = input_args[1]->BuildShape();
|
||||
auto v_shape = input_args[2]->BuildShape();
|
||||
auto vhat_shape = input_args[3]->BuildShape();
|
||||
auto grad_shape = input_args[7]->BuildShape();
|
||||
auto beta1_power_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape];
|
||||
auto beta2_power_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[5]->BuildShape())[kShape];
|
||||
auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[6]->BuildShape())[kShape];
|
||||
auto grad_shape = input_args[7]->BuildShape();
|
||||
|
||||
int64_t batch_rank = 0;
|
||||
if (primitive->HasAttr(kBatchRank)) {
|
||||
|
@ -60,8 +59,7 @@ abstract::TupleShapePtr ApplyAdamWithAmsgradInferShape(const PrimitivePtr &primi
|
|||
|
||||
if (var_shape->IsDynamic() || m_shape->IsDynamic() || v_shape->IsDynamic() || vhat_shape->IsDynamic() ||
|
||||
grad_shape->IsDynamic()) {
|
||||
return std::make_shared<abstract::TupleShape>(
|
||||
std::vector<abstract::BaseShapePtr>{var_shape, m_shape, v_shape, vhat_shape});
|
||||
return var_shape->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
// shape of var, m, v, vhat must be the same
|
||||
|
@ -78,14 +76,12 @@ abstract::TupleShapePtr ApplyAdamWithAmsgradInferShape(const PrimitivePtr &primi
|
|||
<< ".";
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::TupleShape>(
|
||||
std::vector<abstract::BaseShapePtr>{var_shape, m_shape, v_shape, vhat_shape});
|
||||
auto shape_ptr = var_shape->cast<abstract::ShapePtr>();
|
||||
return shape_ptr;
|
||||
}
|
||||
|
||||
TuplePtr ApplyAdamWithAmsgradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
TypePtr ApplyAdamWithAmsgradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = prim->name();
|
||||
// get all input_args' shape
|
||||
auto var_type = input_args[0]->BuildType();
|
||||
auto m_type = input_args[1]->BuildType();
|
||||
auto v_type = input_args[2]->BuildType();
|
||||
|
@ -102,15 +98,22 @@ TuplePtr ApplyAdamWithAmsgradInferType(const PrimitivePtr &prim, const std::vect
|
|||
(void)args.insert(std::make_pair("v_type", v_type));
|
||||
(void)args.insert(std::make_pair("vhat_type", vhat_type));
|
||||
(void)args.insert(std::make_pair("grad_type", grad_type));
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
// beta1_power, beta2_power, lr type valid
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("beta1_power_type", beta1_power_type, valid_types, prim_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("beta2_power_type", beta2_power_type, valid_types, prim_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("lr_type", lr_type, valid_types, prim_name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type, v_type, vhat_type});
|
||||
return var_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ApplyAdamWithAmsgrad::Init(const float beta1, const float beta2, const float epsilon, const bool use_locking) {
|
||||
this->set_beta1(beta1);
|
||||
this->set_beta2(beta2);
|
||||
this->set_epsilon(epsilon);
|
||||
this->set_use_locking(use_locking);
|
||||
}
|
||||
|
||||
void ApplyAdamWithAmsgrad::set_beta1(const float beta1) { (void)this->AddAttr(kBeta1, api::MakeValue(beta1)); }
|
||||
|
||||
void ApplyAdamWithAmsgrad::set_beta2(const float beta2) { (void)this->AddAttr(kBeta2, api::MakeValue(beta2)); }
|
||||
|
@ -146,7 +149,7 @@ AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, c
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 8;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
auto infer_type = ApplyAdamWithAmsgradInferType(primitive, input_args);
|
||||
auto infer_shape = ApplyAdamWithAmsgradInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
|
|
|
@ -29,9 +29,12 @@ class MIND_API ApplyAdamWithAmsgrad : public BaseOperator {
|
|||
public:
|
||||
MIND_API_BASE_MEMBER(ApplyAdamWithAmsgrad);
|
||||
ApplyAdamWithAmsgrad() : BaseOperator(kNameApplyAdamWithAmsgrad) {
|
||||
InitIOName({"var", "m", "v", "vhat", "beta1_power", "beta2_power", "lr", "grad"}, {"var", "m", "v", "vhat"});
|
||||
InitIOName({"var", "m", "v", "vhat", "beta1_power", "beta2_power", "lr", "grad"}, {"var"});
|
||||
}
|
||||
|
||||
void Init(const float beta1 = 0.9, const float beta2 = 0.999, const float epsilon = 1e-8,
|
||||
const bool use_locking = false);
|
||||
|
||||
void set_beta1(const float beta1);
|
||||
|
||||
void set_beta2(const float beta2);
|
||||
|
|
|
@ -507,7 +507,7 @@ class Adam(Optimizer):
|
|||
self.use_amsgrad = use_amsgrad
|
||||
self.moment1 = self._parameters.clone(prefix="moment1", init='zeros')
|
||||
self.moment2 = self._parameters.clone(prefix="moment2", init='zeros')
|
||||
self.vhat = self._parameters.clone(prefix="vhat", init=-100000)
|
||||
self.vhat = self._parameters.clone(prefix="vhat", init='zeros')
|
||||
|
||||
self._is_device = True
|
||||
if use_amsgrad:
|
||||
|
|
|
@ -1198,8 +1198,8 @@ def get_apply_adam_with_amsgrad_rule(prim, axis_size):
|
|||
ValueError("The source axis of `var` is None, "
|
||||
"but the source axis of `m/v/vhat/beta1_power/beta2_power/lr/grad` is not None. "
|
||||
"The execution of operator `{}` cannot be guaranteed.".format(prim_name))
|
||||
out_var, out_m, out_v, out_vhat = prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad)
|
||||
return ((out_var, None), (out_m, None), (out_v, None), (out_vhat, None))
|
||||
out_var = prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad)
|
||||
return (out_var, None)
|
||||
|
||||
if any(dim != 0 for dim in [var_dim, m_dim, v_dim, vhat_dim]):
|
||||
ValueError("For `{}`, the source axis of `var/m/v/vhat` must be 0, "
|
||||
|
@ -1211,8 +1211,8 @@ def get_apply_adam_with_amsgrad_rule(prim, axis_size):
|
|||
lr = _bdim_at_front(lr, lr_dim, axis_size)
|
||||
grad = _bdim_at_front(grad, grad_dim, axis_size)
|
||||
|
||||
out_var, out_m, out_v, out_vhat = batch_prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad)
|
||||
return ((out_var, 0), (out_m, 0), (out_v, 0), (out_vhat, 0))
|
||||
out_var = batch_prim(var, m, v, vhat, beta1_power, beta2_power, lr, grad, u_monad)
|
||||
return (out_var, 0)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ def numpy_apply_adam_with_amsgrad(var, m, v, vhat, grad, beta1=0.9, beta2=0.999,
|
|||
v = v * beta2 + grad * grad * (1 - beta2)
|
||||
vhat = np.maximum(vhat, v)
|
||||
var = var - new_lr * m / (np.sqrt(vhat) + eps)
|
||||
return var, m, v, vhat
|
||||
return var
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -72,14 +72,10 @@ def test_apply_adam_with_amsgrad_op(data_type):
|
|||
grad = Tensor(grad_np)
|
||||
|
||||
output = amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad)
|
||||
ms_var, ms_m, ms_v, ms_vhat = output[0].asnumpy(), output[1].asnumpy(), output[2].asnumpy(), output[3].asnumpy()
|
||||
np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(amsgrad.var_np, amsgrad.m_np,
|
||||
amsgrad.v_np, amsgrad.vhat_np, grad_np)
|
||||
ms_var = output.asnumpy()
|
||||
np_var = numpy_apply_adam_with_amsgrad(amsgrad.var_np, amsgrad.m_np, amsgrad.v_np, amsgrad.vhat_np, grad_np)
|
||||
|
||||
np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error)
|
||||
|
||||
|
||||
class AmsgradNetVmap(nn.Cell):
|
||||
|
@ -121,15 +117,11 @@ def test_apply_adam_witm_amsgrad_op_vmap():
|
|||
|
||||
vmap_amsgrad = AmsgradNetVmap(cal_amsgrad)
|
||||
_ = vmap_amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad)
|
||||
ms_var, ms_m = vmap_amsgrad.var.asnumpy(), vmap_amsgrad.m.asnumpy()
|
||||
ms_v, ms_vhat = vmap_amsgrad.v.asnumpy(), vmap_amsgrad.vhat.asnumpy()
|
||||
np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
|
||||
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
|
||||
ms_var = vmap_amsgrad.var.asnumpy()
|
||||
np_var = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
|
||||
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
|
||||
|
||||
np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error)
|
||||
|
||||
|
||||
class AmsgradNetVmap2(nn.Cell):
|
||||
|
@ -173,12 +165,8 @@ def test_apply_adam_with_amsgrad_grad_op_vmap2():
|
|||
|
||||
vmap_amsgrad = AmsgradNetVmap2(cal_amsgrad)
|
||||
_ = vmap_amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad)
|
||||
ms_var, ms_m = vmap_amsgrad.var.asnumpy(), vmap_amsgrad.m.asnumpy()
|
||||
ms_v, ms_vhat = vmap_amsgrad.v.asnumpy(), vmap_amsgrad.vhat.asnumpy()
|
||||
np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
|
||||
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
|
||||
ms_var = vmap_amsgrad.var.asnumpy()
|
||||
np_var = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
|
||||
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
|
||||
|
||||
np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error)
|
||||
|
|
|
@ -49,7 +49,7 @@ def numpy_apply_adam_with_amsgrad(var, m, v, vhat, grad, beta1=0.9, beta2=0.999,
|
|||
v = v * beta2 + grad * grad * (1 - beta2)
|
||||
vhat = np.maximum(vhat, v)
|
||||
var = var - new_lr * m / (np.sqrt(vhat) + eps)
|
||||
return var, m, v, vhat
|
||||
return var
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -71,14 +71,11 @@ def test_apply_adam_with_amsgrad_op(data_type):
|
|||
grad = Tensor(grad_np)
|
||||
|
||||
output = amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad)
|
||||
ms_var, ms_m, ms_v, ms_vhat = output[0].asnumpy(), output[1].asnumpy(), output[2].asnumpy(), output[3].asnumpy()
|
||||
np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(amsgrad.var_np, amsgrad.m_np,
|
||||
amsgrad.v_np, amsgrad.vhat_np, grad_np)
|
||||
ms_var = output.asnumpy()
|
||||
np_var = numpy_apply_adam_with_amsgrad(amsgrad.var_np, amsgrad.m_np,
|
||||
amsgrad.v_np, amsgrad.vhat_np, grad_np)
|
||||
|
||||
np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error)
|
||||
|
||||
|
||||
class AmsgradNetVmap(nn.Cell):
|
||||
|
@ -119,15 +116,11 @@ def test_apply_adam_witm_amsgrad_op_vmap():
|
|||
|
||||
vmap_amsgrad = AmsgradNetVmap(cal_amsgrad)
|
||||
_ = vmap_amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad)
|
||||
ms_var, ms_m = vmap_amsgrad.var.asnumpy(), vmap_amsgrad.m.asnumpy()
|
||||
ms_v, ms_vhat = vmap_amsgrad.v.asnumpy(), vmap_amsgrad.vhat.asnumpy()
|
||||
np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
|
||||
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
|
||||
ms_var = vmap_amsgrad.var.asnumpy()
|
||||
np_var = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
|
||||
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
|
||||
|
||||
np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error)
|
||||
|
||||
|
||||
class AmsgradNetVmap2(nn.Cell):
|
||||
|
@ -167,15 +160,9 @@ def test_apply_adam_with_amsgrad_grad_op_vmap2():
|
|||
grad_np = np.random.randn(*shape).astype(np.float32)
|
||||
grad = Tensor(grad_np)
|
||||
|
||||
|
||||
vmap_amsgrad = AmsgradNetVmap2(cal_amsgrad)
|
||||
_ = vmap_amsgrad(Tensor(0.9), Tensor(0.999), Tensor(0.01), grad)
|
||||
ms_var, ms_m = vmap_amsgrad.var.asnumpy(), vmap_amsgrad.m.asnumpy()
|
||||
ms_v, ms_vhat = vmap_amsgrad.v.asnumpy(), vmap_amsgrad.vhat.asnumpy()
|
||||
np_var, np_m, np_v, np_vhat = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
|
||||
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
|
||||
|
||||
ms_var = vmap_amsgrad.var.asnumpy()
|
||||
np_var = numpy_apply_adam_with_amsgrad(vmap_amsgrad.var_np, vmap_amsgrad.m_np,
|
||||
vmap_amsgrad.v_np, vmap_amsgrad.vhat_np, grad_np)
|
||||
np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_m, np_m, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_v, np_v, rtol=error, atol=error)
|
||||
np.testing.assert_allclose(ms_vhat, np_vhat, rtol=error, atol=error)
|
||||
|
|
Loading…
Reference in New Issue