[feat][assistant][I3T925] add new math operator Erfinv

This commit is contained in:
hulx 2021-08-24 21:25:04 +08:00
parent dddeb06129
commit ca6f3ef155
10 changed files with 197 additions and 2 deletions

View File

@ -502,6 +502,7 @@ inline const PrimitivePtr kPrimAcoshGrad = std::make_shared<Primitive>("AcoshGra
inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod"); inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod");
inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where"); inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where");
inline const PrimitivePtr kPrimIdentityMath = std::make_shared<Primitive>("Identity", kSideEffectPropagate); inline const PrimitivePtr kPrimIdentityMath = std::make_shared<Primitive>("Identity", kSideEffectPropagate);
inline const PrimitivePtr kPrimErfinv = std::make_shared<Primitive>("Erfinv");
inline const PrimitivePtr kPrimIsNan = std::make_shared<Primitive>("IsNan"); inline const PrimitivePtr kPrimIsNan = std::make_shared<Primitive>("IsNan");
inline const PrimitivePtr kPrimIsInf = std::make_shared<Primitive>("IsInf"); inline const PrimitivePtr kPrimIsInf = std::make_shared<Primitive>("IsInf");
inline const PrimitivePtr kPrimIsFinite = std::make_shared<Primitive>("IsFinite"); inline const PrimitivePtr kPrimIsFinite = std::make_shared<Primitive>("IsFinite");

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/erfinv.h"
#include <map>
#include <string>
#include <set>
#include "ops/op_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr ErfinvInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input_x numbers", input_args.size(), kEqual, 1, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto in_shape = shape_map[kShape];
return std::make_shared<abstract::Shape>(in_shape);
}
TypePtr ErfinvInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto op_name = prim->name();
CheckAndConvertUtils::CheckInteger("input_x number", input_args.size(), kEqual, 1, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto infer_type = input_args[0]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("input_x", infer_type, valid_types, prim->name());
return infer_type;
}
} // namespace
AbstractBasePtr ErfinvInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(ErfinvInferShape(primitive, input_args), ErfinvInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(Erfinv, prim::kPrimErfinv, ErfinvInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_ERFINV_H_
#define MINDSPORE_CORE_OPS_ERFINV_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameErfinv = "Erfinv";
class Erfinv : public PrimitiveC {
public:
Erfinv() : PrimitiveC(kNameErfinv) { InitIOName({"input_x"}, {"output"}); }
~Erfinv() = default;
MS_DECLARE_PARENT(Erfinv, PrimitiveC);
};
AbstractBasePtr ErfinvInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimErfinvPtr = std::shared_ptr<Erfinv>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_ERFINV_H_

View File

@ -15,10 +15,10 @@
"""grad experimental impl.""" """grad experimental impl."""
from .._grad.grad_base import get_bprop_fn from .._grad.grad_base import get_bprop_fn
from . import grad_math_ops
from . import grad_array_ops from . import grad_array_ops
from . import grad_inner_ops from . import grad_inner_ops
from . import grad_nn_ops from . import grad_nn_ops
from . import grad_comm_ops from . import grad_comm_ops
from . import grad_math_ops
__all__ = ['get_bprop_fn'] __all__ = ['get_bprop_fn']

View File

@ -16,6 +16,7 @@
"""Define the grad rules of math related operations.""" """Define the grad rules of math related operations."""
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
import numpy as np
from .. import functional as F from .. import functional as F
from .. import operations as P from .. import operations as P
from .._grad.grad_base import bprop_getters from .._grad.grad_base import bprop_getters
@ -45,3 +46,21 @@ def get_bprop_index_lerp(self):
return dstart, dend, dweight return dstart, dend, dweight
return bprop return bprop
@bprop_getters.register(P.Erfinv)
def get_bprop_erfinv(self):
"""Grad definition for `Erfinv` operation."""
exp = P.Exp()
square = P.Square()
sqrt = P.Sqrt()
cast = P.Cast()
dtype = P.DType()
def bprop(input_x, out, dout):
root_pi_over_two = cast(sqrt(F.scalar_to_tensor(np.pi)) / 2, dtype(dout))
dout_square = square(dout)
dx = dout * root_pi_over_two * exp(dout_square)
return (dx,)
return bprop

View File

@ -396,6 +396,7 @@ from .ctc_loss_v2 import _ctc_loss_v2_tbe
from .ctc_loss_v2_grad import _ctc_loss_v2_grad_tbe from .ctc_loss_v2_grad import _ctc_loss_v2_grad_tbe
from .roll import _roll_tbe from .roll import _roll_tbe
from .soft_shrink import _soft_shrink_tbe from .soft_shrink import _soft_shrink_tbe
from .erfinv import _erfinv_tbe
from .soft_shrink_grad import _soft_shrink_grad_tbe from .soft_shrink_grad import _soft_shrink_grad_tbe
from .hsigmoid_grad import _hsigmoid_grad_tbe from .hsigmoid_grad import _hsigmoid_grad_tbe
from .hsigmoid import _hsigmoid_tbe from .hsigmoid import _hsigmoid_tbe

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Erfinv op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
erfinv_op_info = TBERegOp("Erfinv") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("erfinv.so") \
.compute_cost(10) \
.kernel_name("erfinv") \
.partial_flag(True) \
.input(0, "input_x", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(erfinv_op_info)
def _erfinv_tbe():
"""Erfinv TBE register"""
return

View File

@ -59,7 +59,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
MatrixInverse, IndexAdd) MatrixInverse, IndexAdd, Erfinv)
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler, RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
@ -241,6 +241,7 @@ __all__ = [
'ReLU6', 'ReLU6',
'Elu', 'Elu',
'Erf', 'Erf',
"Erfinv",
'Erfc', 'Erfc',
'Sigmoid', 'Sigmoid',
'HSwish', 'HSwish',
@ -531,6 +532,7 @@ __all__ = [
"BufferGetItem", "BufferGetItem",
"BufferSample", "BufferSample",
"NeighborListUpdateNew", "NeighborListUpdateNew",
"Erfinv"
] ]
__all__.sort() __all__.sort()

View File

@ -5269,3 +5269,36 @@ class IndexAdd(PrimitiveWithInfer):
if dim != axis: if dim != axis:
validator.check('x dim %d' % dim, x_shape[dim], "y dim %d" % dim, y_shape[dim], Rel.EQ, self.name) validator.check('x dim %d' % dim, x_shape[dim], "y dim %d" % dim, y_shape[dim], Rel.EQ, self.name)
return x_shape return x_shape
class Erfinv(Primitive):
r"""
Computes the inverse error function of input. The inverse error function is defined in the range (-1, 1) as:
.. math::
erfinv(erf(x)) = x
Inputs:
- **input_x** (Tensor) - The input tensor to compute to, with data type float32, float16.
Outputs:
Tensor, has the same shape and dtype as `input_x`.
Raises:
TypeError: If dtype of `input_x` is not one of: float32, float16.
Supported Platforms:
``Ascend``
Examples:
>>> x = Tensor(np.array([0, 0.5, -0.9]), mindspore.float32)
>>> erfinv = P.Erfinv()
>>> output = erfinv(x)
>>> print(output)
[ 0. 0.47695306 -1.1630805 ]
"""
@prim_attr_register
def __init__(self):
"""Initialize Erfinv"""
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])

View File

@ -1667,6 +1667,10 @@ test_case_math_ops = [
Tensor(np.random.rand(4).astype(np.int32))], Tensor(np.random.rand(4).astype(np.int32))],
'desc_bprop': [], 'desc_bprop': [],
'skip': ['backward']}), 'skip': ['backward']}),
('Erfinv', {
'block': P.Erfinv(),
'desc_inputs': [Tensor(np.array([0.1, 0.1, 0.1]).astype(np.float16))],
'desc_bprop': [Tensor(np.array([1, 1, 1]).astype(np.float16))]}),
] ]
test_case_nn_ops = [ test_case_nn_ops = [