!40296 SparseApplyRMSProp CPU version

Merge pull request !40296 from ivanshan_8170/sparseApplyRMSProp
This commit is contained in:
i-robot 2022-09-16 07:23:33 +00:00 committed by Gitee
commit 45e91de765
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 677 additions and 5 deletions

View File

@ -0,0 +1,268 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/sparse_apply_r_m_s_prop_cpu_kernel.h"
#include <algorithm>
#include <iostream>
#include "kernel/common_utils.h"
#include "mindspore/core/ops/sparse_apply_r_m_s_prop.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kSparseApplyRMSPropOutputsNum = 3;
constexpr size_t kSparseApplyRMSPropInputsNum = 6;
constexpr size_t kIndicesDim = 1;
constexpr size_t kSparseApplyRMSPropWorkspaceSize = 4;
constexpr char kKernelName[] = "SparseApplyRMSProp";
using KernelRunFunc = SparseApplyRMSPropCpuKernelMod::KernelRunFunc;
#define ADD_INPUT_ATTR(var_type, indices_type) \
.AddInputAttr(var_type) \
.AddInputAttr(var_type) \
.AddInputAttr(var_type) \
.AddInputAttr(var_type) \
.AddInputAttr(var_type) \
.AddInputAttr(indices_type)
#define ADD_OI_REF_SAME_PLACE(ind1, ind2, ind3) .AddOutInRef(ind1, ind1).AddOutInRef(ind2, ind2).AddOutInRef(ind3, ind3)
#define CPU_FUNLIST_KERNEL_REGISTER(var_type, var_fun_type, indices_type, indices_fun_type) \
{ \
KernelAttr() ADD_INPUT_ATTR(var_type, indices_type) \
.AddOutputAttr(var_type) \
.AddOutputAttr(var_type) \
.AddOutputAttr(var_type) ADD_OI_REF_SAME_PLACE(0, 1, 2), \
&SparseApplyRMSPropCpuKernelMod::LaunchKernel<indices_fun_type, var_fun_type> \
}
} // namespace
bool SparseApplyRMSPropCpuKernelMod::ResizedInputSize(const std::vector<KernelTensorPtr> &inputs) {
var_shape_ = inputs.at(kIndex0)->GetShapeVector();
if (var_shape_.empty()) {
MS_EXCEPTION(ValueError) << "For '" << kKernelName
<< "', the dimension of 'var' must be at least 1, but got scalar or None.";
return false;
}
var_first_dim_size_ = var_shape_[kDim0];
auto ms_shape = inputs.at(kIndex1)->GetShapeVector();
if (!IsSameShape(var_shape_, ms_shape)) {
MS_EXCEPTION(ValueError) << "For '" << kKernelName
<< "', the shape of 'ms' must be the same as the shape of 'var', "
"but got the shape of 'ms': "
<< Vector2Str(ms_shape) << " and the shape of 'var': " << Vector2Str(var_shape_);
return false;
}
auto mom_shape = inputs.at(kIndex2)->GetShapeVector();
if (!IsSameShape(var_shape_, mom_shape)) {
MS_EXCEPTION(ValueError) << "For '" << kKernelName
<< "', the shape of 'mom' must be the same as the shape of 'var', "
"but got the shape of 'mom': "
<< Vector2Str(mom_shape) << " and the shape of 'var': " << Vector2Str(var_shape_);
return false;
}
// scalar
auto lr_shape = inputs.at(kIndex3)->GetShapeVector();
if (!lr_shape.empty()) {
MS_EXCEPTION(ValueError)
<< "For '" << kKernelName
<< "', 'lr' must be a scalar; thus, its dimension must be 0, but got the dimension of 'lr': "
<< Vector2Str(lr_shape);
return false;
}
auto grad_shape = inputs.at(kIndex4)->GetShapeVector();
for (size_t i = 1; i < var_shape_.size(); ++i) {
if (var_shape_[i] != grad_shape[i]) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of 'var' and 'grad' must be equal in dimension i=" << i
<< ", but got 'var_shape[i]': " << var_shape_[i] << " and 'grad_shape[i]': " << grad_shape[i];
return KRET_RESIZE_FAILED;
}
var_outer_dim_size_ *= var_shape_[i];
}
if (!IsSameShape(var_shape_, grad_shape)) {
MS_EXCEPTION(ValueError) << "For '" << kKernelName
<< "', the shape of 'grad' must be the same as the shape of 'var', "
"but got the shape of 'grad': "
<< Vector2Str(mom_shape) << " and the shape of 'var': " << Vector2Str(var_shape_);
return false;
}
auto indices_shape = inputs.at(kIndex5)->GetShapeVector();
if (indices_shape.size() != kIndicesDim) {
MS_LOG(EXCEPTION) << "For '" << kKernelName
<< "', the 'indices' must be a 1-D Tensor, but got shape: " << Vector2Str(indices_shape);
return false;
}
if (indices_shape[kDim0] != var_shape_[kDim0]) {
MS_EXCEPTION(ValueError) << "For '" << kKernelName
<< "', the indices.shape[0] must be equal to var.shape[0], but got 'var_shape[0]': "
<< var_shape_[kDim0] << " and 'indices_shape[0]': " << indices_shape[kDim0];
return false;
}
indices_size_ = indices_shape[kDim0];
return true;
}
bool SparseApplyRMSPropCpuKernelMod::ResizedOutputSize(const std::vector<KernelTensorPtr> &outputs) {
auto output_var_shape = outputs[kIndex0]->GetShapeVector();
if (!IsSameShape(var_shape_, output_var_shape)) {
MS_EXCEPTION(ValueError) << "For '" << kKernelName
<< "', the shape of output 'var' must be the same as the shape of input 'var', but got "
"the shape of output 'var': "
<< Vector2Str(output_var_shape)
<< ", and the shape of input 'var': " << Vector2Str(var_shape_);
return false;
}
auto output_ms_shape = outputs[kIndex1]->GetShapeVector();
if (!IsSameShape(var_shape_, output_ms_shape)) {
MS_EXCEPTION(ValueError) << "For '" << kKernelName
<< "', the shape of output 'ms' must be the same as the shape of input 'ms', "
"but got the shape of output 'ms': "
<< Vector2Str(output_ms_shape)
<< " and the shape of input 'ms': " << Vector2Str(var_shape_);
return false;
}
auto output_mom_shape = outputs[kIndex2]->GetShapeVector();
if (!IsSameShape(var_shape_, output_mom_shape)) {
MS_EXCEPTION(ValueError) << "For '" << kKernelName
<< "', the shape of output 'mom' must be the same as the shape of output 'mom', "
"but got the shape of output 'mom': "
<< Vector2Str(output_mom_shape)
<< " and the shape of output 'mom': " << Vector2Str(var_shape_);
return false;
}
return true;
}
void SparseApplyRMSPropCpuKernelMod::ResetResource() noexcept {
input_size_list_.clear();
output_size_list_.clear();
indices_size_ = 0;
var_first_dim_size_ = 0;
var_outer_dim_size_ = 1;
}
int SparseApplyRMSPropCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
MS_EXCEPTION_IF_NULL(base_operator);
ResetResource();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseApplyRMSPropInputsNum, kKernelName);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseApplyRMSPropOutputsNum, kKernelName);
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto kernel_ptr = std::dynamic_pointer_cast<ops::SparseApplyRMSProp>(base_operator);
if (kernel_ptr == nullptr) {
MS_LOG(ERROR) << "Cast op from BaseOperator to SparseApplyRMSProp failed.";
return KRET_RESIZE_FAILED;
}
if (!ResizedInputSize(inputs)) {
return KRET_RESIZE_FAILED;
}
if (!ResizedOutputSize(outputs)) {
return KRET_RESIZE_FAILED;
}
return KRET_OK;
}
bool SparseApplyRMSPropCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::SparseApplyRMSProp>(base_operator);
if (kernel_ptr == nullptr) {
MS_LOG(ERROR) << "Cast op from BaseOperator to SparseApplyRMSProp failed.";
return false;
}
rho_ = kernel_ptr->get_rho();
if (rho_ > 1 || rho_ < 0) {
MS_EXCEPTION(ValueError) << "For '" << kKernelName
<< "', the argument rho should be between 0 and 1, but got the value of rho: " << rho_;
return false;
}
momentum_ = kernel_ptr->get_momentum();
if (momentum_ < 0) {
MS_EXCEPTION(ValueError) << "For '" << kKernelName
<< "', the argument momentum should be no less than 0, but got the value of momentum: "
<< momentum_;
return false;
}
epsilon_ = kernel_ptr->get_epsilon();
if (epsilon_ <= 0) {
MS_EXCEPTION(ValueError) << "For '" << kKernelName
<< "', the argument momentum should be greater than 0, but got the value of epsilon: "
<< epsilon_;
return false;
}
return MatchKernelFunc(base_operator, inputs, outputs);
}
template <typename I, typename T>
bool SparseApplyRMSPropCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
auto *var = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
auto *ms = reinterpret_cast<T *>(inputs.at(kIndex1)->addr);
auto *mom = reinterpret_cast<T *>(inputs.at(kIndex2)->addr);
auto lr = reinterpret_cast<T *>(inputs.at(kIndex3)->addr)[kDim0];
auto *grad = reinterpret_cast<T *>(inputs.at(kIndex4)->addr);
auto *indices = reinterpret_cast<I *>(inputs.at(kIndex5)->addr);
const auto rho = this->rho_;
const auto momentum = this->momentum_;
const auto epsilon = this->epsilon_;
auto var_first_dim_size = static_cast<size_t>(this->var_first_dim_size_);
auto var_outer_dim_size = this->var_outer_dim_size_;
auto task = [var, ms, mom, grad, indices, &lr, &rho, &momentum, &epsilon, &var_first_dim_size, &var_outer_dim_size](
size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
const int indices_pos = i / var_outer_dim_size;
const int inner_pos = i % var_outer_dim_size;
size_t index = static_cast<size_t>(indices[indices_pos]);
if (LongToSize(index) >= var_first_dim_size) {
MS_LOG(EXCEPTION) << "For '" << kKernelName << "', each element in 'indices' must be in range [0, "
<< SizeToLong(var_first_dim_size) << "), but got " << index;
}
const size_t cur_pos = index * var_outer_dim_size + inner_pos;
const float grad_t = static_cast<float>(grad[i]);
float msf = static_cast<float>(ms[cur_pos]);
if (grad_t != 0) {
msf = msf * rho + grad_t * grad_t * (1.0f - rho);
ms[cur_pos] = static_cast<T>(msf);
}
mom[cur_pos] = static_cast<T>(static_cast<float>(mom[cur_pos]) * momentum +
1 / sqrt(msf + epsilon) * static_cast<float>(lr) * grad_t);
var[cur_pos] -= mom[cur_pos];
}
};
ParallelLaunchAutoSearch(task, var_first_dim_size * var_outer_dim_size, this, &parallel_search_info_);
return true;
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &SparseApplyRMSPropCpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
CPU_FUNLIST_KERNEL_REGISTER(kNumberTypeFloat32, float, kNumberTypeInt32, int),
CPU_FUNLIST_KERNEL_REGISTER(kNumberTypeFloat32, float, kNumberTypeInt64, int64_t),
CPU_FUNLIST_KERNEL_REGISTER(kNumberTypeFloat16, float16, kNumberTypeInt32, int),
CPU_FUNLIST_KERNEL_REGISTER(kNumberTypeFloat16, float16, kNumberTypeInt64, int64_t),
};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseApplyRMSProp, SparseApplyRMSPropCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,63 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SPARSE_APPLY_R_M_S_PORP_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SPARSE_APPLY_R_M_S_PORP_H_
#include <map>
#include <vector>
#include <utility>
#include "plugin/device/cpu/kernel/sparse_optimizer_cpu_kernel.h"
namespace mindspore {
namespace kernel {
class SparseApplyRMSPropCpuKernelMod : public SparseOptimizerCpuKernelMod,
public MatchKernelHelper<SparseApplyRMSPropCpuKernelMod> {
public:
SparseApplyRMSPropCpuKernelMod() { ResetResource(); }
~SparseApplyRMSPropCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool ResizedInputSize(const std::vector<KernelTensorPtr> &inputs);
bool ResizedOutputSize(const std::vector<KernelTensorPtr> &outputs);
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
void ResetResource() noexcept;
private:
template <typename I, typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
float rho_;
float momentum_;
float epsilon_;
ShapeVector var_shape_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SPARSE_APPLY_R_M_S_PORP_H_

View File

@ -192,6 +192,7 @@ constexpr auto kPreNmsTopn = "pre_nms_topn";
constexpr auto kRankSize = "rank_size";
constexpr auto kRatio = "ratio";
constexpr auto kReduction = "reduction";
constexpr auto kRho = "rho";
constexpr auto kRootRank = "root_rank";
constexpr auto kRoundMode = "round_mode";
constexpr auto kRtol = "rtol";

View File

@ -15,10 +15,9 @@
*/
#include "ops/sparse_apply_r_m_s_prop.h"
#include <algorithm>
#include <set>
#include <utility>
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/tensor_construct_utils.h"
@ -104,6 +103,46 @@ TuplePtr SparseApplyRMSPropInferType(const PrimitivePtr &prim, const std::vector
}
} // namespace
// SparseApplyRMSProp Rho getter method
float SparseApplyRMSProp::get_rho() const {
auto value_ptr = this->GetAttr(kRho);
return GetValue<float>(value_ptr);
}
// SparseApplyRMSProp Rho setter method
void SparseApplyRMSProp::set_rho(const float rho) { (void)this->AddAttr(kRho, api::MakeValue(rho)); }
// SparseApplyRMSProp Momentum getter method
float SparseApplyRMSProp::get_momentum() const {
auto value_ptr = this->GetAttr(kMomentum);
return GetValue<float>(value_ptr);
}
// SparseApplyRMSProp Momentum setter method
void SparseApplyRMSProp::set_momentum(const float momentum) {
(void)this->AddAttr(kMomentum, api::MakeValue(momentum));
}
// SparseApplyRMSProp Epsilon getter method
float SparseApplyRMSProp::get_epsilon() const {
auto value_ptr = this->GetAttr(kEpsilon);
return GetValue<float>(value_ptr);
}
// SparseApplyRMSProp Epsilon setter method
void SparseApplyRMSProp::set_epsilon(const float epsilon) { (void)this->AddAttr(kEpsilon, api::MakeValue(epsilon)); }
// SparseApplyRMSProp Use_Locking getz`ter method
bool SparseApplyRMSProp::get_use_locking() const {
auto value_ptr = this->GetAttr(kUseLocking);
return GetValue<bool>(value_ptr);
}
// SparseApplyRMSProp Use_Locking setter method
void SparseApplyRMSProp::set_use_locking(const bool use_locking) {
(void)this->AddAttr(kUseLocking, api::MakeValue(use_locking));
}
MIND_API_OPERATOR_IMPL(SparseApplyRMSProp, BaseOperator);
AbstractBasePtr SparseApplyRMSPropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {

View File

@ -36,6 +36,32 @@ class MIND_API SparseApplyRMSProp : public BaseOperator {
SparseApplyRMSProp() : BaseOperator(kNameSparseApplyRMSProp) {
InitIOName({"var", "ms", "mom", "lr", "grad", "indices"}, {"var", "ms", "mom"});
}
/// \brief Set rho, the decay rate.
void set_rho(const float epsilon);
/// \brief Get rho.
///
/// \return rho.
float get_rho() const;
/// \brief Set momentum.
void set_momentum(const float momentum);
/// \brief Get momentum.
///
/// \return momentum.
float get_momentum() const;
/// \brief Set epsilon, A small value (float) added for numerical stability.
void set_epsilon(const float epsilon);
/// \brief Get epsilon.
///
/// \return epsilon.
float get_epsilon() const;
/// \brief Set use_locking, A bool where if True, updating var, ms and mom is protected by a lock. Default: False.
void set_use_locking(const bool use_locking);
/// \brief Get use_locking.
///
/// \return use_locking.
bool get_use_locking() const;
};
abstract::AbstractBasePtr SparseApplyRMSPropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -8768,11 +8768,11 @@ class SparseApplyRMSProp(Primitive):
the relatively highest priority data type.
Args:
rho (float): Decay rate. The value should between 0 and 1, otherwise the behavior is undefined.
rho (float): Decay rate. The value should be between 0 and 1, otherwise the behavior is undefined.
momentum (float): Momentum. The value should be greater or equal to 0, otherwise the behavior is undefined.
epsilon (float): A small value added for numerical stability. The value should be greater than 0,
otherwise the behavior is undefined.
use_locking (bool): If `True`, updating of the var, ms, and mom tensors is protected by a lock;
use_locking (bool): If `True`, updating of the var, ms, and mom tensors are protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention. Default: False.
Inputs:
@ -8811,7 +8811,7 @@ class SparseApplyRMSProp(Primitive):
RuntimeError: If the data type of `var`, `ms`, `mom` and `grad` conversion of Parameter is not supported.
Supported Platforms:
``Ascend``
``Ascend`` ``CPU``
Examples:
>>> class SparseApplyRMSPropNet(nn.Cell):

View File

@ -0,0 +1,275 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.ops.operations as P
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class SparseApplyRMSPropNet(nn.Cell):
def __init__(self, rho, momentum, epsilon, use_locking=False):
super(SparseApplyRMSPropNet, self).__init__()
self.sparse_apply_r_m_s_prop = P.SparseApplyRMSProp(rho, momentum, epsilon, use_locking)
self.var = Parameter(Tensor(np.array([[0.6, 0.3], [0.1, 0.5]]).astype(np.float32)), name="var")
self.ms = Parameter(Tensor(np.array([[0.2, 0.4], [0.1, 0.3]]).astype(np.float32)), name="ms")
self.mom = Parameter(Tensor(np.array([[0.3, 0.1], [0.3, 0.6]]).astype(np.float32)), name="mom")
def construct(self, learning_rate, grad, indices):
out = self.sparse_apply_r_m_s_prop(self.var, self.ms, self.mom, learning_rate, grad, indices)
return out
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sparse_apply_rms_prop():
"""
Feature: test SparseApplyRMSProp in cpu
Description: docs params, attr and input
Expectation: the results and expects are within 1e-6
"""
rho = 0.2
momentum = 0.01
epsilon = 1e-6
net = SparseApplyRMSPropNet(rho, momentum, epsilon)
learning_rate = 0.01
tol = 1e-6
grad = np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32)
indices = np.array([0, 1], dtype=np.int32)
net.var = Parameter(Tensor(np.array([[0.6, 0.3], [0.1, 0.5]]).astype(np.float32)), name="var")
net.ms = Parameter(Tensor(np.array([[0.2, 0.4], [0.1, 0.3]]).astype(np.float32)), name="ms")
net.mom = Parameter(Tensor(np.array([[0.3, 0.1], [0.3, 0.6]]).astype(np.float32)), name="mom")
output_var, output_ms, output_mom = net(learning_rate, Tensor(grad), Tensor(indices))
expect_var = np.array([[0.5880358, 0.28881112], [0.09102397, 0.48342228]])
expect_ms = np.array([[0.112, 0.472], [0.028, 0.572]])
expect_mom = np.array([[0.01196417, 0.01118888], [0.00897604, 0.01657771]])
assert (abs(output_var.asnumpy() - expect_var) <= tol).all()
assert (abs(output_ms.asnumpy() - expect_ms) <= tol).all()
assert (abs(output_mom.asnumpy() - expect_mom) <= tol).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sparse_apply_rms_prop_fp32():
"""
Feature: test SparseApplyRMSProp in cpu
Description: normal params, attr and input in float32
Expectation: the results and expects are within 1e-6
"""
var = Tensor(
np.array(
[
[
[[1.7584051, 7.845357, 9.487755, 11.609518], [6.3358746, 9.710918, 10.127965, 10.117655]],
[[12.163624, 5.494794, 3.8711822, 1.3894155], [8.985711, 0.6518214, 7.3151374, 16.33593]],
[[8.341027, 5.162506, 8.352797, 5.554555], [4.9117146, 4.477907, 13.811077, 0.54865116]],
],
[
[[11.817743, 14.965637, 8.13786, 12.019079], [13.102469, 15.835658, 13.591752, 9.972791]],
[[17.454584, 11.351265, 13.24484, 3.8717928], [17.244823, 12.653173, 19.387028, 5.45228]],
[[18.595354, 0.32980376, 12.503356, 5.3955374], [0.47630417, 12.696551, 6.7440767, 12.151557]],
],
]
)
).astype(np.float32)
ms = Tensor(
np.array(
[
[
[[13.247066, 3.0132513, 15.529863, 7.0405197], [15.222864, 17.862719, 14.253433, 8.52769]],
[[4.603761, 7.4978523, 15.64114, 3.4454918], [8.88428, 14.043913, 2.6531525, 1.7218554]],
[[6.9842176, 4.660216, 12.589785, 11.106893], [17.857334, 1.9999982, 2.2025642, 13.055216]],
],
[
[[8.858172, 18.533686, 5.48135, 16.584848], [3.5365322, 2.140122, 11.01436, 1.4174879]],
[[18.309923, 12.984872, 16.118517, 2.7294059], [12.451426, 5.4134645, 16.591896, 4.5551147]],
[[5.5329094, 8.667258, 12.109718, 6.447345], [12.299871, 10.31546, 16.994408, 18.751486]],
],
]
)
).astype(np.float32)
mom = Tensor(
np.array(
[
[
[[1.8185945, 9.377954, 0.10671406, 19.155134], [10.460225, 15.26945, 18.154474, 3.1047785]],
[[14.950758, 2.8664052, 9.1753845, 13.3002205], [5.3172884, 4.909375, 5.1808786, 16.881796]],
[[11.970335, 3.5992355, 8.939086, 10.23226], [2.2149224, 11.196065, 5.0415382, 13.498018]],
],
[
[[19.054583, 8.202999, 5.3966255, 9.038197], [13.197036, 19.272615, 15.766206, 8.0324135]],
[[12.263951, 14.052368, 14.865421, 14.657042], [13.552727, 0.70198125, 2.8945522, 7.790198]],
[[2.3330674, 0.64346105, 19.878948, 14.215902], [18.90649, 4.7782664, 6.36722, 18.578365]],
],
]
)
).astype(np.float32)
rho = 0.2
momentum = 0.01
epsilon = 1e-6
net = SparseApplyRMSPropNet(rho, momentum, epsilon, True)
net.var = Parameter(var, name="var")
net.ms = Parameter(ms, name="ms")
net.mom = Parameter(mom, name="mom")
learning_rate = 0.01
tol = 1e-6
grad = np.array(
[
[
[[4.425984, 17.72997, 3.6272728, 14.553083], [7.809875, 1.0404425, 0.4167797, 1.4313234]],
[[15.876797, 19.840714, 0.19511667, 8.967148], [5.1575384, 9.222021, 6.7389107, 13.391502]],
[[3.3068883, 18.009441, 3.2276564, 8.246849], [12.699854, 18.070751, 7.0316415, 18.188854]],
],
[
[[15.942688, 10.274351, 10.572657, 6.9661407], [13.754183, 16.018494, 6.9371862, 2.9460514]],
[[16.671234, 17.091852, 7.828639, 4.098937], [8.028752, 9.3316345, 15.868357, 1.5713477]],
[[10.281095, 6.8612375, 0.5492036, 10.575689], [11.136571, 6.750351, 10.062054, 14.244425]],
],
]
).astype(np.float32)
indices = np.array([0, 1], dtype=np.int64)
output_var, output_ms, output_mom = net(learning_rate, Tensor(grad), Tensor(indices))
expect_var = np.array(
[
[
[[1.7298788, 7.7404103, 9.476863, 11.406833], [6.220425, 9.553286, 9.94401, 10.07878]],
[[12.002961, 5.454976, 3.7783306, 1.2452924], [8.921798, 0.5917712, 7.252229, 16.155945]],
[[8.210941, 5.1153536, 8.253608, 5.4412737], [4.8785367, 4.354775, 13.749543, 0.4025454]],
],
[
[[11.616066, 14.872664, 8.072782, 11.917966], [12.959345, 15.631763, 13.423217, 9.881508]],
[[17.320856, 11.199622, 13.085356, 3.7142625], [17.098377, 12.635059, 19.346992, 5.365129]],
[[18.560915, 0.31243756, 12.301201, 5.2422776], [0.276195, 12.637892, 6.6694517, 11.95472]],
],
]
).astype(np.float32)
expect_ms = np.array(
[
[
[[18.32088, 252.08414, 13.63166, 170.84189], [51.83989, 4.4385605, 2.989651, 3.3444874]],
[[202.57889, 316.4227, 3.1586843, 65.01689], [23.057018, 70.84532, 36.860966, 143.81024]],
[[10.145252, 260.40402, 10.852169, 56.629795], [132.60051, 261.64166, 39.9957, 267.2786]],
],
[
[[205.10707, 88.15657, 90.521126, 42.138668], [152.04935, 205.70174, 40.70252, 7.2268724]],
[[226.00603, 236.3021, 52.253777, 13.986909], [54.058975, 70.746216, 204.76218, 2.88633]],
[[85.667305, 39.39472, 2.6632433, 90.76563], [101.67854, 38.516884, 84.39482, 166.07321]],
],
]
).astype(np.float32)
expect_mom = np.array(
[
[
[[0.02852633, 0.1049465, 0.01089154, 0.2026855], [0.11544931, 0.15763302, 0.18395518, 0.03887438]],
[[0.16066249, 0.03981787, 0.09285168, 0.14412314], [0.06391378, 0.06005022, 0.06290836, 0.1799849]],
[[0.13008551, 0.04715267, 0.09918867, 0.11328146], [0.03317797, 0.12313244, 0.06153398, 0.14610578]],
],
[
[[0.20167777, 0.09297276, 0.06507868, 0.10111325], [0.14312467, 0.20389485, 0.16853563, 0.09128299]],
[[0.13372889, 0.15164241, 0.15948418, 0.15753041], [0.14644705, 0.01811427, 0.0400349, 0.08715107]],
[[0.03443857, 0.0173662, 0.20215482, 0.15325965], [0.20010917, 0.05865945, 0.07462509, 0.19683704]],
],
]
).astype(np.float32)
assert (abs(output_var.asnumpy() - expect_var) <= tol).all()
assert (abs(output_ms.asnumpy() - expect_ms) <= tol).all()
assert (abs(output_mom.asnumpy() - expect_mom) <= tol).all()
assert (abs(net.var.asnumpy() - expect_var) <= tol).all()
assert (abs(net.ms.asnumpy() - expect_ms) <= tol).all()
assert (abs(net.mom.asnumpy() - expect_mom) <= tol).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sparse_apply_rms_prop_update_fp16():
"""
Feature: test SparseApplyRMSProp in cpu
Description: random params, attr and input in float16. Update net's parameters
Expectation: the results, parameters and expects are within 1e-3
"""
var = np.array([[[0.2048, 2.107], [3.395, 3.107]], [[1.971, 3.18], [2.648, 1.034]]])
ms = np.array([[[4.93, 3.984], [4.25, 3.662]], [[0.6567, 4.86], [3.867, 2.898]]])
mom = np.array([[[1.537, 1.1], [4.668, 4.03]], [[0.5044, 1.44], [3.336, 3.855]]])
rho = 0.2
momentum = 0.01
epsilon = 1e-6
tol = 1e-3
net = SparseApplyRMSPropNet(rho, momentum, epsilon, True)
net.var = Parameter(Tensor(var, dtype=mindspore.float16), name="var")
net.ms = Parameter(Tensor(ms, dtype=mindspore.float16), name="ms")
net.mom = Parameter(Tensor(mom, dtype=mindspore.float16), name="mom")
learning_rate = Tensor(0.01, dtype=mindspore.float16)
grad = np.array([[[4.105, 1.056], [4.773, 1.278]], [[0.5186, 1.605], [2.549, 1.029]]]).astype(np.float16)
indices = np.array([0, 1], dtype=np.int32)
output_var, output_ms, output_mom = net(learning_rate, Tensor(grad, dtype=mindspore.float16), Tensor(indices))
expect_var = np.array(
[[[0.1787, 2.08787379], [3.336, 3.05774736]], [[1.95714428, 3.15638097], [2.605, 0.98683219]]]
).astype(np.float16)
expect_ms = np.array(
[[[14.46989893, 1.68834129], [19.07856445, 2.03968226]], [[0.34645917, 3.03402393], [5.97061985, 1.42716165]]]
).astype(np.float16)
expect_mom = np.array(
[[[0.026165, 0.01912621], [0.05761078, 0.04925264]], [[0.01385572, 0.02361903], [0.04379335, 0.04716781]]]
).astype(np.float16)
assert (abs(output_ms.asnumpy() - expect_ms) <= tol).all()
assert (abs(output_var.asnumpy() - expect_var) <= tol).all()
assert (abs(output_mom.asnumpy() - expect_mom) <= tol).all()
assert (abs(net.var.asnumpy() - expect_var) <= tol).all()
assert (abs(net.ms.asnumpy() - expect_ms) <= tol).all()
assert (abs(net.mom.asnumpy() - expect_mom) <= tol).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sparse_apply_rms_prop_grad0():
"""
Feature: test SparseApplyRMSProp in cpu
Description: input grad is zero
Expectation: parameter ms is not updated, but var and mom are
"""
rho = 0.2
momentum = 0.01
epsilon = 1e-6
net = SparseApplyRMSPropNet(rho, momentum, epsilon)
learning_rate = 0.01
tol = 1e-6
grad = np.array([[0, 0], [0, 0]]).astype(np.float32)
indices = np.array([0, 1], dtype=np.int32)
var = np.array([[0.6, 0.3], [0.1, 0.5]]).astype(np.float32)
ms = np.array([[0.2, 0.4], [0.1, 0.3]]).astype(np.float32)
mom = np.array([[0.3, 0.1], [0.3, 0.6]]).astype(np.float32)
net.var = Parameter(Tensor(var, dtype=mindspore.float32), name="var")
net.ms = Parameter(Tensor(ms, dtype=mindspore.float32), name="ms")
net.mom = Parameter(Tensor(mom, dtype=mindspore.float32), name="mom")
output_var, output_ms, output_mom = net(learning_rate, Tensor(grad), Tensor(indices))
expect_var = np.array([[0.597, 0.29900002], [0.097, 0.494]]).astype(np.float32)
expect_ms = np.array([[0.2, 0.4], [0.1, 0.3]]).astype(np.float32)
expect_mom = np.array([[0.003, 0.001], [0.003, 0.006]]).astype(np.float32)
assert (abs(output_ms.asnumpy() - expect_ms) <= tol).all()
assert (abs(output_var.asnumpy() - expect_var) <= tol).all()
assert (abs(output_mom.asnumpy() - expect_mom) <= tol).all()
assert (abs(net.var.asnumpy() - expect_var) <= tol).all()
assert (abs(net.ms.asnumpy() - expect_ms) <= tol).all()
assert (abs(net.mom.asnumpy() - expect_mom) <= tol).all()