!33061 [assistant][ops] Add new aicpu operator SparseApplyProximalGradientDescent

Merge pull request !33061 from 杨旭华/SparseApplyProximalGradientDescent
This commit is contained in:
i-robot 2022-07-24 15:08:11 +00:00 committed by Gitee
commit 460ecac5ff
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 658 additions and 0 deletions

View File

@ -0,0 +1,239 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/sparse_apply_proximal_gradient_descent_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include <memory>
#include <map>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kSparseApplyProximalGradientDescentInputsNum = 6;
constexpr size_t kSparseApplyProximalGradientDescentOutputsNum = 1;
using KernelRunFunc = SparseApplyProximalGradientDescentCpuKernelMod::KernelRunFunc;
#define ADD_KERNEL(t1, t2, t3, t4, t5, t6, t7) \
KernelAttr() \
.AddInputAttr(kNumberType##t1) \
.AddInputAttr(kNumberType##t2) \
.AddInputAttr(kNumberType##t3) \
.AddInputAttr(kNumberType##t4) \
.AddInputAttr(kNumberType##t5) \
.AddInputAttr(kNumberType##t6) \
.AddOutputAttr(kNumberType##t7)
} // namespace
bool SparseApplyProximalGradientDescentCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
return false;
}
if (inputs.size() != kSparseApplyProximalGradientDescentInputsNum) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input size must be " << kSparseApplyProximalGradientDescentInputsNum
<< ", but got " << inputs.size() << ".";
return false;
}
if (outputs.size() != kSparseApplyProximalGradientDescentOutputsNum) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', output size must be "
<< kSparseApplyProximalGradientDescentOutputsNum << ", but got " << outputs.size() << ".";
return false;
}
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
return true;
}
void SparseApplyProximalGradientDescentCpuKernelMod::ResetResouce() noexcept {
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
indices_data_type_ = kNumberTypeInt32;
indices_size_ = 0;
var_first_dim_size_ = 0;
var_outer_dim_size_ = 1;
}
int SparseApplyProximalGradientDescentCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
ResetResouce();
int ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
enum input_index : size_t { Var_no, Alpha_no, L1_no, L2_no, Grad_no, Indices_no };
ShapeVector var_shape = inputs[Var_no]->GetShapeVector();
ShapeVector alpha_shape = inputs[Alpha_no]->GetShapeVector();
ShapeVector l1_shape = inputs[L1_no]->GetShapeVector();
ShapeVector l2_shape = inputs[L2_no]->GetShapeVector();
ShapeVector grad_shape = inputs[Grad_no]->GetShapeVector();
ShapeVector indices_shape = inputs[Indices_no]->GetShapeVector();
if (var_shape.empty()) {
MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, var must be at least 1D.";
} else {
var_first_dim_size_ = var_shape[0];
}
if (var_shape.size() != grad_shape.size()) {
MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, rank(grad) should be same as rank(var), but "
"got rank(grad): "
<< grad_shape.size() << ", rank(var): " << var_shape.size() << ".";
}
for (size_t i = 1; i < var_shape.size(); ++i) {
if (var_shape[i] != grad_shape[i]) {
MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, the shape of var and grad must equal in dimension "
<< i << ".";
}
var_outer_dim_size_ *= var_shape[i];
}
if (indices_shape.size() != 1) {
MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, indices must be 1D, but got " << indices_shape.size()
<< "D.";
}
indices_size_ = indices_shape[0];
if (grad_shape[0] != SizeToLong(indices_size_)) {
MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, grad.shape[0] must be equal to indices.shape[0], but "
"got grad.shape[0]: "
<< grad_shape[0] << " indices.shape[0]: " << indices_size_ << ".";
}
if (!alpha_shape.empty()) {
MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, alpha is not a scalar, got shape: "
<< Vector2Str(alpha_shape) << ".";
}
if (!l1_shape.empty()) {
MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, l1 is not a scalar, got shape: "
<< Vector2Str(l1_shape) << ".";
}
if (!l2_shape.empty()) {
MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, l2 is not a scalar, got shape: "
<< Vector2Str(l2_shape) << ".";
}
return KRET_OK;
}
template <typename I, typename T>
bool SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseApplyProximalGradientDescentInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseApplyProximalGradientDescentOutputsNum, kernel_name_);
auto var = reinterpret_cast<T *>(inputs[0]->addr);
auto grad = reinterpret_cast<T *>(inputs[4]->addr);
auto indices = reinterpret_cast<I *>(inputs[5]->addr);
auto alpha_scalar = reinterpret_cast<T *>(inputs[1]->addr)[0];
auto l1_scalar = reinterpret_cast<T *>(inputs[2]->addr)[0];
auto l2_scalar = reinterpret_cast<T *>(inputs[3]->addr)[0];
auto output = reinterpret_cast<T *>(outputs[0]->addr);
for (size_t i = 0; i < indices_size_; i++) {
I index = indices[i];
if (index < 0 || LongToSize(index) >= var_first_dim_size_) {
MS_LOG(EXCEPTION)
<< "For SparseApplyProximalGradientDescent, values in indices should be [0, var.shape[0]), but got " << index
<< ".";
}
size_t start_index = var_outer_dim_size_ * static_cast<size_t>(index);
size_t end_index = start_index + var_outer_dim_size_;
for (size_t j = start_index, k = var_outer_dim_size_ * i; j < end_index; ++j, ++k) {
auto learning_rate = alpha_scalar;
auto prox_v = var[j];
prox_v -= grad[k] * learning_rate;
if (l1_scalar > static_cast<T>(0.0)) {
var[j] = (T)Sign(static_cast<double>(prox_v)) *
(T)std::fmax(std::fabs(static_cast<double>(prox_v)) -
static_cast<double>(learning_rate) * static_cast<double>(l1_scalar),
static_cast<double>(0.0)) /
(static_cast<T>(1.0) + l2_scalar * learning_rate);
} else {
var[j] = prox_v / (static_cast<T>(1.0) + l2_scalar * learning_rate);
}
}
}
size_t copy_size = var_first_dim_size_ * var_outer_dim_size_ * sizeof(T);
auto ret = memcpy_s(output, copy_size, var, copy_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "For SparseApplyProximalGradientDescent, memcpy_s error, errorno: " << ret << ".";
}
return true;
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &SparseApplyProximalGradientDescentCpuKernelMod::GetFuncList()
const {
static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list_ = {
{ADD_KERNEL(Int8, Int8, Int8, Int8, Int8, Int32, Int8),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int32_t, int8_t>},
{ADD_KERNEL(Int16, Int16, Int16, Int16, Int16, Int32, Int16),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int32_t, int16_t>},
{ADD_KERNEL(Int32, Int32, Int32, Int32, Int32, Int32, Int32),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int32_t, int32_t>},
{ADD_KERNEL(Int64, Int64, Int64, Int64, Int64, Int32, Int64),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int32_t, int64_t>},
{ADD_KERNEL(UInt8, UInt8, UInt8, UInt8, UInt8, Int32, UInt8),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int32_t, uint8_t>},
{ADD_KERNEL(UInt16, UInt16, UInt16, UInt16, UInt16, Int32, UInt16),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int32_t, uint16_t>},
{ADD_KERNEL(UInt32, UInt32, UInt32, UInt32, UInt32, Int32, UInt32),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int32_t, uint32_t>},
{ADD_KERNEL(UInt64, UInt64, UInt64, UInt64, UInt64, Int32, UInt64),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int32_t, uint64_t>},
{ADD_KERNEL(Float16, Float16, Float16, Float16, Float16, Int32, Float16),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int32_t, float16>},
{ADD_KERNEL(Float32, Float32, Float32, Float32, Float32, Int32, Float32),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int32_t, float>},
{ADD_KERNEL(Float64, Float64, Float64, Float64, Float64, Int32, Float64),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int32_t, double>},
{ADD_KERNEL(Int8, Int8, Int8, Int8, Int8, Int64, Int8),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int64_t, int8_t>},
{ADD_KERNEL(Int16, Int16, Int16, Int16, Int16, Int64, Int16),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int64_t, int16_t>},
{ADD_KERNEL(Int32, Int32, Int32, Int32, Int32, Int64, Int32),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int64_t, int32_t>},
{ADD_KERNEL(Int64, Int64, Int64, Int64, Int64, Int64, Int64),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int64_t, int64_t>},
{ADD_KERNEL(UInt8, UInt8, UInt8, UInt8, UInt8, Int64, UInt8),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int64_t, uint8_t>},
{ADD_KERNEL(UInt16, UInt16, UInt16, UInt16, UInt16, Int64, UInt16),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int64_t, uint16_t>},
{ADD_KERNEL(UInt32, UInt32, UInt32, UInt32, UInt32, Int64, UInt32),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int64_t, uint32_t>},
{ADD_KERNEL(UInt64, UInt64, UInt64, UInt64, UInt64, Int64, UInt64),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int64_t, uint64_t>},
{ADD_KERNEL(Float16, Float16, Float16, Float16, Float16, Int64, Float16),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int64_t, float16>},
{ADD_KERNEL(Float32, Float32, Float32, Float32, Float32, Int64, Float32),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int64_t, float>},
{ADD_KERNEL(Float64, Float64, Float64, Float64, Float64, Int64, Float64),
&SparseApplyProximalGradientDescentCpuKernelMod::LaunchKernel<int64_t, double>}};
return func_list_;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseApplyProximalGradientDescent,
SparseApplyProximalGradientDescentCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_CPU_KERNEL_H_
#include <utility>
#include <vector>
#include <map>
#include "mindspore/core/ops/sparse_apply_proximal_gradient_descent.h"
#include "plugin/device/cpu/kernel/sparse_optimizer_cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class SparseApplyProximalGradientDescentCpuKernelMod
: public SparseOptimizerCpuKernelMod,
public MatchKernelHelper<SparseApplyProximalGradientDescentCpuKernelMod> {
public:
SparseApplyProximalGradientDescentCpuKernelMod() = default;
~SparseApplyProximalGradientDescentCpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
template <typename I, typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
void ResetResouce() noexcept;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_CPU_KERNEL_H_

View File

@ -249,6 +249,7 @@ constexpr auto kSparseApplyAdagradDA = "SparseApplyAdagradDA";
constexpr auto kMaxPool3DWithArgmax = "MaxPool3DWithArgmax";
constexpr auto kUpsampleTrilinear3DGrad = "UpsampleTrilinear3DGrad";
constexpr auto kIFMR = "IFMR";
constexpr auto kSparseApplyProximalGradientDescent = "SparseApplyProximalGradientDescent";
// Random
constexpr auto kStandardNormal = "StandardNormal";
@ -814,6 +815,8 @@ GVAR_DEF(PrimitivePtr, kPrimAdaptiveMaxPool2DGrad, std::make_shared<Primitive>("
GVAR_DEF(PrimitivePtr, kPrimUpsampleNearest3DGrad, std::make_shared<Primitive>("UpsampleNearest3DGrad"));
GVAR_DEF(PrimitivePtr, kPrimUpsampleTrilinear3DGrad, std::make_shared<Primitive>("UpsampleTrilinear3DGrad"));
GVAR_DEF(PrimitivePtr, kPrimIFMR, std::make_shared<Primitive>(kIFMR));
GVAR_DEF(PrimitivePtr, kPrimSparseApplyProximalGradientDescent,
std::make_shared<Primitive>(kSparseApplyProximalGradientDescent));
// Comm ops
GVAR_DEF(PrimitivePtr, kPrimMirror, std::make_shared<Primitive>("_MirrorOperator"));

View File

@ -0,0 +1,121 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/sparse_apply_proximal_gradient_descent.h"
#include <algorithm>
#include <set>
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SparseApplyProximalGradientDescentInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto alpha_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto l1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto l2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[5]->BuildShape())[kShape];
auto scalar_shape = 0;
(void)CheckAndConvertUtils::CheckInteger("alpha_shape size", alpha_shape.size(), kEqual, scalar_shape, prim_name);
(void)CheckAndConvertUtils::CheckInteger("l1_shape size", l1_shape.size(), kEqual, scalar_shape, prim_name);
(void)CheckAndConvertUtils::CheckInteger("l2_shape size", l2_shape.size(), kEqual, scalar_shape, prim_name);
// Var dimension must be equal or greater than 1.
(void)CheckAndConvertUtils::CheckInteger("var dimension", var_shape.size(), kGreaterEqual, 1, prim_name);
if (var_shape.size() != grad_shape.size()) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', rank(grad) should be same as rank(var), but got rank(grad): " << grad_shape.size()
<< ", rank(var): " << var_shape.size() << ".";
}
for (size_t i = 1; i < var_shape.size(); ++i) {
if (var_shape[i] != grad_shape[i]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "'. the shape of var and grad must equal in dimension " << i
<< ".";
}
}
// Indices must be rank 1.
(void)CheckAndConvertUtils::CheckInteger("indices dimension", indices_shape.size(), kEqual, 1, prim_name);
if (indices_shape[0] != grad_shape[0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', grad.shape[0] must be equal to indices.shape[0], but got grad.shape[0]: "
<< grad_shape[0] << ", indices.shape[0]: " << indices_shape[0] << ".";
}
return std::make_shared<abstract::Shape>(var_shape);
}
TypePtr SparseApplyProximalGradientDescentInferType(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto var_type = input_args[0]->BuildType();
auto alpha_type = input_args[1]->BuildType();
auto l1_type = input_args[2]->BuildType();
auto l2_type = input_args[3]->BuildType();
auto grad_type = input_args[4]->BuildType();
auto indices_type = input_args[5]->BuildType();
std::map<std::string, TypePtr> args;
(void)args.insert({"var", var_type});
(void)args.insert({"alpha", alpha_type});
(void)args.insert({"l1", l1_type});
(void)args.insert({"l2", l2_type});
(void)args.insert({"grad", grad_type});
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, common_valid_types, prim_name);
const std::set<TypePtr> valid_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", indices_type, valid_types, prim_name);
return var_type;
}
} // namespace
MIND_API_OPERATOR_IMPL(SparseApplyProximalGradientDescent, BaseOperator);
void SparseApplyProximalGradientDescent::Init(const bool use_locking) { this->set_use_locking(use_locking); }
void SparseApplyProximalGradientDescent::set_use_locking(const bool use_locking) {
(void)this->AddAttr(kUseLocking, api::MakeValue(use_locking));
}
bool SparseApplyProximalGradientDescent::get_use_locking() const {
auto value_ptr = GetAttr(kUseLocking);
return GetValue<bool>(value_ptr);
}
AbstractBasePtr SparseApplyProximalGradientDescentInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int Inputs_num = 6;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, Inputs_num, primitive->name());
auto infer_type = SparseApplyProximalGradientDescentInferType(primitive, input_args);
auto infer_shape = SparseApplyProximalGradientDescentInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(SparseApplyProximalGradientDescent, prim::kPrimSparseApplyProximalGradientDescent,
SparseApplyProximalGradientDescentInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_H_
#define MINDSPORE_CORE_OPS_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSparseApplyProximalGradientDescent = "SparseApplyProximalGradientDescent";
class MIND_API SparseApplyProximalGradientDescent : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SparseApplyProximalGradientDescent);
SparseApplyProximalGradientDescent() : BaseOperator(kNameSparseApplyProximalGradientDescent) {}
void Init(const bool use_locking = false);
void set_use_locking(const bool use_locking);
bool get_use_locking() const;
};
abstract::AbstractBasePtr SparseApplyProximalGradientDescentInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using kPrimSparseApplyProximalGradientDescentPtr = std::shared_ptr<SparseApplyProximalGradientDescent>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SPARSE_APPLY_PROXIMAL_GRADIENT_DESCENT_H_

