!37773 support apply_proximal_gradient_descent cpu&vmap

Merge pull request !37773 from Yanzhi_YI/apply_proximal_gradient_descent
This commit is contained in:
i-robot 2022-07-15 08:55:52 +00:00 committed by Gitee
commit 294cfd8e26
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 525 additions and 18 deletions

View File

@ -0,0 +1,197 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/apply_proximal_gradient_descent_cpu_kernel.h"
#include <functional>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace {
constexpr size_t kApplyProximalGradientDescentInputsNum = 5;
constexpr size_t kApplyProximalGradientDescentOutputsNum = 1;
constexpr size_t kVarIndex = 0;
constexpr size_t kAlphaIndex = 1;
constexpr size_t kL1Index = 2;
constexpr size_t kL2Index = 3;
constexpr size_t kDeltaIndex = 4;
template <typename T>
int Sgn(T val) {
if (val > T(0)) {
return 1;
}
if (val < T(0)) {
return -1;
}
return 0;
}
template <typename T>
T Abs(T x) {
if (x >= T(0)) {
return x;
}
return -x;
}
template <typename T>
T Max(T x, T y) {
if (x > y) {
return x;
}
return y;
}
} // namespace
namespace mindspore {
namespace kernel {
bool ApplyProximalGradientDescentCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
dtype_ = inputs[0]->GetDtype();
batch_rank_ = base_operator->get_batch_rank();
return true;
}
int ApplyProximalGradientDescentCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
if (ret != 0) {
return ret;
}
if (input_size_list_.size() != kApplyProximalGradientDescentInputsNum) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' input size must be equal 5.";
return KRET_RESIZE_FAILED;
}
std::vector<int64_t> var_shape = inputs[kVarIndex]->GetShapeVector();
std::vector<int64_t> alpha_shape = inputs[kAlphaIndex]->GetShapeVector();
std::vector<int64_t> l1_shape = inputs[kL1Index]->GetShapeVector();
std::vector<int64_t> l2_shape = inputs[kL2Index]->GetShapeVector();
std::vector<int64_t> delta_shape = inputs[kDeltaIndex]->GetShapeVector();
if (var_shape.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the dimension of 'var' must be at least 1-D, but got scalar or None.";
return KRET_RESIZE_FAILED;
}
if (!IsSameShape(var_shape, delta_shape)) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the shape of 'delta' must be the same as the shape of 'var', "
"but got the shape of 'delta': "
<< Vector2Str(delta_shape) << " and the shape of 'var': " << Vector2Str(var_shape);
return KRET_RESIZE_FAILED;
}
if (!IsSameShape(alpha_shape, l1_shape)) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the shape of 'alpha' must be the same as the shape of 'l1', "
"but got the shape of 'alpha': "
<< Vector2Str(alpha_shape) << " and the shape of 'l1': " << Vector2Str(l1_shape);
return KRET_RESIZE_FAILED;
}
if (!IsSameShape(alpha_shape, l2_shape)) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the shape of 'alpha' must be the same as the shape of 'l2', "
"but got the shape of 'alpha': "
<< Vector2Str(alpha_shape) << " and the shape of 'l2': " << Vector2Str(l2_shape);
return KRET_RESIZE_FAILED;
}
if (batch_rank_ < 0 || alpha_shape.size() != static_cast<size_t>(batch_rank_)) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the shape size of 'alpha' must be equal to 'batch_rank', "
"but got the shape of 'alpha': "
<< Vector2Str(alpha_shape) << " and 'batch_rank': " << batch_rank_;
return KRET_RESIZE_FAILED;
}
batch_size_ = 1;
if (!alpha_shape.empty()) {
batch_size_ = std::accumulate(alpha_shape.begin(), alpha_shape.end(), batch_size_, std::multiplies<int64_t>());
}
input_elements_ = std::accumulate(var_shape.begin(), var_shape.end(), 1, std::multiplies<int64_t>());
if (batch_size_ <= 0) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', batch_size_ must be greater than 0, but got batch_size: " << batch_size_;
return KRET_RESIZE_FAILED;
}
input_elements_ = input_elements_ / batch_size_;
if (batch_rank_ > 1) {
if (var_shape.size() < alpha_shape.size()) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the shape size of 'var' must be greater than 'alpha_shape', but got the shape of 'var': "
<< Vector2Str(var_shape) << " and 'alpha_shape': " << Vector2Str(alpha_shape);
return KRET_RESIZE_FAILED;
}
std::vector<int64_t> var_batch_shape(var_shape.begin(), var_shape.begin() + batch_rank_);
if (!IsSameShape(alpha_shape, var_batch_shape)) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the batch shape of 'var' must be the same as the shape of 'alpha', "
"but got the batch shape of 'var': "
<< Vector2Str(var_batch_shape) << " and the shape of 'alpha': " << Vector2Str(alpha_shape);
return KRET_RESIZE_FAILED;
}
}
return ret;
}
bool ApplyProximalGradientDescentCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kApplyProximalGradientDescentInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kApplyProximalGradientDescentOutputsNum, kernel_name_);
if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
} else {
MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', input dtype only support float16 and float32, but got ["
<< dtype_ << "].";
}
return true;
}
template <typename T>
void ApplyProximalGradientDescentCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &) {
auto var_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto alpha_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto l1_addr = reinterpret_cast<T *>(inputs[2]->addr);
auto l2_addr = reinterpret_cast<T *>(inputs[3]->addr);
auto delta_addr = reinterpret_cast<T *>(inputs[4]->addr);
auto task = [this, &var_addr, &alpha_addr, &l1_addr, &l2_addr, &delta_addr](size_t start, size_t end) {
auto cur_input_elements = end - start;
for (int64_t b = 0; b < batch_size_; b++) {
auto offset = b * input_elements_ + start;
auto var_cur = var_addr + offset;
auto delta_cur = delta_addr + offset;
for (size_t pos = 0; pos < cur_input_elements; pos++) {
T prox_var = var_cur[pos] - alpha_addr[b] * delta_cur[pos];
if (l1_addr[b] > T(0)) {
var_cur[pos] = (T)Sgn(prox_var) * Max(Abs(prox_var) - alpha_addr[b] * l1_addr[b], T(0)) /
(T(1) + alpha_addr[b] * l2_addr[b]);
} else {
var_cur[pos] = prox_var / (T(1) + alpha_addr[b] * l2_addr[b]);
}
}
}
};
ParallelLaunchAutoSearch(task, input_elements_, this, &parallel_search_info_, pool_);
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ApplyProximalGradientDescent, ApplyProximalGradientDescentCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,79 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_APPLY_PROXIMAL_GRADIENT_DESCENT_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_APPLY_PROXIMAL_GRADIENT_DESCENT_CPU_KERNEL_H_
#include <map>
#include <algorithm>
#include <vector>
#include <memory>
#include <string>
#include "mindspore/core/ops/apply_proximal_gradient_descent.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class ApplyProximalGradientDescentCpuKernelMod : public NativeCpuKernelMod {
public:
ApplyProximalGradientDescentCpuKernelMod() = default;
~ApplyProximalGradientDescentCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutInRef(0, 0)
.AddOutInRef(1, 1),
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutInRef(0, 0)
.AddOutInRef(1, 1)};
return support_list;
}
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
TypeId dtype_{kTypeUnknown};
int64_t batch_rank_;
int64_t batch_size_;
int unit_size_;
size_t input_elements_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_APPLY_PROXIMAL_GRADIENT_DESCENT_CPU_KERNEL_H_

