diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index d3fd00d401f..5d8b9f5e40a 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -105,7 +105,8 @@ static std::map tbe_func_adapter_map = { {"unsorted_segment_min", "unsorted_segment_min_d"}, {"reduce_prod", "reduce_prod_d"}, {"a_cos", "acos"}, - {"a_cos_grad", "acos_grad"}}; + {"a_cos_grad", "acos_grad"}, + {"broadcast_to", "broadcast_to_d"}}; void TbeAdapter::NormalizeFuncName(std::string *func_name) { if (func_name == nullptr) { @@ -139,7 +140,7 @@ void TbeAdapter::NormalizeFuncName(std::string *func_name) { *func_name = name_tmp; auto iter = tbe_func_adapter_map.find(*func_name); if (iter != tbe_func_adapter_map.end()) { - MS_LOG(INFO) << "map actual op from me " << func_name << "to tbe op" << iter->second; + MS_LOG(INFO) << "map actual op from me " << *func_name << " to tbe op" << iter->second; *func_name = iter->second; } } diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 77fda2162e8..a0b2e5bdb2a 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -175,7 +175,7 @@ class FakeQuantWithMinMaxAscend(Cell): else: quant_fun = P.FakeQuantPerLayer ema_fun = P.FakeQuantMinMaxPerLayerUpdate - + self.fake_quant = quant_fun(num_bits=self.num_bits, ema=self.ema, ema_decay=self.ema_decay, @@ -272,7 +272,7 @@ class FakeQuantWithMinMaxGPU(Cell): 0, self.out_channels)]).astype(np.float32) self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) - + if per_channel: quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) else: diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 1861a4d7265..9ec9b0f0804 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -18,6 +18,7 @@ from .. import operations as P from ..operations import _grad_ops as G from ..composite.multitype_ops.zeros_like_impl import zeros_like +from ..functional import broadcast_gradient_args from .. import functional as F from .grad_base import bprop_getters from ..primitive import constexpr @@ -580,3 +581,17 @@ def get_bprop_batch_to_space_nd(self): dx = batch_to_space_nd_grad(dout) return (dx,) return bprop + +@bprop_getters.register(P.BroadcastTo) +def get_bprop_broadcast_to(self): + """Generate bprop for BroadcastTo""" + reduce_keep_dim = P.ReduceSum(keep_dims=True) + broadcast_shape = self.shape + + def bprop(x, out, dout): + x_shape = shape_op(x) + _, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape) + reduced_grad = reduce_keep_dim(dout, reduction_axes) + dx = reshape(reduced_grad, x_shape) + return (dx,) + return bprop diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 10947a535af..0d796eac463 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -217,9 +217,9 @@ from .bessel_i0e import _bessel_i0e_tbe from .bessel_i1e import _bessel_i1e_tbe from .batch_to_space_nd import _batch_to_space_nd_tbe from .space_to_batch_nd import _space_to_batch_nd_tbe -from .bitwise_and import bitwise_and_op_info -from .bitwise_or import bitwise_or_op_info -from .bitwise_xor import bitwise_xor_op_info +from .bitwise_and import _bitwise_and_tbe +from .bitwise_or import _bitwise_or_tbe +from .bitwise_xor import _bitwise_xor_tbe from .reduce_all import _reduce_all_tbe from .sparse_apply_adagrad import _sparse_apply_adagrad_tbe from .unsorted_segment_min import _unsorted_segment_min_tbe @@ -238,3 +238,4 @@ from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe from .confusion_matrix import _confusion_matrix_tbe +from .broadcast_to import _broadcast_to_tbe diff --git a/mindspore/ops/_op_impl/tbe/broadcast_to.py b/mindspore/ops/_op_impl/tbe/broadcast_to.py new file mode 100644 index 00000000000..5d4b642017a --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/broadcast_to.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BroadcastTo op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +broadcast_to_op_info = TBERegOp("BroadcastTo") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("broadcast_to_d.so") \ + .compute_cost(10) \ + .kernel_name("broadcast_to_d") \ + .partial_flag(True) \ + .attr("shape", "required", "listInt", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.U16_Default) \ + .get_op_info() + + +@op_info_register(broadcast_to_op_info) +def _broadcast_to_tbe(): + """BroadcastTo TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 996df7c2856..2176b1e38e2 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -30,7 +30,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Squeeze, StridedSlice, Tile, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, - SpaceToBatchND, BatchToSpaceND) + SpaceToBatchND, BatchToSpaceND, BroadcastTo) from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, _VirtualDiv, _GetTensorSlice, @@ -289,7 +289,8 @@ __all__ = [ "Atan", "Atanh", "BasicLSTMCell", - "ConfusionMatrix" + "ConfusionMatrix", + "BroadcastTo" ] __all__.extend(_quant_ops.__all__) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 694c4f1d74b..6b1794c2439 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2738,3 +2738,40 @@ class BatchToSpaceND(PrimitiveWithInfer): f'block_shape_prod {block_shape_prod}') out_shape[0] = out_shape[0] // block_shape_prod return out_shape + + +class BroadcastTo(PrimitiveWithInfer): + """ + Broadcasts input tensor to a given shape. + + Args: + shape (tuple): The target shape to broadcast. + + Inputs: + - **input_x** (Tensor) - The input tensor. + + Outputs: + Tensor, with the given `shape` and the same data type as `input_x`. + + Examples: + >>> shape = (2, 3) + >>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) + >>> broadcast_to = P.BroadcastTo(shape) + >>> broadcast_to(input_x) + [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]] + """ + + @prim_attr_register + def __init__(self, shape): + """Init BroadcastTo""" + validator.check_value_type("shape", shape, (tuple), self.name) + for i in shape: + validator.check_integer("shape element", i, 0, Rel.GT, self.name) + self.shape = shape + + def infer_shape(self, x_shape): + return self.shape + + def infer_dtype(self, x_dtype): + validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name) + return x_dtype diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index f9b7ee64831..bd95cef5305 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1396,6 +1396,10 @@ test_case_array_ops = [ 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)), Tensor(np.array([0, 1, 1]).astype(np.int32))], 'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}), + ('BroadcastTo', { + 'block': P.BroadcastTo((2,3)), + 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))], + 'desc_bprop': [Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32))]}), ] test_case_other_ops = [