forked from mindspore-Ecosystem/mindspore
[feat][assistant][I40FGD] add new Ascend operator Trunc
This commit is contained in:
parent
a0613df349
commit
785558cdd2
|
@ -64,6 +64,7 @@ constexpr auto kTile = "Tile";
|
|||
constexpr auto kBiasAddGrad = "BiasAddGrad";
|
||||
constexpr auto kCos = "Cos";
|
||||
constexpr auto kAbs = "Abs";
|
||||
constexpr auto kTrunc = "Trunc";
|
||||
constexpr auto kSquare = "Square";
|
||||
|
||||
// Arrays
|
||||
|
@ -123,6 +124,7 @@ inline const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_l
|
|||
inline const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin");
|
||||
inline const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos");
|
||||
inline const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan");
|
||||
inline const PrimitivePtr kPrimTrunc = std::make_shared<Primitive>(kTrunc);
|
||||
|
||||
// Comparisons
|
||||
inline const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq");
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/trunc.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr TruncInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x_shape = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x_shape);
|
||||
auto output_shape = x_shape->cast<abstract::ShapePtr>();
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
TypePtr TruncInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name());
|
||||
std::set<TypePtr> check_list = {kFloat16, kFloat32, kInt8, kInt32, kUInt8};
|
||||
auto input_type = input_args[0]->BuildType();
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_type, check_list, prim->name());
|
||||
return input_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr TruncInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(TruncInferShape(primitive, input_args), TruncInferType(primitive, input_args));
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Trunc, prim::kPrimTrunc, TruncInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_TRUNC_H_
|
||||
#define MINDSPORE_CORE_OPS_TRUNC_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
||||
constexpr auto kNameTrunc = "Trunc";
|
||||
|
||||
class Trunc : public PrimitiveC {
|
||||
public:
|
||||
Trunc() : PrimitiveC(kNameTrunc) { InitIOName({"input_x"}, {"output_y"}); }
|
||||
~Trunc() = default;
|
||||
MS_DECLARE_PARENT(Trunc, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr TruncInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
|
||||
using PrimTruncPtr = std::shared_ptr<Trunc>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_TRUNC_H_
|
|
@ -21,6 +21,7 @@ from .. import functional as F
|
|||
from .. import operations as P
|
||||
from .._grad.grad_base import bprop_getters
|
||||
from .._grad.grad_math_ops import binop_grad_common
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations import _grad_ops as G
|
||||
from ..primitive import constexpr
|
||||
|
||||
|
@ -96,3 +97,14 @@ def get_bprop_erfinv(self):
|
|||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Trunc)
|
||||
def get_bprop_trunc(self):
|
||||
"""Grad definition for `Trunc` operation."""
|
||||
|
||||
def bprop(input_x, output_y, dout):
|
||||
bc_x = zeros_like(input_x)
|
||||
return (bc_x,)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -418,3 +418,4 @@ from .hshrink import _hshrink_tbe
|
|||
from .hshrink_grad import _hshrink_grad_tbe
|
||||
from .new_im2col import _new_im2col_tbe
|
||||
from .non_zero_ds import _non_zero_ds_tbe
|
||||
from .trunc import _trunc_tbe
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Trunc op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
trunc_op_info = TBERegOp("Trunc") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("trunc.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("trunc") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "input_x", False, "required", "all") \
|
||||
.output(0, "output_y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(trunc_op_info)
|
||||
def _trunc_tbe():
|
||||
"""Trunc TBE register"""
|
||||
return
|
|
@ -59,7 +59,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
|||
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
|
||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
|
||||
Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
|
||||
MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex)
|
||||
MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc,)
|
||||
|
||||
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
|
||||
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
|
||||
|
@ -483,6 +483,7 @@ __all__ = [
|
|||
"Conj",
|
||||
"Real",
|
||||
"Imag",
|
||||
"Trunc",
|
||||
"Complex"
|
||||
]
|
||||
|
||||
|
|
|
@ -5516,3 +5516,31 @@ class Imag(PrimitiveWithInfer):
|
|||
elif input_dtype == mstype.tensor_type(mstype.complex128):
|
||||
output_dtype = mstype.float64
|
||||
return output_dtype
|
||||
|
||||
|
||||
class Trunc(Primitive):
|
||||
"""
|
||||
Returns a new tensor with the truncated integer values of the elements of input.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - Input_x is a tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the same shape and data type as the input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> trunc = ops.Trunc()
|
||||
>>> output = trunc(Tensor(np.array([3.4742, 0.5466, -0.8008, -3.9079]),mindspore.float32))
|
||||
>>> print(output)
|
||||
[ 3. 0. 0. -3.]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Trunc"""
|
||||
|
|
|
@ -469,6 +469,10 @@ raise_set = [
|
|||
('AssignAdd_Error', {
|
||||
'block': (P.AssignAdd(), {'exception': ValueError}),
|
||||
'desc_inputs': [[1]]}),
|
||||
('Trunc', {
|
||||
'block': P.Trunc(),
|
||||
'desc_inputs': [Tensor(np.array([[1.1, 2.2, -4.1]], np.float32))],
|
||||
'skip': ['backward']}),
|
||||
]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue