[feat] [assistant] [I4809E] add new data operator

This commit is contained in:
unknown 2021-12-07 14:09:07 +08:00
parent e1259078e6
commit 9969ef7260
6 changed files with 231 additions and 33 deletions

View File

@ -302,6 +302,7 @@ inline const PrimitivePtr kPrimAsinh = std::make_shared<Primitive>("Asinh");
inline const PrimitivePtr kPrimAcosh = std::make_shared<Primitive>("Acosh");
inline const PrimitivePtr kPrimAtanh = std::make_shared<Primitive>("Atanh");
inline const PrimitivePtr kPrimApplyGradientDescent = std::make_shared<Primitive>("ApplyGradientDescent");
inline const PrimitivePtr kPrimApplyPowerSignD = std::make_shared<Primitive>("ApplyPowerSign");
inline const PrimitivePtr kPrimBesselI0e = std::make_shared<Primitive>("BesselI0e");
inline const PrimitivePtr kPrimBesselI1e = std::make_shared<Primitive>("BesselI1e");
inline const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad");

View File

@ -0,0 +1,120 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/apply_power_sign_d.h"
#include <vector>
#include <memory>
#include <set>
#include <map>
#include <string>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
namespace {
const int64_t kInputNum = 7;
abstract::TupleShapePtr ApplyPowerSignDInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto var_shape = input_args[kInputIndex0]->BuildShape();
auto m_shape = input_args[kInputIndex1]->BuildShape();
auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
auto logbase_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->GetShapeTrack())[kShape];
auto sign_decay_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->GetShapeTrack())[kShape];
auto beta_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->GetShapeTrack())[kShape];
auto grad_shape = input_args[kInputIndex6]->BuildShape();
(void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape.size(), kLessEqual, 1, prim_name);
if (lr_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("lr_shape[0] size", lr_shape[0], kEqual, 1, prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("logbase_shape size", logbase_shape.size(), kLessEqual, 1, prim_name);
if (logbase_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("logbase_shape[0] size", logbase_shape[0], kEqual, 1, prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("sign_decay_shape size", sign_decay_shape.size(), kLessEqual, 1, prim_name);
if (sign_decay_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("sign_decay_shape[0] size", sign_decay_shape[0], kEqual, 1, prim_name);
}
(void)CheckAndConvertUtils::CheckInteger("beta_shape size", beta_shape.size(), kLessEqual, 1, prim_name);
if (beta_shape.size() == 1) {
(void)CheckAndConvertUtils::CheckInteger("beta_shape[0] size", beta_shape[0], kEqual, 1, prim_name);
}
if (grad_shape->IsDynamic() || var_shape->IsDynamic() || m_shape->IsDynamic()) {
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, m_shape});
}
// var, m and grad must have the same shape
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
same_shape_args_map.insert({"m", m_shape});
same_shape_args_map.insert({"grad", grad_shape});
for (auto &elem : same_shape_args_map) {
if (*elem.second != *var_shape) {
MS_EXCEPTION(ValueError) << prim_name << " evaluator arg " << elem.first << " shape " << elem.second->ToString()
<< " are not consistent with var shape " << var_shape->ToString();
}
}
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, m_shape});
}
TuplePtr ApplyPowerSignDInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto var_type = input_args[kInputIndex0]->BuildType();
auto m_type = input_args[kInputIndex1]->BuildType();
auto lr_type = input_args[kInputIndex2]->BuildType();
auto logbase_type = input_args[kInputIndex3]->BuildType();
auto sign_decay_type = input_args[kInputIndex4]->BuildType();
auto beta_type = input_args[kInputIndex5]->BuildType();
auto grad_type = input_args[kInputIndex6]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> args;
args.insert({"var", var_type});
args.insert({"m", m_type});
args.insert({"grad", grad_type});
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("lr_dtype", lr_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("logbase_dtype", logbase_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("sign_decay_dtype", sign_decay_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("beta_dtype", beta_type, valid_types, prim_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type});
}
} // namespace
AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto infer_type = ApplyPowerSignDInferType(primitive, input_args);
auto infer_shape = ApplyPowerSignDInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyPowerSign, prim::kPrimApplyPowerSignD, ApplyPowerSignDInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_APPLY_POWER_SIGN_D_H_
#define MINDSPORE_CORE_OPS_APPLY_POWER_SIGN_D_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyPowerSign = "ApplyPowerSign";
class ApplyPowerSign : public PrimitiveC {
public:
ApplyPowerSign() : PrimitiveC(kNameApplyPowerSign) {
InitIOName({"var", "m", "lr", "logbase", "sign_decay", "beta", "grad"}, {"var", "m"});
}
~ApplyPowerSign() = default;
MS_DECLARE_PARENT(ApplyPowerSign, PrimitiveC);
};
AbstractBasePtr ApplyPowerSignDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using kPrimApplyPowerSignDPtr = std::shared_ptr<ApplyPowerSign>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_APPLY_POWERSIGND_H_