View File

@ -278,3 +278,4 @@ from .sparse_matrix_transpose import _sparse_matrix_transpose_aicpu
from .sparse_tensor_to_csr_sparse_matrix import _sparse_tensor_to_csr_sparse_matrix_aicpu
from .csr_sparse_matrix_to_sparse_tensor import _csr_sparse_matrix_to_sparse_tensor_aicpu
from .split import _split_aicpu
from .sparse_apply_proximal_gradient_descent import _sparse_apply_proximal_gradient_descent_aicpu

View File

@ -0,0 +1,79 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SparseApplyProximalGradientDescent"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
sparse_apply_proximal_gradient_descent_op_info = AiCPURegOp("SparseApplyProximalGradientDescent") \
.fusion_type("OPAQUE") \
.attr("use_locking", "bool") \
.input(0, "var", "required") \
.input(1, "alpha", "required") \
.input(2, "l1", "required") \
.input(3, "l2", "required") \
.input(4, "grad", "required") \
.input(5, "indices", "required") \
.output(0, "var", "required") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, \
DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, \
DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \
DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, \
DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, \
DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, \
DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, \
DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \
DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, \
DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, \
DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, \
DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \
DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, \
DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, \
DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, \
DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, \
DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \
DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, \
DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(sparse_apply_proximal_gradient_descent_op_info)
def _sparse_apply_proximal_gradient_descent_aicpu():
"""SparseApplyProximalGradientDescent"""
return

