!1863 add op broadcast_to

Merge pull request !1863 from zhaozhenlong/op/broadcast-to-d-vm
This commit is contained in:
mindspore-ci-bot 2020-06-05 19:05:24 +08:00 committed by Gitee
commit 53df649737
8 changed files with 108 additions and 9 deletions

View File

@ -105,7 +105,8 @@ static std::map<string, string> tbe_func_adapter_map = {
{"unsorted_segment_min", "unsorted_segment_min_d"}, {"unsorted_segment_min", "unsorted_segment_min_d"},
{"reduce_prod", "reduce_prod_d"}, {"reduce_prod", "reduce_prod_d"},
{"a_cos", "acos"}, {"a_cos", "acos"},
{"a_cos_grad", "acos_grad"}}; {"a_cos_grad", "acos_grad"},
{"broadcast_to", "broadcast_to_d"}};
void TbeAdapter::NormalizeFuncName(std::string *func_name) { void TbeAdapter::NormalizeFuncName(std::string *func_name) {
if (func_name == nullptr) { if (func_name == nullptr) {
@ -139,7 +140,7 @@ void TbeAdapter::NormalizeFuncName(std::string *func_name) {
*func_name = name_tmp; *func_name = name_tmp;
auto iter = tbe_func_adapter_map.find(*func_name); auto iter = tbe_func_adapter_map.find(*func_name);
if (iter != tbe_func_adapter_map.end()) { if (iter != tbe_func_adapter_map.end()) {
MS_LOG(INFO) << "map actual op from me " << func_name << "to tbe op" << iter->second; MS_LOG(INFO) << "map actual op from me " << *func_name << " to tbe op" << iter->second;
*func_name = iter->second; *func_name = iter->second;
} }
} }

View File

@ -18,6 +18,7 @@
from .. import operations as P from .. import operations as P
from ..operations import _grad_ops as G from ..operations import _grad_ops as G
from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..functional import broadcast_gradient_args
from .. import functional as F from .. import functional as F
from .grad_base import bprop_getters from .grad_base import bprop_getters
from ..primitive import constexpr from ..primitive import constexpr
@ -580,3 +581,17 @@ def get_bprop_batch_to_space_nd(self):
dx = batch_to_space_nd_grad(dout) dx = batch_to_space_nd_grad(dout)
return (dx,) return (dx,)
return bprop return bprop
@bprop_getters.register(P.BroadcastTo)
def get_bprop_broadcast_to(self):
"""Generate bprop for BroadcastTo"""
reduce_keep_dim = P.ReduceSum(keep_dims=True)
broadcast_shape = self.shape
def bprop(x, out, dout):
x_shape = shape_op(x)
_, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
reduced_grad = reduce_keep_dim(dout, reduction_axes)
dx = reshape(reduced_grad, x_shape)
return (dx,)
return bprop

View File

@ -217,9 +217,9 @@ from .bessel_i0e import _bessel_i0e_tbe
from .bessel_i1e import _bessel_i1e_tbe from .bessel_i1e import _bessel_i1e_tbe
from .batch_to_space_nd import _batch_to_space_nd_tbe from .batch_to_space_nd import _batch_to_space_nd_tbe
from .space_to_batch_nd import _space_to_batch_nd_tbe from .space_to_batch_nd import _space_to_batch_nd_tbe
from .bitwise_and import bitwise_and_op_info from .bitwise_and import _bitwise_and_tbe
from .bitwise_or import bitwise_or_op_info from .bitwise_or import _bitwise_or_tbe
from .bitwise_xor import bitwise_xor_op_info from .bitwise_xor import _bitwise_xor_tbe
from .reduce_all import _reduce_all_tbe from .reduce_all import _reduce_all_tbe
from .sparse_apply_adagrad import _sparse_apply_adagrad_tbe from .sparse_apply_adagrad import _sparse_apply_adagrad_tbe
from .unsorted_segment_min import _unsorted_segment_min_tbe from .unsorted_segment_min import _unsorted_segment_min_tbe
@ -238,3 +238,4 @@ from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe
from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe
from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe
from .confusion_matrix import _confusion_matrix_tbe from .confusion_matrix import _confusion_matrix_tbe
from .broadcast_to import _broadcast_to_tbe

View File

@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BroadcastTo op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
broadcast_to_op_info = TBERegOp("BroadcastTo") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("broadcast_to_d.so") \
.compute_cost(10) \
.kernel_name("broadcast_to_d") \
.partial_flag(True) \
.attr("shape", "required", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U16_Default) \
.get_op_info()
@op_info_register(broadcast_to_op_info)
def _broadcast_to_tbe():
"""BroadcastTo TBE register"""
return

View File

@ -30,7 +30,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Squeeze, StridedSlice, Tile, Squeeze, StridedSlice, Tile,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND) SpaceToBatchND, BatchToSpaceND, BroadcastTo)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset, _MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice, _VirtualDiv, _GetTensorSlice,
@ -289,7 +289,8 @@ __all__ = [
"Atan", "Atan",
"Atanh", "Atanh",
"BasicLSTMCell", "BasicLSTMCell",
"ConfusionMatrix" "ConfusionMatrix",
"BroadcastTo"
] ]
__all__.extend(_quant_ops.__all__) __all__.extend(_quant_ops.__all__)

View File

@ -2738,3 +2738,40 @@ class BatchToSpaceND(PrimitiveWithInfer):
f'block_shape_prod {block_shape_prod}') f'block_shape_prod {block_shape_prod}')
out_shape[0] = out_shape[0] // block_shape_prod out_shape[0] = out_shape[0] // block_shape_prod
return out_shape return out_shape
class BroadcastTo(PrimitiveWithInfer):
"""
Broadcasts input tensor to a given shape.
Args:
shape (tuple): The target shape to broadcast.
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, with the given `shape` and the same data type as `input_x`.
Examples:
>>> shape = (2, 3)
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
>>> broadcast_to = P.BroadcastTo(shape)
>>> broadcast_to(input_x)
[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]
"""
@prim_attr_register
def __init__(self, shape):
"""Init BroadcastTo"""
validator.check_value_type("shape", shape, (tuple), self.name)
for i in shape:
validator.check_integer("shape element", i, 0, Rel.GT, self.name)
self.shape = shape
def infer_shape(self, x_shape):
return self.shape
def infer_dtype(self, x_dtype):
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
return x_dtype

View File

@ -1396,6 +1396,10 @@ test_case_array_ops = [
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)), 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)),
Tensor(np.array([0, 1, 1]).astype(np.int32))], Tensor(np.array([0, 1, 1]).astype(np.int32))],
'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}), 'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}),
('BroadcastTo', {
'block': P.BroadcastTo((2,3)),
'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))],
'desc_bprop': [Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32))]}),
] ]
test_case_other_ops = [ test_case_other_ops = [