View File

@ -33,29 +33,26 @@ namespace {
abstract::ShapePtr ApplyProximalGradientDescentInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto var_shape = input_args[kInputIndex0]->BuildShape();
auto alpha_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
int64_t shp_len = alpha_shape.size();
std::string para_name = input_args[kInputIndex1]->ToString();
(void)CheckAndConvertUtils::CheckInteger(para_name, SizeToLong(shp_len), kLessEqual, 1, primitive->name());
if (shp_len == 1) {
(void)CheckAndConvertUtils::CheckInteger(para_name, alpha_shape[kInputIndex0], kEqual, 1, primitive->name());
}
auto l1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
shp_len = l1_shape.size();
para_name = input_args[kInputIndex2]->ToString();
(void)CheckAndConvertUtils::CheckInteger(para_name, SizeToLong(shp_len), kLessEqual, 1, primitive->name());
if (shp_len == 1) {
(void)CheckAndConvertUtils::CheckInteger(para_name, l1_shape[kInputIndex0], kEqual, 1, primitive->name());
}
auto l2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
shp_len = l2_shape.size();
para_name = input_args[kInputIndex3]->ToString();
(void)CheckAndConvertUtils::CheckInteger(para_name, SizeToLong(shp_len), kLessEqual, 1, primitive->name());
if (shp_len == 1) {
(void)CheckAndConvertUtils::CheckInteger(para_name, l2_shape[kInputIndex0], kEqual, 1, primitive->name());
}
auto delta_shape = input_args[kInputIndex4]->BuildShape();
size_t batch_rank = 0;
if (primitive->HasAttr(kBatchRank)) {
auto value_ptr = primitive->GetAttr(kBatchRank);
batch_rank = GetValue<int64_t>(value_ptr);
}
(void)CheckAndConvertUtils::CheckInteger("alpha_shape size", SizeToLong(alpha_shape.size()), kLessEqual, batch_rank,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("l1_shape size", SizeToLong(l1_shape.size()), kLessEqual, batch_rank,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("l2_shape size", SizeToLong(l2_shape.size()), kLessEqual, batch_rank,
prim_name);
// var and delta must have the same shape
auto var_shape_ptr = var_shape->cast<abstract::ShapePtr>();
auto delta_shape_ptr = delta_shape->cast<abstract::ShapePtr>();

View File

@ -388,6 +388,7 @@ _ops_vmap_clone_prim_dict = {"ApplyAdaMax": P.ApplyAdaMax,
"ApplyAdadelta": P.ApplyAdadelta,
"ApplyFtrl": P.ApplyFtrl,
"ApplyProximalAdagrad": P.ApplyProximalAdagrad,
"ApplyProximalGradientDescent": P.ApplyProximalGradientDescent,
"ApplyAdamWithAmsgrad": P.ApplyAdamWithAmsgrad,
"ApplyPowerSign": P.ApplyPowerSign,
"ApplyAdagradDA": P.ApplyAdagradDA,

View File

@ -208,6 +208,48 @@ def get_apply_proximal_adagrad_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(P.ApplyProximalGradientDescent)
def get_apply_proximal_gradient_descent_rule(prim, axis_size):
"""VmapRule for `ApplyProximalGradientDescent` operation."""
if hasattr(prim, 'batch_rank'):
batch_rank = prim.batch_rank + 1
else:
batch_rank = 1
prim_name = prim.name
batch_prim = _vmap_clone_prim(prim)
batch_prim.add_prim_attr('batch_rank', batch_rank)
def vmap_rule(var_bdim, alpha_bdim, l1_bdim, l2_bdim, delta_bdim, u_monad):
var, var_dim = var_bdim
alpha, alpha_dim = alpha_bdim
l1, l1_dim = l1_bdim
l2, l2_dim = l2_bdim
delta, delta_dim = delta_bdim
if var_dim is None:
if any(dim is not None for dim in [alpha_dim, l1_dim, l2_dim, delta_dim]):
ValueError("The source axis of `var` is None, but the source "
"axis of `alpha/l1/l2/delta` is not None. The execution order of "
"operator `{}` cannot be guaranteed.".format(prim_name))
var = prim(var, alpha, l1, l2, delta, u_monad)
return (var, None)
if var_dim != 0:
raise ValueError("For `{}`, the source axis of `var` must not equal to 0, "
"but got the source axis of `var`: {}.".format(prim_name, var_dim))
alpha = _bdim_at_front(alpha, alpha_dim, axis_size)
l1 = _bdim_at_front(l1, l1_dim, axis_size)
l2 = _bdim_at_front(l2, l2_dim, axis_size)
delta = _bdim_at_front(delta, delta_dim, axis_size)
var = batch_prim(var, alpha, l1, l2, delta, u_monad)
return (var, 0)
return vmap_rule
@vmap_rules_getters.register(NN.BCEWithLogitsLoss)
def get_bce_with_logits_loss_vamp_rule(prim, axis_size):
"""VmapRule for 'BCEWithLogitsLoss' ."""

View File

@ -0,0 +1,191 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore.ops.functional import vmap
class Net(nn.Cell):
def __init__(self, var, alpha, l1, l2):
super(Net, self).__init__()
self.var = Parameter(var, name="var")
self.alpha = alpha
self.l1 = l1
self.l2 = l2
self.apply_proximal_gradient_descent = P.ApplyProximalGradientDescent()
def construct(self, delta):
return self.apply_proximal_gradient_descent(self.var, self.alpha, self.l1, self.l2, delta)
def run_net(var, alpha, l1, l2, delta, expect):
net = Net(var, alpha, l1, l2)
output = net(delta)
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_apply_proximal_gradient_descent_float32():
"""
Feature: ApplyProximalGradientDescent cpu op.
Description: test data type is float32 in both graph mode and pynative mode.
Expectation: success.
"""
# data preparation
var_np = np.array([[0.1632949, 0.6505809, 0.41898054],
[0.6073093, 0.809577, 0.5305462]])
delta_np = np.array([[0.58472073, 0.5078854, 0.03992645],
[0.58894235, 0.3060052, 0.6934281]])
var = Tensor(var_np.astype(np.float32))
alpha = 0.01
l1 = 0.0
l2 = 0.0
delta = Tensor(delta_np.astype(np.float32))
expect = np.array([[0.1574477, 0.64550203, 0.41858128],
[0.60141987, 0.80651695, 0.5236119]], dtype=np.float32)
# run in graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
run_net(var, alpha, l1, l2, delta, expect)
# run in pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
run_net(var, alpha, l1, l2, delta, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_apply_proximal_gradient_descent_float16():
"""
Feature: ApplyProximalGradientDescent cpu op.
Description: test data type is float16 in both graph mode and pynative mode.
Expectation: success.
"""
# data preparation
var_np = np.array([[0.6636, 0.902, 0.574],
[0.6167, 0.4993, 0.6987]])
delta_np = np.array([[0.68, 0.749, 0.145],
[0.3599, 0.4841, 0.1714]])
var = Tensor(var_np.astype(np.float16))
alpha = 0.01
l1 = 0.2
l2 = 0.0
delta = Tensor(delta_np.astype(np.float16))
expect = np.array([[0.655, 0.8926, 0.571],
[0.6113, 0.4924, 0.695]], dtype=np.float16)
# run in graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
run_net(var, alpha, l1, l2, delta, expect)
# run in pynative mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
run_net(var, alpha, l1, l2, delta, expect)
class ProximalGradientDescentNetVmap(nn.Cell):
def __init__(self, net):
super(ProximalGradientDescentNetVmap, self).__init__()
self.net = net
self.var = Parameter(
Tensor(np.array([[[0.6, 0.4], [0.1, 0.5]], [[0.6, 0.4], [0.1, 0.5]]]).astype(np.float32)), name="var")
self.vmap_proximal_gradient_descent = vmap(self.net, in_axes=(
0, 0, None, None, 0), out_axes=0)
def construct(self, alpha, l1, l2, delta):
return self.vmap_proximal_gradient_descent(self.var, alpha, l1, l2, delta)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_apply_proximal_gradient_descent_op_vmap():
"""
Feature: ApplyProximalGradientDescent cpu kernel
Description: test the ApplyProximalGradientDescent vmap.
Expectation: match to np benchmark.
"""
def cal_proximal_gradient_descent(var, alpha, l1, l2, delta):
return P.ApplyProximalGradientDescent()(var, alpha, l1, l2, delta)
error = 1e-3
delta = Tensor(np.array([[[0.3, 0.7], [0.1, 0.8]], [
[0.3, 0.7], [0.1, 0.8]]]).astype(np.float32))
alpha = Tensor(np.array([0.01, 0.01]).astype(np.float32))
l1 = 0.0
l2 = 0.0
vmap_func = ProximalGradientDescentNetVmap(cal_proximal_gradient_descent)
output = vmap_func(alpha, l1, l2, delta)
mindspore_var_out = output[0].asnumpy()
print(mindspore_var_out)
expect_var = np.array([[0.597, 0.393], [0.099, 0.492]]).astype(np.float32)
np.testing.assert_allclose(mindspore_var_out, expect_var, rtol=error)
class ProximalGradientDescentNetVmap2(nn.Cell):
def __init__(self, net):
super(ProximalGradientDescentNetVmap2, self).__init__()
self.net = net
self.var = Parameter(
Tensor(np.array([[[[0.6, 0.4], [0.1, 0.5]], [[0.7, 0.4], [0.1, 0.5]]],
[[[0.8, 0.4], [0.1, 0.5]], [[0.9, 0.4], [0.1, 0.5]]]]).astype(np.float32)), name="var")
self.vmap_proximal_gradient_descent = vmap(vmap(self.net, in_axes=(
0, None, None, None, 0), out_axes=0), in_axes=(0, None, None, None, 0), out_axes=0)
def construct(self, alpha, l1, l2, delta):
return self.vmap_proximal_gradient_descent(self.var, alpha, l1, l2, delta)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_apply_proximal_adagrad_op_vmap2():
"""
Feature: ApplyProximalGradientDescent cpu kernel
Description: test the ApplyProximalGradientDescent vmap.
Expectation: match to np benchmark.
"""
def cal_proximal_gradient_descent(var, alpha, l1, l2, delta):
return P.ApplyProximalGradientDescent()(var, alpha, l1, l2, delta)
error = 1e-3
delta = Tensor(np.array([[[[0.3, 0.7], [0.1, 0.8]], [[0.3, 0.7], [0.1, 0.8]]], [
[[0.3, 0.7], [0.1, 0.8]], [[0.3, 0.7], [0.1, 0.8]]]]).astype(np.float32))
alpha = Tensor(0.2)
l1 = Tensor(0.1)
l2 = Tensor(0.0)
vmap_func = ProximalGradientDescentNetVmap2(cal_proximal_gradient_descent)
output = vmap_func(alpha, l1, l2, delta)
mindspore_var_out = output[0].asnumpy()
print(mindspore_var_out)
expect_var = np.array([[[0.52000004, 0.24], [0.05999999, 0.31999996]], [
[0.62, 0.24], [0.05999999, 0.31999996]]]).astype(np.float32)
np.testing.assert_allclose(mindspore_var_out, expect_var, rtol=error)