forked from mindspore-Ecosystem/mindspore
!8396 Adapt UnsortedSegmentMax for Ascend.
From: @liu_xiao_93 Reviewed-by: @liangchenghui,@c_34 Signed-off-by: @liangchenghui
This commit is contained in:
commit
e62285e8e7
|
@ -675,10 +675,10 @@ def get_bprop_diag_part(self):
|
|||
return bprop
|
||||
|
||||
|
||||
def _GatherDropNegatives(params,
|
||||
ids,
|
||||
zero_clipped_indices=None,
|
||||
is_positive=None):
|
||||
def _gather_drop_negatives(params,
|
||||
ids,
|
||||
zero_clipped_indices=None,
|
||||
is_positive=None):
|
||||
"""Helper function for unsorted segment ops."""
|
||||
maximum = P.Maximum()
|
||||
gather = P.GatherV2()
|
||||
|
@ -703,12 +703,32 @@ def _GatherDropNegatives(params,
|
|||
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
|
||||
|
||||
|
||||
def _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout):
|
||||
"""Gradient for UnsortedSegmentMin or UnsortedSegmentMax"""
|
||||
equal = P.Equal()
|
||||
cast = P.Cast()
|
||||
divide = P.RealDiv()
|
||||
get_dtype = P.DType()
|
||||
select = P.Select()
|
||||
|
||||
gathered_outputs, zero_clipped_indices, is_positive = _gather_drop_negatives(out, segment_ids, None, None)
|
||||
is_selected = equal(x, gathered_outputs)
|
||||
is_selected = logical_and(is_selected, is_positive)
|
||||
num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
|
||||
segment_ids, num_segments)
|
||||
weighted_grads = divide(dout, num_selected)
|
||||
gathered_grads, _, _ = _gather_drop_negatives(weighted_grads, None,
|
||||
zero_clipped_indices, is_positive)
|
||||
zeros = zeros_like(gathered_grads)
|
||||
return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
|
||||
|
||||
|
||||
@bprop_getters.register(P.UnsortedSegmentSum)
|
||||
def get_bprop_unsorted_segment_sum(self):
|
||||
"""Generate bprop for UnsortedSegmentSum"""
|
||||
|
||||
def bprop(x, segment_ids, num_segments, out, dout):
|
||||
return _GatherDropNegatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments)
|
||||
return _gather_drop_negatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments)
|
||||
|
||||
return bprop
|
||||
|
||||
|
@ -716,23 +736,20 @@ def get_bprop_unsorted_segment_sum(self):
|
|||
@bprop_getters.register(P.UnsortedSegmentMin)
|
||||
def get_bprop_unsorted_segment_min(self):
|
||||
"""Generate bprop for UnsortedSegmentMin"""
|
||||
equal = P.Equal()
|
||||
cast = P.Cast()
|
||||
divide = P.RealDiv()
|
||||
get_dtype = P.DType()
|
||||
select = P.Select()
|
||||
|
||||
def bprop(x, segment_ids, num_segments, out, dout):
|
||||
gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids, None, None)
|
||||
is_selected = equal(x, gathered_outputs)
|
||||
is_selected = logical_and(is_selected, is_positive)
|
||||
num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
|
||||
segment_ids, num_segments)
|
||||
weighted_grads = divide(dout, num_selected)
|
||||
gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
|
||||
zero_clipped_indices, is_positive)
|
||||
zeros = zeros_like(gathered_grads)
|
||||
return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
|
||||
return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.UnsortedSegmentMax)
|
||||
def get_bprop_unsorted_segment_max(self):
|
||||
"""Generate bprop for UnsortedSegmentMax"""
|
||||
|
||||
def bprop(x, segment_ids, num_segments, out, dout):
|
||||
return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
|
@ -759,7 +776,7 @@ def get_bprop_unsorted_segment_prod(self):
|
|||
gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0)
|
||||
prod_divided_by_x = gathered_prod / x
|
||||
partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x)
|
||||
gathered_grad, _, _ = _GatherDropNegatives(grad, segment_ids, zero_clipped_indices)
|
||||
gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices)
|
||||
dx = gathered_grad * partial_derivative
|
||||
return dx, zeros_like(segment_ids), zeros_like(num_segments)
|
||||
|
||||
|
|
|
@ -272,6 +272,7 @@ from .reduce_all import _reduce_all_tbe
|
|||
from .reduce_any import _reduce_any_tbe
|
||||
from .sparse_apply_adagrad import _sparse_apply_adagrad_tbe
|
||||
from .unsorted_segment_min import _unsorted_segment_min_tbe
|
||||
from .unsorted_segment_max import _unsorted_segment_max_tbe
|
||||
from .asin import _asin_tbe
|
||||
from .asin_grad import _asin_grad_tbe
|
||||
from .asinh import _asinh_tbe
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""UnsortedSegmentMax op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
unsorted_segment_max_op_info = TBERegOp("UnsortedSegmentMax") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("unsorted_segment_max_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("unsorted_segment_max_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("num_segments", "required", "int", "all") \
|
||||
.input(0, "data", False, "required", "all") \
|
||||
.input(1, "segment_ids", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.I32_Default, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.I32_Default, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.I32_Default, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.I32_Default, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.I32_Default, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.I32_Default, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_Default, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.I32_FracZ, DataType.I32_Default, DataType.I32_FracZ) \
|
||||
.dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_Default, DataType.I32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(unsorted_segment_max_op_info)
|
||||
def _unsorted_segment_max_tbe():
|
||||
"""UnsortedSegmentMax TBE register"""
|
||||
return
|
|
@ -1648,6 +1648,11 @@ test_case_nn_ops = [
|
|||
'desc_const': [4],
|
||||
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))],
|
||||
'desc_bprop': [[4, 2, 1, 3]]}),
|
||||
('UnsortedSegmentMax', {
|
||||
'block': P.UnsortedSegmentMax(),
|
||||
'desc_const': [4],
|
||||
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))],
|
||||
'desc_bprop': [[4, 2, 1, 3]]}),
|
||||
('UnsortedSegmentProd', {
|
||||
'block': P.UnsortedSegmentProd(),
|
||||
'desc_const': [4],
|
||||
|
|
Loading…
Reference in New Issue