forked from mindspore-Ecosystem/mindspore
[feat][assistant][I40FGD] add new Ascend operator Trunc
This commit is contained in:
parent
a0613df349
commit
785558cdd2
|
@ -64,6 +64,7 @@ constexpr auto kTile = "Tile";
|
||||||
constexpr auto kBiasAddGrad = "BiasAddGrad";
|
constexpr auto kBiasAddGrad = "BiasAddGrad";
|
||||||
constexpr auto kCos = "Cos";
|
constexpr auto kCos = "Cos";
|
||||||
constexpr auto kAbs = "Abs";
|
constexpr auto kAbs = "Abs";
|
||||||
|
constexpr auto kTrunc = "Trunc";
|
||||||
constexpr auto kSquare = "Square";
|
constexpr auto kSquare = "Square";
|
||||||
|
|
||||||
// Arrays
|
// Arrays
|
||||||
|
@ -123,6 +124,7 @@ inline const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_l
|
||||||
inline const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin");
|
inline const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin");
|
||||||
inline const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos");
|
inline const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos");
|
||||||
inline const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan");
|
inline const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan");
|
||||||
|
inline const PrimitivePtr kPrimTrunc = std::make_shared<Primitive>(kTrunc);
|
||||||
|
|
||||||
// Comparisons
|
// Comparisons
|
||||||
inline const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq");
|
inline const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq");
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/trunc.h"
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
abstract::ShapePtr TruncInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
auto x_shape = input_args[0]->BuildShape();
|
||||||
|
MS_EXCEPTION_IF_NULL(x_shape);
|
||||||
|
auto output_shape = x_shape->cast<abstract::ShapePtr>();
|
||||||
|
return output_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr TruncInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name());
|
||||||
|
std::set<TypePtr> check_list = {kFloat16, kFloat32, kInt8, kInt32, kUInt8};
|
||||||
|
auto input_type = input_args[0]->BuildType();
|
||||||
|
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_type, check_list, prim->name());
|
||||||
|
return input_type;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
AbstractBasePtr TruncInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
return abstract::MakeAbstract(TruncInferShape(primitive, input_args), TruncInferType(primitive, input_args));
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(Trunc, prim::kPrimTrunc, TruncInfer, nullptr, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_TRUNC_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_TRUNC_H_
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
constexpr auto kNameTrunc = "Trunc";
|
||||||
|
|
||||||
|
class Trunc : public PrimitiveC {
|
||||||
|
public:
|
||||||
|
Trunc() : PrimitiveC(kNameTrunc) { InitIOName({"input_x"}, {"output_y"}); }
|
||||||
|
~Trunc() = default;
|
||||||
|
MS_DECLARE_PARENT(Trunc, PrimitiveC);
|
||||||
|
};
|
||||||
|
|
||||||
|
AbstractBasePtr TruncInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
|
using PrimTruncPtr = std::shared_ptr<Trunc>;
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_TRUNC_H_
|
|
@ -21,6 +21,7 @@ from .. import functional as F
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
from .._grad.grad_base import bprop_getters
|
from .._grad.grad_base import bprop_getters
|
||||||
from .._grad.grad_math_ops import binop_grad_common
|
from .._grad.grad_math_ops import binop_grad_common
|
||||||
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
from ..operations import _grad_ops as G
|
from ..operations import _grad_ops as G
|
||||||
from ..primitive import constexpr
|
from ..primitive import constexpr
|
||||||
|
|
||||||
|
@ -96,3 +97,14 @@ def get_bprop_erfinv(self):
|
||||||
return (dx,)
|
return (dx,)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.Trunc)
|
||||||
|
def get_bprop_trunc(self):
|
||||||
|
"""Grad definition for `Trunc` operation."""
|
||||||
|
|
||||||
|
def bprop(input_x, output_y, dout):
|
||||||
|
bc_x = zeros_like(input_x)
|
||||||
|
return (bc_x,)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
|
@ -418,3 +418,4 @@ from .hshrink import _hshrink_tbe
|
||||||
from .hshrink_grad import _hshrink_grad_tbe
|
from .hshrink_grad import _hshrink_grad_tbe
|
||||||
from .new_im2col import _new_im2col_tbe
|
from .new_im2col import _new_im2col_tbe
|
||||||
from .non_zero_ds import _non_zero_ds_tbe
|
from .non_zero_ds import _non_zero_ds_tbe
|
||||||
|
from .trunc import _trunc_tbe
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Trunc op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
trunc_op_info = TBERegOp("Trunc") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("trunc.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("trunc") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.input(0, "input_x", False, "required", "all") \
|
||||||
|
.output(0, "output_y", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(trunc_op_info)
|
||||||
|
def _trunc_tbe():
|
||||||
|
"""Trunc TBE register"""
|
||||||
|
return
|
|
@ -59,7 +59,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
||||||
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
|
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
|
||||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
|
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
|
||||||
Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
|
Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
|
||||||
MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex)
|
MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc,)
|
||||||
|
|
||||||
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
|
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
|
||||||
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
|
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
|
||||||
|
@ -483,6 +483,7 @@ __all__ = [
|
||||||
"Conj",
|
"Conj",
|
||||||
"Real",
|
"Real",
|
||||||
"Imag",
|
"Imag",
|
||||||
|
"Trunc",
|
||||||
"Complex"
|
"Complex"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -5516,3 +5516,31 @@ class Imag(PrimitiveWithInfer):
|
||||||
elif input_dtype == mstype.tensor_type(mstype.complex128):
|
elif input_dtype == mstype.tensor_type(mstype.complex128):
|
||||||
output_dtype = mstype.float64
|
output_dtype = mstype.float64
|
||||||
return output_dtype
|
return output_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class Trunc(Primitive):
|
||||||
|
"""
|
||||||
|
Returns a new tensor with the truncated integer values of the elements of input.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - Input_x is a tensor.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the same shape and data type as the input.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `input_x` is not a Tensor.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> trunc = ops.Trunc()
|
||||||
|
>>> output = trunc(Tensor(np.array([3.4742, 0.5466, -0.8008, -3.9079]),mindspore.float32))
|
||||||
|
>>> print(output)
|
||||||
|
[ 3. 0. 0. -3.]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize Trunc"""
|
||||||
|
|
|
@ -469,6 +469,10 @@ raise_set = [
|
||||||
('AssignAdd_Error', {
|
('AssignAdd_Error', {
|
||||||
'block': (P.AssignAdd(), {'exception': ValueError}),
|
'block': (P.AssignAdd(), {'exception': ValueError}),
|
||||||
'desc_inputs': [[1]]}),
|
'desc_inputs': [[1]]}),
|
||||||
|
('Trunc', {
|
||||||
|
'block': P.Trunc(),
|
||||||
|
'desc_inputs': [Tensor(np.array([[1.1, 2.2, -4.1]], np.float32))],
|
||||||
|
'skip': ['backward']}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue