forked from mindspore-Ecosystem/mindspore
!2787 Add op ScatterMul ScatterDiv vm
Merge pull request !2787 from zhaozhenlong/op/scatter-div-mul-vm
This commit is contained in:
commit
622b97f3b6
|
@ -369,7 +369,7 @@ class CosineEmbeddingLoss(_Loss):
|
|||
>>> x2 = Tensor(np.array([[0.4, 1.2], [-0.4, -0.9]]), mindspore.float32)
|
||||
>>> y = Tensor(np.array([1,-1]), mindspore.int32)
|
||||
>>> cosine_embedding_loss = P.CosineEmbeddingLoss()
|
||||
>>> cosine_embedding_loss(x1, x2, target)
|
||||
>>> cosine_embedding_loss(x1, x2, y)
|
||||
[0.0003426671]
|
||||
"""
|
||||
def __init__(self, margin=0.0, reduction="mean"):
|
||||
|
|
|
@ -276,4 +276,6 @@ from .lrn_grad import _lrn_grad_tbe
|
|||
from .scatter_max import _scatter_max_tbe
|
||||
from .scatter_min import _scatter_min_tbe
|
||||
from .scatter_sub import _scatter_sub_tbe
|
||||
from .scatter_mul import _scatter_mul_tbe
|
||||
from .scatter_div import _scatter_div_tbe
|
||||
from .mod import _mod_tbe
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ScatterDiv op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
scatter_div_op_info = TBERegOp("ScatterDiv") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("scatter_div.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("scatter_div") \
|
||||
.partial_flag(True) \
|
||||
.attr("use_locking", "optional", "bool", "all") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "indices", False, "required", "all") \
|
||||
.input(2, "updates", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(scatter_div_op_info)
|
||||
def _scatter_div_tbe():
|
||||
"""ScatterDiv TBE register"""
|
||||
return
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ScatterMul op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
scatter_mul_op_info = TBERegOp("ScatterMul") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("scatter_mul.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("scatter_mul") \
|
||||
.partial_flag(True) \
|
||||
.attr("use_locking", "optional", "bool", "all") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "indices", False, "required", "all") \
|
||||
.input(2, "updates", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(scatter_mul_op_info)
|
||||
def _scatter_mul_tbe():
|
||||
"""ScatterMul TBE register"""
|
||||
return
|
|
@ -25,8 +25,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
|
||||
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
||||
SameTypeShape, ScatterAdd, ScatterSub, ScatterMax, ScatterMin, ScatterUpdate,
|
||||
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
|
||||
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||
Shape, Size, Slice, Split,
|
||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
||||
|
@ -215,6 +215,8 @@ __all__ = [
|
|||
'L2Normalize',
|
||||
'ScatterAdd',
|
||||
'ScatterSub',
|
||||
'ScatterMul',
|
||||
'ScatterDiv',
|
||||
'ScatterNd',
|
||||
'ScatterMax',
|
||||
'ScatterMin',
|
||||
|
|
|
@ -38,6 +38,39 @@ from ..._c_expression import signature_dtype as sig_dtype
|
|||
from ..._c_expression import typing
|
||||
|
||||
|
||||
class _ScatterOp(PrimitiveWithInfer):
|
||||
"""
|
||||
Define Scatter operators
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('updates', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
@staticmethod
|
||||
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name):
|
||||
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
|
||||
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
|
||||
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
|
||||
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
"""Init _ScatterOp"""
|
||||
validator.check_value_type('use_locking', use_locking, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
||||
_ScatterOp._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
|
||||
args = {"x": x_dtype, "updates": updates_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
|
||||
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
|
||||
validator.check_value_type('axis', axis, [int, tuple], prim_name)
|
||||
|
@ -2221,7 +2254,7 @@ class TensorScatterUpdate(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class ScatterUpdate(PrimitiveWithInfer):
|
||||
class ScatterUpdate(_ScatterOp):
|
||||
"""
|
||||
Update tensor value by using input indices and value.
|
||||
|
||||
|
@ -2233,8 +2266,8 @@ class ScatterUpdate(PrimitiveWithInfer):
|
|||
Inputs:
|
||||
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||
- **indices** (Tensor) - The index of input tensor. With int32 data type.
|
||||
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
||||
and update.shape = indices.shape + input_x.shape[1:].
|
||||
- **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
||||
and updates.shape = indices.shape + input_x.shape[1:].
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and type as `input_x`.
|
||||
|
@ -2243,27 +2276,17 @@ class ScatterUpdate(PrimitiveWithInfer):
|
|||
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
|
||||
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
|
||||
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
||||
>>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
|
||||
>>> update = Tensor(np_update, mindspore.float32)
|
||||
>>> np_updates = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
|
||||
>>> updates = Tensor(np_updates, mindspore.float32)
|
||||
>>> op = P.ScatterUpdate()
|
||||
>>> output = op(input_x, indices, update)
|
||||
>>> output = op(input_x, indices, updates)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=True):
|
||||
"""Init ScatterUpdate"""
|
||||
validator.check_value_type('use_locking', use_locking, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, value_shape):
|
||||
if indices_shape + x_shape[1:] != value_shape:
|
||||
raise ValueError("For 'ScatterUpdate', input value are not match with input indices.")
|
||||
return x_shape
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
|
||||
|
@ -2323,14 +2346,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name):
|
||||
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
|
||||
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
|
||||
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
|
||||
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
|
||||
|
||||
|
||||
class ScatterMax(PrimitiveWithInfer):
|
||||
class ScatterMax(_ScatterOp):
|
||||
"""
|
||||
Update the value of the input tensor through the max operation.
|
||||
|
||||
|
@ -2364,18 +2380,8 @@ class ScatterMax(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
||||
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
||||
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
|
||||
args = {"x": x_dtype, "updates": updates_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class ScatterMin(PrimitiveWithInfer):
|
||||
class ScatterMin(_ScatterOp):
|
||||
"""
|
||||
Update the value of the input tensor through the min operation.
|
||||
|
||||
|
@ -2403,24 +2409,8 @@ class ScatterMin(PrimitiveWithInfer):
|
|||
[[0.0, 1.0, 1.0], [0.0, 0.0, 0.0]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
"""Init ScatterMin"""
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
||||
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
||||
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
|
||||
args = {"x": x_dtype, "updates": updates_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class ScatterAdd(PrimitiveWithInfer):
|
||||
class ScatterAdd(_ScatterOp):
|
||||
"""
|
||||
Update the value of the input tensor through the add operation.
|
||||
|
||||
|
@ -2448,23 +2438,8 @@ class ScatterAdd(PrimitiveWithInfer):
|
|||
[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
"""Init ScatterAdd"""
|
||||
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
||||
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
|
||||
args = {'x': x_dtype, 'updates': updates_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class ScatterSub(PrimitiveWithInfer):
|
||||
class ScatterSub(_ScatterOp):
|
||||
"""
|
||||
Update the value of the input tensor through the sub operation.
|
||||
|
||||
|
@ -2492,20 +2467,63 @@ class ScatterSub(PrimitiveWithInfer):
|
|||
[[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
"""Init ScatterSub"""
|
||||
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
||||
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
||||
return x_shape
|
||||
class ScatterMul(_ScatterOp):
|
||||
"""
|
||||
Update the value of the input tensor through the mul operation.
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
|
||||
args = {'x': x_dtype, 'updates': updates_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
Using given values to update tensor value through the mul operation, along with the input indices.
|
||||
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
|
||||
|
||||
Args:
|
||||
use_locking (bool): Whether protect the assignment by a lock. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Parameter) - The target parameter.
|
||||
- **indices** (Tensor) - The index to do mul operation whose data type should be mindspore.int32.
|
||||
- **updates** (Tensor) - The tensor doing the mul operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
|
||||
|
||||
Outputs:
|
||||
Parameter, the updated `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
|
||||
>>> indices = Tensor(np.array([0, 1]), mindspore.int32)
|
||||
>>> updates = Tensor(np.ones([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
|
||||
>>> scatter_mul = P.ScatterMul()
|
||||
>>> output = scatter_mul(input_x, indices, updates)
|
||||
[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]]
|
||||
"""
|
||||
|
||||
|
||||
class ScatterDiv(_ScatterOp):
|
||||
"""
|
||||
Update the value of the input tensor through the div operation.
|
||||
|
||||
Using given values to update tensor value through the div operation, along with the input indices.
|
||||
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
|
||||
|
||||
Args:
|
||||
use_locking (bool): Whether protect the assignment by a lock. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Parameter) - The target parameter.
|
||||
- **indices** (Tensor) - The index to do div operation whose data type should be mindspore.int32.
|
||||
- **updates** (Tensor) - The tensor doing the div operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
|
||||
|
||||
Outputs:
|
||||
Parameter, the updated `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
|
||||
>>> indices = Tensor(np.array([0, 1]), mindspore.int32)
|
||||
>>> updates = Tensor(np.ones([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
|
||||
>>> scatter_div = P.ScatterDiv()
|
||||
>>> output = scatter_div(input_x, indices, updates)
|
||||
[[3.0, 3.0, 3.0], [1.0, 1.0, 1.0]]
|
||||
"""
|
||||
|
||||
|
||||
class SpaceToDepth(PrimitiveWithInfer):
|
||||
|
|
|
@ -185,6 +185,19 @@ class HistogramSummaryNet(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class ScatterUpdate(nn.Cell):
|
||||
"""ScatterUpdate net definition"""
|
||||
|
||||
def __init__(self, ref_shape, dtype=np.float32, use_locking=False):
|
||||
super(ScatterUpdate, self).__init__()
|
||||
self.scatter_update = P.ScatterUpdate(use_locking)
|
||||
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
|
||||
|
||||
def construct(self, indices, updates):
|
||||
out = self.scatter_update(self.ref, indices, updates)
|
||||
return out
|
||||
|
||||
|
||||
class ScatterMax(nn.Cell):
|
||||
"""ScatterMax net definition"""
|
||||
|
||||
|
@ -237,6 +250,32 @@ class ScatterSub(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class ScatterMul(nn.Cell):
|
||||
"""ScatterMul net definition"""
|
||||
|
||||
def __init__(self, ref_shape, dtype=np.float32, use_locking=False):
|
||||
super(ScatterMul, self).__init__()
|
||||
self.scatter_mul = P.ScatterMul(use_locking)
|
||||
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
|
||||
|
||||
def construct(self, indices, updates):
|
||||
out = self.scatter_mul(self.ref, indices, updates)
|
||||
return out
|
||||
|
||||
|
||||
class ScatterDiv(nn.Cell):
|
||||
"""ScatterDiv net definition"""
|
||||
|
||||
def __init__(self, ref_shape, dtype=np.float32, use_locking=False):
|
||||
super(ScatterDiv, self).__init__()
|
||||
self.scatter_div = P.ScatterDiv(use_locking)
|
||||
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)*10), name="ref")
|
||||
|
||||
def construct(self, indices, updates):
|
||||
out = self.scatter_div(self.ref, indices, updates)
|
||||
return out
|
||||
|
||||
|
||||
class ApplyFtrlNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ApplyFtrlNet, self).__init__()
|
||||
|
@ -1861,6 +1900,11 @@ test_case_other_ops = [
|
|||
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
|
||||
Tensor(np.ones([2, 2, 3], np.int32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterUpdate', {
|
||||
'block': ScatterUpdate((6,)),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2.0, 3.0, 4.0], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterAddUseLocking', {
|
||||
'block': ScatterAdd((6,), use_locking=True),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
|
@ -1902,6 +1946,73 @@ test_case_other_ops = [
|
|||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2, 3, 4], np.uint8))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterMulUseLocking', {
|
||||
'block': ScatterMul((6,), use_locking=True),
|
||||
'desc_inputs': (Tensor(np.array([2], np.int32)),
|
||||
Tensor(np.array([2.0], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterMulScalar', {
|
||||
'block': ScatterMul((6,)),
|
||||
'desc_inputs': (Tensor(np.array([2], np.int32)),
|
||||
Tensor(np.array([2.0], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterMul2d', {
|
||||
'block': ScatterMul((3, 4)),
|
||||
'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)),
|
||||
Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]],
|
||||
[[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterMulF16', {
|
||||
'block': ScatterMul((6,), np.float16),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2.0, 3.0, 4.0], np.float16))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterMulI8', {
|
||||
'block': ScatterMul((6,), np.int8),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2, 3, 4], np.int8))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterMulI32', {
|
||||
'block': ScatterMul((6,), np.int32),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2, 3, 4], np.int32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterMulU8', {
|
||||
'block': ScatterMul((6,), np.uint8),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2, 3, 4], np.uint8))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterDivUseLocking', {
|
||||
'block': ScatterDiv((6,), use_locking=True),
|
||||
'desc_inputs': (Tensor(np.array([2], np.int32)),
|
||||
Tensor(np.array([2.0], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterDivScalar', {
|
||||
'block': ScatterDiv((6,)),
|
||||
'desc_inputs': (Tensor(np.array([2], np.int32)),
|
||||
Tensor(np.array([2.0], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterDiv2d', {
|
||||
'block': ScatterDiv((3, 4)),
|
||||
'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)),
|
||||
Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]],
|
||||
[[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterDivF16', {
|
||||
'block': ScatterDiv((6,), np.float16),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2.0, 3.0, 4.0], np.float16))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterDivI8', {
|
||||
'block': ScatterDiv((6,), np.int8),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2, 3, 4], np.int8))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterDivU8', {
|
||||
'block': ScatterDiv((6,), np.uint8),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2, 3, 4], np.uint8))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterSubUseLocking', {
|
||||
'block': ScatterSub((6,), use_locking=True),
|
||||
'desc_inputs': (Tensor(np.array([2], np.int32)),
|
||||
|
|
Loading…
Reference in New Issue