View File

@ -10454,3 +10454,86 @@ class SparseApplyAdagradDA(Primitive):
'grad', 'indices', 'lr', 'l1', 'l2', 'global_step'],
outputs=['var'])
validator.check_value_type("use_locking", use_locking, [bool], self.name)
class SparseApplyProximalGradientDescent(Primitive):
r"""
Sparse update '*var' as FOBOS algorithm with fixed learning rate.
.. math::
\begin{array}{ll} \\
\text{prox_v} = var - alpha \\
var = sign(\text{prox_v})/(1 + alpha * l2) * \max(\left| \text{prox_v} \right| - alpha * l1,0)
\end{array}
Inputs of `var` and `delta` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, the lower priority data type will be converted to
the relatively highest priority data type.
Args:
use_locking (bool): If `True`, the `var` tensors will be protected from being updated.
Default: False.
Inputs:
- **var** (Parameter) - Variable tensor to be updated. The data type must be int8, int16, int32, int64,
uint8, uint16, uint32, uint64, float16, float32 or float64.
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
- **alpha** (Union[Number, Tensor]) - Scaling factor. Must be a scalar with same type as `var`.
- **l1** (Union[Number, Tensor]) - L1 regularization. Must be a scalar with same type as `var`.
- **l2** (Union[Number, Tensor]) - l2 regularization. Must be a scalar with same type as `var`.
- **grad** (Tensor) - A tensor for gradient, has the same type as `var`,
and grad.shape[1:] = var.shape[1:] if rank(var) > 1.
- **indices** (Tensor) - A tensor of indices in the first dimension of `var` and `accum`.
If there are duplicates in `indices`, the behavior is undefined. Must be one of the
following types: int32, int64 and indices.shape[0] = grad.shape[0].
Outputs:
- **var** (Tensor) - Tensor, has the same shape and type as 'var'.
Raises:
TypeError: If `var`, `grad` or `indices` is not a Parameter..
TypeError: If `alpha`, `l1`, `l2` is neither a Number nor a Tensor.
TypeError: If `use_locking` is not a bool.
TypeError: If dtype of `var`, `alpha`, `l1`, `l2` or `grad` is not one of int8, int16,
int32, int64, uint8, uint16, uint32, uint64, float16, float32, float64.
TypeError: If dtype of `indices` is neither int32 nor int64.
ValueError: If the shape of `var` or `grad` is rank 0.
ValueError: If shape of `grad` is not same as `var`.
ValueError: If the shape of `alpha`, `l1` or `l2` is not rank 0.
ValueError: If shape of `indices` is not same as the shape of first dimension of `grad`.
RuntimeError: If the data type of `var`, `alpha`, `l1`, `l2`, `grad` conversion of Parameter
is not supported.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> import mindspore.ops.operations.nn_ops as nn_ops
>>> var = Tensor(np.array([[4.1, 7.2], [1.1, 3.0]]).astype(np.float32))
>>> alpha = Tensor(1.0, mstype.float32)
>>> l1 = Tensor(1.0, mstype.float32)
>>> l2 = Tensor(0.0, mstype.float32)
>>> grad = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
>>> sparse_apply_proximal_gradient_descent = nn_ops.SparseApplyProximalGradientDescent()
>>> output = sparse_apply_proximal_gradient_descent(var, alpha, l1, l2, grad, indices)
>>> print(output)
[[2.1 5.2]
[0. 1. ]]
"""
__mindspore_signature__ = (
sig.make_sig('var', dtype=sig.sig_dtype.T),
sig.make_sig('alpha', dtype=sig.sig_dtype.T),
sig.make_sig('l1', dtype=sig.sig_dtype.T),
sig.make_sig('l2', dtype=sig.sig_dtype.T),
sig.make_sig('grad', dtype=sig.sig_dtype.T),
sig.make_sig('indices', dtype=sig.sig_dtype.T1)
)
@prim_attr_register
def __init__(self, use_locking=False):
"""Initialize SparseApplyProximalGradientDescent."""
self.init_prim_io_names(inputs=['var', 'alpha', 'l1', 'l2', 'grad', 'indices'],
outputs=['var'])
validator.check_value_type("use_locking", use_locking, [bool], self.name)

