!8396 Adapt UnsortedSegmentMax for Ascend.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@c_34
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2020-11-10 19:09:35 +08:00 committed by Gitee
commit e62285e8e7
4 changed files with 92 additions and 21 deletions

View File

@ -675,10 +675,10 @@ def get_bprop_diag_part(self):
return bprop
def _GatherDropNegatives(params,
ids,
zero_clipped_indices=None,
is_positive=None):
def _gather_drop_negatives(params,
ids,
zero_clipped_indices=None,
is_positive=None):
"""Helper function for unsorted segment ops."""
maximum = P.Maximum()
gather = P.GatherV2()
@ -703,12 +703,32 @@ def _GatherDropNegatives(params,
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
def _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout):
"""Gradient for UnsortedSegmentMin or UnsortedSegmentMax"""
equal = P.Equal()
cast = P.Cast()
divide = P.RealDiv()
get_dtype = P.DType()
select = P.Select()
gathered_outputs, zero_clipped_indices, is_positive = _gather_drop_negatives(out, segment_ids, None, None)
is_selected = equal(x, gathered_outputs)
is_selected = logical_and(is_selected, is_positive)
num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
segment_ids, num_segments)
weighted_grads = divide(dout, num_selected)
gathered_grads, _, _ = _gather_drop_negatives(weighted_grads, None,
zero_clipped_indices, is_positive)
zeros = zeros_like(gathered_grads)
return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
@bprop_getters.register(P.UnsortedSegmentSum)
def get_bprop_unsorted_segment_sum(self):
"""Generate bprop for UnsortedSegmentSum"""
def bprop(x, segment_ids, num_segments, out, dout):
return _GatherDropNegatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments)
return _gather_drop_negatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments)
return bprop
@ -716,23 +736,20 @@ def get_bprop_unsorted_segment_sum(self):
@bprop_getters.register(P.UnsortedSegmentMin)
def get_bprop_unsorted_segment_min(self):
"""Generate bprop for UnsortedSegmentMin"""
equal = P.Equal()
cast = P.Cast()
divide = P.RealDiv()
get_dtype = P.DType()
select = P.Select()
def bprop(x, segment_ids, num_segments, out, dout):
gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids, None, None)
is_selected = equal(x, gathered_outputs)
is_selected = logical_and(is_selected, is_positive)
num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
segment_ids, num_segments)
weighted_grads = divide(dout, num_selected)
gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
zero_clipped_indices, is_positive)
zeros = zeros_like(gathered_grads)
return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout)
return bprop
@bprop_getters.register(P.UnsortedSegmentMax)
def get_bprop_unsorted_segment_max(self):
"""Generate bprop for UnsortedSegmentMax"""
def bprop(x, segment_ids, num_segments, out, dout):
return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout)
return bprop
@ -759,7 +776,7 @@ def get_bprop_unsorted_segment_prod(self):
gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0)
prod_divided_by_x = gathered_prod / x
partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x)
gathered_grad, _, _ = _GatherDropNegatives(grad, segment_ids, zero_clipped_indices)
gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices)
dx = gathered_grad * partial_derivative
return dx, zeros_like(segment_ids), zeros_like(num_segments)

View File

@ -272,6 +272,7 @@ from .reduce_all import _reduce_all_tbe
from .reduce_any import _reduce_any_tbe
from .sparse_apply_adagrad import _sparse_apply_adagrad_tbe
from .unsorted_segment_min import _unsorted_segment_min_tbe
from .unsorted_segment_max import _unsorted_segment_max_tbe
from .asin import _asin_tbe
from .asin_grad import _asin_grad_tbe
from .asinh import _asinh_tbe

View File

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""UnsortedSegmentMax op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
unsorted_segment_max_op_info = TBERegOp("UnsortedSegmentMax") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("unsorted_segment_max_d.so") \
.compute_cost(10) \
.kernel_name("unsorted_segment_max_d") \
.partial_flag(True) \
.attr("num_segments", "required", "int", "all") \
.input(0, "data", False, "required", "all") \
.input(1, "segment_ids", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.I32_Default, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.I32_Default, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.I32_Default, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.I32_Default, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.I32_Default, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.I32_Default, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_Default, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_Default, DataType.I32_FracZ) \
.dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_Default, DataType.I32_C1HWNCoC0) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(unsorted_segment_max_op_info)
def _unsorted_segment_max_tbe():
"""UnsortedSegmentMax TBE register"""
return

View File

@ -1648,6 +1648,11 @@ test_case_nn_ops = [
'desc_const': [4],
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))],
'desc_bprop': [[4, 2, 1, 3]]}),
('UnsortedSegmentMax', {
'block': P.UnsortedSegmentMax(),
'desc_const': [4],
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))],
'desc_bprop': [[4, 2, 1, 3]]}),
('UnsortedSegmentProd', {
'block': P.UnsortedSegmentProd(),
'desc_const': [4],