forked from mindspore-Ecosystem/mindspore
!1863 add op broadcast_to
Merge pull request !1863 from zhaozhenlong/op/broadcast-to-d-vm
This commit is contained in:
commit
53df649737
|
@ -105,7 +105,8 @@ static std::map<string, string> tbe_func_adapter_map = {
|
||||||
{"unsorted_segment_min", "unsorted_segment_min_d"},
|
{"unsorted_segment_min", "unsorted_segment_min_d"},
|
||||||
{"reduce_prod", "reduce_prod_d"},
|
{"reduce_prod", "reduce_prod_d"},
|
||||||
{"a_cos", "acos"},
|
{"a_cos", "acos"},
|
||||||
{"a_cos_grad", "acos_grad"}};
|
{"a_cos_grad", "acos_grad"},
|
||||||
|
{"broadcast_to", "broadcast_to_d"}};
|
||||||
|
|
||||||
void TbeAdapter::NormalizeFuncName(std::string *func_name) {
|
void TbeAdapter::NormalizeFuncName(std::string *func_name) {
|
||||||
if (func_name == nullptr) {
|
if (func_name == nullptr) {
|
||||||
|
@ -139,7 +140,7 @@ void TbeAdapter::NormalizeFuncName(std::string *func_name) {
|
||||||
*func_name = name_tmp;
|
*func_name = name_tmp;
|
||||||
auto iter = tbe_func_adapter_map.find(*func_name);
|
auto iter = tbe_func_adapter_map.find(*func_name);
|
||||||
if (iter != tbe_func_adapter_map.end()) {
|
if (iter != tbe_func_adapter_map.end()) {
|
||||||
MS_LOG(INFO) << "map actual op from me " << func_name << "to tbe op" << iter->second;
|
MS_LOG(INFO) << "map actual op from me " << *func_name << " to tbe op" << iter->second;
|
||||||
*func_name = iter->second;
|
*func_name = iter->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
from ..operations import _grad_ops as G
|
from ..operations import _grad_ops as G
|
||||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
|
from ..functional import broadcast_gradient_args
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .grad_base import bprop_getters
|
from .grad_base import bprop_getters
|
||||||
from ..primitive import constexpr
|
from ..primitive import constexpr
|
||||||
|
@ -580,3 +581,17 @@ def get_bprop_batch_to_space_nd(self):
|
||||||
dx = batch_to_space_nd_grad(dout)
|
dx = batch_to_space_nd_grad(dout)
|
||||||
return (dx,)
|
return (dx,)
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
@bprop_getters.register(P.BroadcastTo)
|
||||||
|
def get_bprop_broadcast_to(self):
|
||||||
|
"""Generate bprop for BroadcastTo"""
|
||||||
|
reduce_keep_dim = P.ReduceSum(keep_dims=True)
|
||||||
|
broadcast_shape = self.shape
|
||||||
|
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
x_shape = shape_op(x)
|
||||||
|
_, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
|
||||||
|
reduced_grad = reduce_keep_dim(dout, reduction_axes)
|
||||||
|
dx = reshape(reduced_grad, x_shape)
|
||||||
|
return (dx,)
|
||||||
|
return bprop
|
||||||
|
|
|
@ -217,9 +217,9 @@ from .bessel_i0e import _bessel_i0e_tbe
|
||||||
from .bessel_i1e import _bessel_i1e_tbe
|
from .bessel_i1e import _bessel_i1e_tbe
|
||||||
from .batch_to_space_nd import _batch_to_space_nd_tbe
|
from .batch_to_space_nd import _batch_to_space_nd_tbe
|
||||||
from .space_to_batch_nd import _space_to_batch_nd_tbe
|
from .space_to_batch_nd import _space_to_batch_nd_tbe
|
||||||
from .bitwise_and import bitwise_and_op_info
|
from .bitwise_and import _bitwise_and_tbe
|
||||||
from .bitwise_or import bitwise_or_op_info
|
from .bitwise_or import _bitwise_or_tbe
|
||||||
from .bitwise_xor import bitwise_xor_op_info
|
from .bitwise_xor import _bitwise_xor_tbe
|
||||||
from .reduce_all import _reduce_all_tbe
|
from .reduce_all import _reduce_all_tbe
|
||||||
from .sparse_apply_adagrad import _sparse_apply_adagrad_tbe
|
from .sparse_apply_adagrad import _sparse_apply_adagrad_tbe
|
||||||
from .unsorted_segment_min import _unsorted_segment_min_tbe
|
from .unsorted_segment_min import _unsorted_segment_min_tbe
|
||||||
|
@ -238,3 +238,4 @@ from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe
|
||||||
from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe
|
from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe
|
||||||
from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe
|
from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe
|
||||||
from .confusion_matrix import _confusion_matrix_tbe
|
from .confusion_matrix import _confusion_matrix_tbe
|
||||||
|
from .broadcast_to import _broadcast_to_tbe
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""BroadcastTo op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
broadcast_to_op_info = TBERegOp("BroadcastTo") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("broadcast_to_d.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("broadcast_to_d") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.attr("shape", "required", "listInt", "all") \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U16_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(broadcast_to_op_info)
|
||||||
|
def _broadcast_to_tbe():
|
||||||
|
"""BroadcastTo TBE register"""
|
||||||
|
return
|
|
@ -30,7 +30,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||||
Squeeze, StridedSlice, Tile,
|
Squeeze, StridedSlice, Tile,
|
||||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
||||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||||
SpaceToBatchND, BatchToSpaceND)
|
SpaceToBatchND, BatchToSpaceND, BroadcastTo)
|
||||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
|
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
|
||||||
_MirrorOperator, ReduceOp, _VirtualDataset,
|
_MirrorOperator, ReduceOp, _VirtualDataset,
|
||||||
_VirtualDiv, _GetTensorSlice,
|
_VirtualDiv, _GetTensorSlice,
|
||||||
|
@ -289,7 +289,8 @@ __all__ = [
|
||||||
"Atan",
|
"Atan",
|
||||||
"Atanh",
|
"Atanh",
|
||||||
"BasicLSTMCell",
|
"BasicLSTMCell",
|
||||||
"ConfusionMatrix"
|
"ConfusionMatrix",
|
||||||
|
"BroadcastTo"
|
||||||
]
|
]
|
||||||
|
|
||||||
__all__.extend(_quant_ops.__all__)
|
__all__.extend(_quant_ops.__all__)
|
||||||
|
|
|
@ -2738,3 +2738,40 @@ class BatchToSpaceND(PrimitiveWithInfer):
|
||||||
f'block_shape_prod {block_shape_prod}')
|
f'block_shape_prod {block_shape_prod}')
|
||||||
out_shape[0] = out_shape[0] // block_shape_prod
|
out_shape[0] = out_shape[0] // block_shape_prod
|
||||||
return out_shape
|
return out_shape
|
||||||
|
|
||||||
|
|
||||||
|
class BroadcastTo(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Broadcasts input tensor to a given shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape (tuple): The target shape to broadcast.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The input tensor.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, with the given `shape` and the same data type as `input_x`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> shape = (2, 3)
|
||||||
|
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
|
||||||
|
>>> broadcast_to = P.BroadcastTo(shape)
|
||||||
|
>>> broadcast_to(input_x)
|
||||||
|
[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, shape):
|
||||||
|
"""Init BroadcastTo"""
|
||||||
|
validator.check_value_type("shape", shape, (tuple), self.name)
|
||||||
|
for i in shape:
|
||||||
|
validator.check_integer("shape element", i, 0, Rel.GT, self.name)
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape):
|
||||||
|
return self.shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype):
|
||||||
|
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
|
||||||
|
return x_dtype
|
||||||
|
|
|
@ -1396,6 +1396,10 @@ test_case_array_ops = [
|
||||||
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)),
|
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)),
|
||||||
Tensor(np.array([0, 1, 1]).astype(np.int32))],
|
Tensor(np.array([0, 1, 1]).astype(np.int32))],
|
||||||
'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}),
|
'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}),
|
||||||
|
('BroadcastTo', {
|
||||||
|
'block': P.BroadcastTo((2,3)),
|
||||||
|
'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))],
|
||||||
|
'desc_bprop': [Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32))]}),
|
||||||
]
|
]
|
||||||
|
|
||||||
test_case_other_ops = [
|
test_case_other_ops = [
|
||||||
|
|
Loading…
Reference in New Issue