View File

@ -56,6 +56,7 @@ from .apply_adagrad_d_a import _apply_adagrad_d_a_tbe
from .apply_add_sign import _apply_add_sign_tbe
from .apply_add_sign_ds import _apply_add_sign_ds_tbe
from .apply_power_sign import _apply_power_sign_tbe
from .apply_power_sign_ds import _apply_power_sign_ds_tbe
from .apply_gradient_descent import _apply_gradient_descent_tbe
from .apply_gradient_descent_ds import _apply_gradient_descent_ds_tbe
from .apply_proximal_gradient_descent import _apply_proximal_gradient_descent_tbe

View File

@ -0,0 +1,66 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ApplyPowerSignD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
apply_power_sign_d_op_info = TBERegOp("ApplyPowerSign") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_power_sign_d.so") \
.compute_cost(10) \
.dynamic_shape(True)\
.kernel_name("apply_power_sign_d") \
.partial_flag(True) \
.input(0, "var", False, "required", "all") \
.input(1, "m", False, "required", "all") \
.input(2, "lr", False, "required", "all") \
.input(3, "logbase", False, "required", "all") \
.input(4, "sign_decay", False, "required", "all") \
.input(5, "beta", False, "required", "all") \
.input(6, "grad", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "m", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ,
DataType.F16_FracZ) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ) \
.get_op_info()
@op_info_register(apply_power_sign_d_op_info)
def _apply_power_sign_ds_tbe():
"""ApplyPowerSignD TBE register"""
return

View File

@ -6209,7 +6209,7 @@ class ApplyAddSign(PrimitiveWithInfer):
"""Initialize ApplyAddSign."""
self.add_prim_attr('side_effect_mem', True)
class ApplyPowerSign(PrimitiveWithInfer):
class ApplyPowerSign(Primitive):
r"""
Updates relevant entries according to the AddSign algorithm.
@ -6303,38 +6303,6 @@ class ApplyPowerSign(PrimitiveWithInfer):
"""Initialize ApplyPowerSign."""
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, var_shape, m_shape, lr_shape, logbase_shape, sign_decay_shape,
beta_shape, grad_shape):
validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name)
validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name)
lr_shape_len = len(lr_shape)
validator.check_int(lr_shape_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shape_len == 1:
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
logbase_shape_len = len(logbase_shape)
validator.check_int(logbase_shape_len, 1, Rel.LE, "logbase's rank", self.name)
if logbase_shape_len == 1:
validator.check_int(logbase_shape[0], 1, Rel.EQ, "logbase_shape[0]", self.name)
sign_decay_shape_len = len(sign_decay_shape)
validator.check_int(sign_decay_shape_len, 1, Rel.LE, "sign_decay's rank", self.name)
if sign_decay_shape_len == 1:
validator.check_int(sign_decay_shape[0], 1, Rel.EQ, "sign_decay_shape[0]", self.name)
beta_shape_len = len(beta_shape)
validator.check_int(beta_shape_len, 1, Rel.LE, "beta's rank", self.name)
if beta_shape_len == 1:
validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name)
return var_shape, m_shape
def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype,
beta_dtype, grad_dtype):
valid_dtypes = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype}
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"logbase": logbase_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"sign_decay": sign_decay_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"beta": beta_dtype}, valid_dtypes, self.name)
return var_dtype, m_dtype
class ApplyGradientDescent(Primitive):