!1863 add op broadcast_to

Merge pull request !1863 from zhaozhenlong/op/broadcast-to-d-vm
This commit is contained in:
mindspore-ci-bot 2020-06-05 19:05:24 +08:00 committed by Gitee
commit 53df649737
8 changed files with 108 additions and 9 deletions

View File

@ -105,7 +105,8 @@ static std::map<string, string> tbe_func_adapter_map = {
{"unsorted_segment_min", "unsorted_segment_min_d"},
{"reduce_prod", "reduce_prod_d"},
{"a_cos", "acos"},
{"a_cos_grad", "acos_grad"}};
{"a_cos_grad", "acos_grad"},
{"broadcast_to", "broadcast_to_d"}};
void TbeAdapter::NormalizeFuncName(std::string *func_name) {
if (func_name == nullptr) {
@ -139,7 +140,7 @@ void TbeAdapter::NormalizeFuncName(std::string *func_name) {
*func_name = name_tmp;
auto iter = tbe_func_adapter_map.find(*func_name);
if (iter != tbe_func_adapter_map.end()) {
MS_LOG(INFO) << "map actual op from me " << func_name << "to tbe op" << iter->second;
MS_LOG(INFO) << "map actual op from me " << *func_name << " to tbe op" << iter->second;
*func_name = iter->second;
}
}

View File

@ -175,7 +175,7 @@ class FakeQuantWithMinMaxAscend(Cell):
else:
quant_fun = P.FakeQuantPerLayer
ema_fun = P.FakeQuantMinMaxPerLayerUpdate
self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
@ -272,7 +272,7 @@ class FakeQuantWithMinMaxGPU(Cell):
0, self.out_channels)]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
if per_channel:
quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis)
else:

View File

@ -18,6 +18,7 @@
from .. import operations as P
from ..operations import _grad_ops as G
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..functional import broadcast_gradient_args
from .. import functional as F
from .grad_base import bprop_getters
from ..primitive import constexpr
@ -580,3 +581,17 @@ def get_bprop_batch_to_space_nd(self):
dx = batch_to_space_nd_grad(dout)
return (dx,)
return bprop
@bprop_getters.register(P.BroadcastTo)
def get_bprop_broadcast_to(self):
"""Generate bprop for BroadcastTo"""
reduce_keep_dim = P.ReduceSum(keep_dims=True)
broadcast_shape = self.shape
def bprop(x, out, dout):
x_shape = shape_op(x)
_, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
reduced_grad = reduce_keep_dim(dout, reduction_axes)
dx = reshape(reduced_grad, x_shape)
return (dx,)
return bprop

View File

@ -217,9 +217,9 @@ from .bessel_i0e import _bessel_i0e_tbe
from .bessel_i1e import _bessel_i1e_tbe
from .batch_to_space_nd import _batch_to_space_nd_tbe
from .space_to_batch_nd import _space_to_batch_nd_tbe
from .bitwise_and import bitwise_and_op_info
from .bitwise_or import bitwise_or_op_info
from .bitwise_xor import bitwise_xor_op_info
from .bitwise_and import _bitwise_and_tbe
from .bitwise_or import _bitwise_or_tbe
from .bitwise_xor import _bitwise_xor_tbe
from .reduce_all import _reduce_all_tbe
from .sparse_apply_adagrad import _sparse_apply_adagrad_tbe
from .unsorted_segment_min import _unsorted_segment_min_tbe
@ -238,3 +238,4 @@ from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe
from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe
from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe
from .confusion_matrix import _confusion_matrix_tbe
from .broadcast_to import _broadcast_to_tbe

View File

@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BroadcastTo op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
broadcast_to_op_info = TBERegOp("BroadcastTo") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("broadcast_to_d.so") \
.compute_cost(10) \
.kernel_name("broadcast_to_d") \
.partial_flag(True) \
.attr("shape", "required", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U16_Default) \
.get_op_info()
@op_info_register(broadcast_to_op_info)
def _broadcast_to_tbe():
"""BroadcastTo TBE register"""
return

View File

@ -30,7 +30,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Squeeze, StridedSlice, Tile,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND)
SpaceToBatchND, BatchToSpaceND, BroadcastTo)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice,
@ -289,7 +289,8 @@ __all__ = [
"Atan",
"Atanh",
"BasicLSTMCell",
"ConfusionMatrix"
"ConfusionMatrix",
"BroadcastTo"
]
__all__.extend(_quant_ops.__all__)

View File

@ -2738,3 +2738,40 @@ class BatchToSpaceND(PrimitiveWithInfer):
f'block_shape_prod {block_shape_prod}')
out_shape[0] = out_shape[0] // block_shape_prod
return out_shape
class BroadcastTo(PrimitiveWithInfer):
"""
Broadcasts input tensor to a given shape.
Args:
shape (tuple): The target shape to broadcast.
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, with the given `shape` and the same data type as `input_x`.
Examples:
>>> shape = (2, 3)
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
>>> broadcast_to = P.BroadcastTo(shape)
>>> broadcast_to(input_x)
[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]
"""
@prim_attr_register
def __init__(self, shape):
"""Init BroadcastTo"""
validator.check_value_type("shape", shape, (tuple), self.name)
for i in shape:
validator.check_integer("shape element", i, 0, Rel.GT, self.name)
self.shape = shape
def infer_shape(self, x_shape):
return self.shape
def infer_dtype(self, x_dtype):
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
return x_dtype

View File

@ -1396,6 +1396,10 @@ test_case_array_ops = [
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)),
Tensor(np.array([0, 1, 1]).astype(np.int32))],
'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}),
('BroadcastTo', {
'block': P.BroadcastTo((2,3)),
'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))],
'desc_bprop': [Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32))]}),
]
test_case_other_ops = [