View File

@ -114,6 +114,7 @@ from mindspore.ops.operations.sparse_ops import CSRSparseMatrixToSparseTensor
from mindspore.ops.operations.sparse_ops import SparseTensorToCSRSparseMatrix
from mindspore.ops.operations.other_ops import BlackmanWindow
from mindspore.ops.operations.nn_ops import SparseApplyCenteredRMSProp
from mindspore.ops.operations.nn_ops import SparseApplyProximalGradientDescent
from mindspore.nn.layer import normalization
from mindspore.ops.operations.array_ops import RightShift
from mindspore.ops.operations.array_ops import LeftShift
@ -1289,6 +1290,16 @@ class SparseApplyAdagradDANet(nn.Cell):
return out
class SparseApplyProximalGradientDescentNet(nn.Cell):
def __init__(self, use_locking=False):
super(SparseApplyProximalGradientDescentNet, self).__init__()
self.sparse_apply_proximal_gradient_descent = SparseApplyProximalGradientDescent(use_locking)
def construct(self, var, alpha, l1, l2, grad, indices):
out = self.sparse_apply_proximal_gradient_descent(var, alpha, l1, l2, grad, indices)
return out
test_case_math_ops = [
('Betainc', {
'block': Betainc(),
@ -3028,6 +3039,15 @@ test_case_nn_ops = [
Tensor(0.001, mstype.float32),
Tensor(1, mstype.int64)],
'skip': ['backward']}),
('SparseApplyProximalGradientDescent', {
'block': SparseApplyProximalGradientDescentNet(),
'desc_inputs': [Tensor(np.array([[0.4, 0.5], [0.3, 0.1]]).astype(np.float32)),
Tensor(0.01, mstype.float32),
Tensor(0.88, mstype.float32),
Tensor(0.3, mstype.float32),
Tensor(np.array([[0.2, 0.5], [0.3, 0.2]]).astype(np.float32)),
Tensor(np.array([0, 1]).astype(np.int32))],
'skip': ['backward']}),
]
test_case_array_ops = [
('LeftShift', {