!1939 Complete vm ops for DataFormatDimMap and HistogramFixedWidthD

Merge pull request !1939 from lihongkang/lhk_master
This commit is contained in:
mindspore-ci-bot 2020-06-12 11:53:21 +08:00 committed by Gitee
commit ac5878b4b7
8 changed files with 180 additions and 4 deletions

View File

@ -111,6 +111,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"reduce_prod", "reduce_prod_d"},
{"a_cos", "acos"},
{"a_cos_grad", "acos_grad"},
{"histogram_fixed_width", "histogram_fixed_width_d"},
{"broadcast_to", "broadcast_to_d"}};
void TbeAdapter::NormalizeFuncName(std::string *func_name) {

View File

@ -249,3 +249,5 @@ from .fused_mul_add_n_l2loss import _fused_mul_add_n_l2loss_tbe
from .fused_mul_apply_momentum_extern import _fused_mul_apply_momentum_extern_tbe
from .lamb_next_right import _lamb_next_right_tbe
from .sparse_gather_v2 import _sparse_gather_v2_tbe
from .data_format_dim_map import _data_format_dim_map_tbe
from .histogram_fixed_width import _histogram_fixed_width_tbe

View File

@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DataFormatDimMap op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
data_format_dim_map_op_info = TBERegOp("DataFormatDimMap") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("data_format_dim_map.so") \
.compute_cost(10) \
.kernel_name("data_format_dim_map") \
.partial_flag(True) \
.attr("dst_format", "optional", "str", "all") \
.attr("src_format", "optional", "str", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(data_format_dim_map_op_info)
def _data_format_dim_map_tbe():
"""DataFormatDimMap TBE register"""
return

View File

@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""HistogramFixedWidth op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
histogram_fixed_width_op_info = TBERegOp("HistogramFixedWidth") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("histogram_fixed_width_d.so") \
.compute_cost(10) \
.kernel_name("histogram_fixed_width_d") \
.partial_flag(True) \
.attr("nbins", "required", "int", "all") \
.attr("dtype", "optional", "str", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "range", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(histogram_fixed_width_op_info)
def _histogram_fixed_width_tbe():
"""HistogramFixedWidth TBE register"""
return

View File

@ -49,7 +49,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2
Minimum, Mul, Neg, NMSWithMask, NotEqual,
NPUAllocFloatStatus, NPUClearFloatStatus,
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum,
Reciprocal, CumSum, HistogramFixedWidth,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh)
@ -62,7 +62,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
Gelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss,
LogSoftmax,
MaxPool,
MaxPool, DataFormatDimMap,
AvgPool, Conv2DBackpropInput, ConfusionMulGrad,
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid,
@ -207,6 +207,7 @@ __all__ = [
'ScatterNd',
'ScatterMax',
'ResizeNearestNeighbor',
'HistogramFixedWidth',
'Pad',
'MirrorPad',
'GatherNd',
@ -298,7 +299,8 @@ __all__ = [
"BasicLSTMCell",
"ConfusionMatrix",
"BroadcastTo",
"Range"
"Range",
"DataFormatDimMap"
]
__all__.extend(_quant_ops.__all__)

View File

@ -1050,6 +1050,50 @@ class Expm1(PrimitiveWithInfer):
return x_type
class HistogramFixedWidth(PrimitiveWithInfer):
"""
Returns a rank 1 histogram counting the number of entries in values that fall into every bin. The bins are equal
width and determined by the arguments range and nbins.
Args:
dtype (string): An optional attribute. Must be one of the following types: "int32", "int64". Default: "int32".
nbins (Tensor): Number of histogram bins, the type is int32.
Inputs:
- **x** (Tensor) - Numeric Tensor. Must be one of the following types: int32, float32, float16.
- **range** (Tensor) - Must have the same type as x. Shape [2] Tensor of same dtype as x.
x <= range[0] will be mapped to hist[0], x >= range[1] will be mapped to hist[-1].
Outputs:
Tensor, the type is int32.
Examples:
>>> x = Tensor([-1.0, 0.0, 1.5, 2.0, 5.0, 15], mindspore.float16)
>>> range = Tensor([0.0, 5.0], mindspore.float16)
>>> hist = P.HistogramFixedWidth(5)
>>> hist(x, range)
[2 1 1 0 2]
"""
@prim_attr_register
def __init__(self, nbins, dtype='int32'):
self.nbins = validator.check_value_type("nbins", nbins, [int], self.name)
valid_values = ['int32', 'int64']
self.dtype = validator.check_string("dtype", dtype, valid_values, self.name)
self.init_prim_io_names(inputs=['x', 'range'], outputs=['y'])
def infer_shape(self, x_shape, range_shape):
return (self.nbins,)
def infer_dtype(self, x_dtype, range_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
valid_types = (mstype.float16, mstype.float32, mstype.int32)
validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"range": range_dtype}, valid_types, self.name)
y_dtype = mstype.int32
return y_dtype
class Log(PrimitiveWithInfer):
"""
Returns the natural logarithm of a tensor element-wise.

View File

@ -1613,6 +1613,45 @@ class L2Loss(PrimitiveWithInfer):
return x_type
class DataFormatDimMap(PrimitiveWithInfer):
"""
Returns the dimension index in the destination data format given the one in the source data format.
Args:
src_format (string): An optional value for source data format. Default: 'NHWC'.
dst_format (string): An optional value for destination data format. Default: 'NCHW'.
Inputs:
- **input_x** (Tensor) - A Tensor with each element as a dimension index in source data format.
Must be in the range [-4, 4). It's type is int32.
Outputs:
Tensor, has the same type as the `input_x`.
Examples:
>>> x = Tensor([0, 1, 2, 3], mindspore.int32)
>>> dfdm = P.DataFormatDimMap()
>>> dfdm(x)
[0 3 1 2]
"""
@prim_attr_register
def __init__(self, src_format='NHWC', dst_format='NCHW'):
valid_values = ['NHWC', 'NCHW']
self.src_format = validator.check_string("src_format", src_format, valid_values, self.name)
self.dst_format = validator.check_string("dst_format", dst_format, valid_values, self.name)
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor, self.name)
valid_types = [mstype.int32]
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
return x_type
class SGD(PrimitiveWithInfer):
"""
Computes stochastic gradient descent (optionally with momentum).
@ -3735,7 +3774,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
validator.check_integer("b rank", len(b_shape), 4, Rel.EQ, self.name)
validator.check("w_shape[0]", w_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
validator.check("w_shape[1]", w_shape[1], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name)
validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4*h_shape[1], Rel.EQ, self.name)
validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
ct_shape = c_shape
ht_shape = h_shape
it_shape = h_shape

View File

@ -764,6 +764,11 @@ test_case_math_ops = [
'desc_inputs': [Tensor(np.array([[24, 4, 13, 9], [1, 5, 10, 8]]).astype(np.int16))],
'desc_bprop': [],
'skip': ['backward']}),
('HistogramFixedWidth', {
'block': P.HistogramFixedWidth(5),
'desc_inputs': [Tensor([-1.0, 0.0, 1.5, 2.0, 5.0, 15], mstype.float16), Tensor([0.0, 5.0], mstype.float16)],
'desc_bprop': [],
'skip': ['backward']}),
]
test_case_nn_ops = [
@ -1203,6 +1208,11 @@ test_case_nn_ops = [
Tensor([[0.5, 0.4], [0.6, 0.1]], mstype.float32), Tensor([1, 1], mstype.int32)],
'desc_bprop': [Tensor([[0.7, 0.2], [0.1, 0.07]], mstype.float32)],
'skip': ['backward']}),
('DataFormatDimMap', {
'block': P.DataFormatDimMap(),
'desc_inputs': [Tensor([0, 1, 2, 3], mstype.int32)],
'desc_bprop': [],
'skip': ['backward']}),
]
test_case_array_ops = [