[feat][assistant][I40FGD] add new Ascend operator Trunc

This commit is contained in:
echo-yike 2021-09-13 10:52:21 +08:00
parent a0613df349
commit 785558cdd2
9 changed files with 188 additions and 1 deletions

View File

@ -64,6 +64,7 @@ constexpr auto kTile = "Tile";
constexpr auto kBiasAddGrad = "BiasAddGrad";
constexpr auto kCos = "Cos";
constexpr auto kAbs = "Abs";
constexpr auto kTrunc = "Trunc";
constexpr auto kSquare = "Square";
// Arrays
@ -123,6 +124,7 @@ inline const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_l
inline const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin");
inline const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos");
inline const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan");
inline const PrimitivePtr kPrimTrunc = std::make_shared<Primitive>(kTrunc);
// Comparisons
inline const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq");

View File

@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/trunc.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr TruncInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto x_shape = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(x_shape);
auto output_shape = x_shape->cast<abstract::ShapePtr>();
return output_shape;
}
TypePtr TruncInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name());
std::set<TypePtr> check_list = {kFloat16, kFloat32, kInt8, kInt32, kUInt8};
auto input_type = input_args[0]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_type, check_list, prim->name());
return input_type;
}
} // namespace
AbstractBasePtr TruncInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(TruncInferShape(primitive, input_args), TruncInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(Trunc, prim::kPrimTrunc, TruncInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_TRUNC_H_
#define MINDSPORE_CORE_OPS_TRUNC_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameTrunc = "Trunc";
class Trunc : public PrimitiveC {
public:
Trunc() : PrimitiveC(kNameTrunc) { InitIOName({"input_x"}, {"output_y"}); }
~Trunc() = default;
MS_DECLARE_PARENT(Trunc, PrimitiveC);
};
AbstractBasePtr TruncInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimTruncPtr = std::shared_ptr<Trunc>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_TRUNC_H_

View File

@ -21,6 +21,7 @@ from .. import functional as F
from .. import operations as P
from .._grad.grad_base import bprop_getters
from .._grad.grad_math_ops import binop_grad_common
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations import _grad_ops as G
from ..primitive import constexpr
@ -96,3 +97,14 @@ def get_bprop_erfinv(self):
return (dx,)
return bprop
@bprop_getters.register(P.Trunc)
def get_bprop_trunc(self):
"""Grad definition for `Trunc` operation."""
def bprop(input_x, output_y, dout):
bc_x = zeros_like(input_x)
return (bc_x,)
return bprop

View File

@ -418,3 +418,4 @@ from .hshrink import _hshrink_tbe
from .hshrink_grad import _hshrink_grad_tbe
from .new_im2col import _new_im2col_tbe
from .non_zero_ds import _non_zero_ds_tbe
from .trunc import _trunc_tbe

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Trunc op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
trunc_op_info = TBERegOp("Trunc") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("trunc.so") \
.compute_cost(10) \
.kernel_name("trunc") \
.partial_flag(True) \
.input(0, "input_x", False, "required", "all") \
.output(0, "output_y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.get_op_info()
@op_info_register(trunc_op_info)
def _trunc_tbe():
"""Trunc TBE register"""
return

View File

@ -59,7 +59,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex)
MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc,)
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
@ -483,6 +483,7 @@ __all__ = [
"Conj",
"Real",
"Imag",
"Trunc",
"Complex"
]

View File

@ -5516,3 +5516,31 @@ class Imag(PrimitiveWithInfer):
elif input_dtype == mstype.tensor_type(mstype.complex128):
output_dtype = mstype.float64
return output_dtype
class Trunc(Primitive):
"""
Returns a new tensor with the truncated integer values of the elements of input.
Inputs:
- **input_x** (Tensor) - Input_x is a tensor.
Outputs:
Tensor, the same shape and data type as the input.
Raises:
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend``
Examples:
>>> trunc = ops.Trunc()
>>> output = trunc(Tensor(np.array([3.4742, 0.5466, -0.8008, -3.9079]),mindspore.float32))
>>> print(output)
[ 3. 0. 0. -3.]
"""
@prim_attr_register
def __init__(self):
"""Initialize Trunc"""

View File

@ -469,6 +469,10 @@ raise_set = [
('AssignAdd_Error', {
'block': (P.AssignAdd(), {'exception': ValueError}),
'desc_inputs': [[1]]}),
('Trunc', {
'block': P.Trunc(),
'desc_inputs': [Tensor(np.array([[1.1, 2.2, -4.1]], np.float32))],
'skip': ['backward']}),
]