forked from mindspore-Ecosystem/mindspore
!236 add aicpu embeddinglookup
Merge pull request !236 from wuxuejian/incu_embedding
This commit is contained in:
commit
d463c3f388
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
from ..operations import _grad_ops as G
|
from ..operations import _grad_ops as G
|
||||||
|
from ..operations import _inner_ops as inner
|
||||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .grad_base import bprop_getters
|
from .grad_base import bprop_getters
|
||||||
|
@ -188,6 +189,31 @@ def get_bprop_tile(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(inner.EmbeddingLookup)
|
||||||
|
def get_bprop_embedding_lookup(self):
|
||||||
|
"""Generate bprop for EmbeddingLookup"""
|
||||||
|
host_sub = P.Sub().add_prim_attr('primitive_target', 'CPU')
|
||||||
|
host_reshape = P.Reshape().add_prim_attr('primitive_target', 'CPU')
|
||||||
|
def bprop_sparse(x, indices, offset, reduce_scatter_flag, split_num, out, dout):
|
||||||
|
x_shp = shape_op(x)
|
||||||
|
if reduce_scatter_flag is True:
|
||||||
|
elu_grad = G.EmbeddingLookupCommGrad()
|
||||||
|
actual_dout = elu_grad(dout, split_num)
|
||||||
|
else:
|
||||||
|
actual_dout = dout
|
||||||
|
new_indices = host_sub(indices - offset)
|
||||||
|
# Reshape the 'new_indices'
|
||||||
|
new_indices_shape_changed = (size_op(new_indices),)
|
||||||
|
new_indices = host_reshape(new_indices, new_indices_shape_changed)
|
||||||
|
# Reshape the 'actual_dout'
|
||||||
|
x_shp_tail = x_shp[1:]
|
||||||
|
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
|
||||||
|
actual_dout = host_reshape(actual_dout, actual_dout_shape_changed)
|
||||||
|
return (new_indices, actual_dout, x_shp), zeros_like(new_indices), zeros_like(axis), \
|
||||||
|
zeros_like(reduce_scatter_flag), zeros_like(split_num)
|
||||||
|
return bprop_sparse
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.Transpose)
|
@bprop_getters.register(P.Transpose)
|
||||||
def get_bprop_transpose(self):
|
def get_bprop_transpose(self):
|
||||||
"""Generate bprop for Transpose"""
|
"""Generate bprop for Transpose"""
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
"""aicpu ops"""
|
"""aicpu ops"""
|
||||||
from .init_data_set_queue import _init_data_set_queue_aicpu
|
from .init_data_set_queue import _init_data_set_queue_aicpu
|
||||||
|
from .embedding_lookup import _embedding_lookup_aicpu
|
||||||
from .dropout_genmask import _dropout_genmask_aicpu
|
from .dropout_genmask import _dropout_genmask_aicpu
|
||||||
from .get_next import _get_next_aicpu
|
from .get_next import _get_next_aicpu
|
||||||
from .print_tensor import _print_aicpu
|
from .print_tensor import _print_aicpu
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""EmbeddingLookup op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
embeddingLookup_op_info = AiCPURegOp("EmbeddingLookup") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "params", "required") \
|
||||||
|
.input(1, "indices", "required") \
|
||||||
|
.input(2, "offset", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(embeddingLookup_op_info)
|
||||||
|
def _embedding_lookup_aicpu():
|
||||||
|
"""EmbeddingLookup AiCPU register"""
|
||||||
|
return
|
|
@ -96,3 +96,73 @@ class ExtractImagePatches(PrimitiveWithInfer):
|
||||||
"""infer dtype"""
|
"""infer dtype"""
|
||||||
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
|
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
|
||||||
return input_x
|
return input_x
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingLookup(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Returns a slice of input tensor based on the specified indices.
|
||||||
|
|
||||||
|
This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs:
|
||||||
|
`offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||||
|
The Tensor slice, instead of the entire Tensor.
|
||||||
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||||
|
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
|
||||||
|
and the exceeding part will be filled with 0 in the output.
|
||||||
|
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
|
||||||
|
are equal to `input_indices` minus `offset`.
|
||||||
|
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
|
||||||
|
Only constant value is allowed.
|
||||||
|
- **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable
|
||||||
|
is used only if `reduce_scatter_flag` is True. Only constant value is allowed.
|
||||||
|
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
|
||||||
|
>>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
|
||||||
|
>>> offset = 4
|
||||||
|
>>> reduce_scatter_flag = False
|
||||||
|
>>> split_num = 1
|
||||||
|
>>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num)
|
||||||
|
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init index_select"""
|
||||||
|
self.__setattr_flag__ = True
|
||||||
|
self.init_prim_io_names(inputs=['params', 'indices', 'offset', 'reduce_scatter_flag', 'split_num'],
|
||||||
|
outputs=['output'])
|
||||||
|
self.add_prim_attr('primitive_target', 'CPU')
|
||||||
|
|
||||||
|
def __infer__(self, params, indices, offset, reduce_scatter_flag=False, split_num=2):
|
||||||
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||||
|
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
||||||
|
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
||||||
|
validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name)
|
||||||
|
if split_num['value'] < 1:
|
||||||
|
raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num)
|
||||||
|
params_shp = params['shape']
|
||||||
|
out_shape = indices['shape'] + params_shp[1:]
|
||||||
|
if reduce_scatter_flag is None:
|
||||||
|
raise ValueError("The value of 'reduce_scatter_flag' is None.")
|
||||||
|
reduce_scatter_flag_value = reduce_scatter_flag['value']
|
||||||
|
if split_num is None:
|
||||||
|
raise ValueError("The value of 'split_num_value' is None.")
|
||||||
|
split_num_value = split_num['value']
|
||||||
|
if reduce_scatter_flag_value is True:
|
||||||
|
# Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by
|
||||||
|
# (split_num * 8)
|
||||||
|
if out_shape[0] % (split_num_value * 8) != 0:
|
||||||
|
raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." %
|
||||||
|
(out_shape[0], (split_num_value * 8)))
|
||||||
|
# After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8
|
||||||
|
out_shape[0] = out_shape[0] // 8
|
||||||
|
out = {'shape': out_shape,
|
||||||
|
'dtype': params['dtype'],
|
||||||
|
'value': None}
|
||||||
|
return out
|
||||||
|
|
|
@ -577,64 +577,43 @@ class Range(PrimitiveWithInfer):
|
||||||
class EmbeddingLookup(PrimitiveWithInfer):
|
class EmbeddingLookup(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar
|
Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar
|
||||||
functionality as GatherV2, but has three more inputs: `offset`, `reduce_scatter_flag` and `split_num`.
|
functionality as GatherV2, but has one more inputs: `offset`.
|
||||||
|
This primitive runs on the acipu devices.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
- **params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||||
The Tensor slice, instead of the entire Tensor.
|
The Tensor slice, instead of the entire Tensor.
|
||||||
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
- **indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||||
Specifies the indices of elements of the original Tensor. Must be in the range
|
Specifies the indices of elements of the original Tensor. Values can be out of range of `params`,
|
||||||
`[0, input_param.shape()[axis])`.
|
and the exceeding part will be filled with 0 in the output.
|
||||||
- **axis** (int) - Specifies the dimension index to gather indices.
|
The indices to do lookup operation whose data type should be mindspore.int32 or mindspore.int64.
|
||||||
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
|
- **offset** (int) - Specifies the offset value of this `params` slice. Thus the real indices
|
||||||
are equal to `input_indices` minus `offset`.
|
are equal to `indices` minus `offset`.
|
||||||
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
|
|
||||||
- **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable
|
|
||||||
is used only if `reduce_scatter_flag` is True.
|
|
||||||
|
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
|
>>> params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
|
||||||
>>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
|
>>> indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
|
||||||
>>> axis = 0
|
|
||||||
>>> offset = 4
|
>>> offset = 4
|
||||||
>>> reduce_scatter_flag = False
|
>>> out = P.EmbeddingLookup()(params, indices, offset)
|
||||||
>>> split_num = 1
|
|
||||||
>>> out = P.EmbeddingLookup()(input_params, input_indices, axis, offset, reduce_scatter_flag, split_num)
|
|
||||||
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
|
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
|
||||||
"""
|
"""
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""init index_select"""
|
"""init index_select"""
|
||||||
self.__setattr_flag__ = True
|
self.init_prim_io_names(inputs=['params', 'indices', 'offset'],
|
||||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis', 'offset', 'reduce_scatter_flag', 'split_num'],
|
|
||||||
outputs=['output'])
|
outputs=['output'])
|
||||||
self.add_prim_attr('target', 'CPU')
|
|
||||||
|
|
||||||
def __infer__(self, params, indices, axis, offset, reduce_scatter_flag=False, split_num=2):
|
def __infer__(self, params, indices, offset):
|
||||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||||
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
valid_types = (mstype.int32, mstype.int64)
|
||||||
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
|
validator.check_tensor_type_same({"indices": indices['dtype']}, valid_types, self.name)
|
||||||
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
||||||
validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name)
|
|
||||||
if split_num['value'] < 1:
|
|
||||||
raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num)
|
|
||||||
axis_v = axis['value']
|
|
||||||
params_shp = params['shape']
|
params_shp = params['shape']
|
||||||
rank = len(params_shp)
|
out_shape = indices['shape'] + params_shp[1:]
|
||||||
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
|
|
||||||
if axis_v < 0:
|
|
||||||
axis_v += rank
|
|
||||||
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
|
|
||||||
if reduce_scatter_flag:
|
|
||||||
# partition the tensor along the dimension 0.
|
|
||||||
if out_shape[0] % split_num['value'] != 0:
|
|
||||||
raise ValueError("The dimension 0 of the shape: %d, is not divisible by split_num: %d." %
|
|
||||||
(out_shape[0], split_num['value']))
|
|
||||||
out_shape[0] = out_shape[0] // split_num['value']
|
|
||||||
out = {'shape': out_shape,
|
out = {'shape': out_shape,
|
||||||
'dtype': params['dtype'],
|
'dtype': params['dtype'],
|
||||||
'value': None}
|
'value': None}
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, offset):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.embedding = P.EmbeddingLookup()
|
||||||
|
self.offset = offset
|
||||||
|
|
||||||
|
def construct(self, param, index):
|
||||||
|
return self.embedding(param, index, self.offset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_lookup_sparse():
|
||||||
|
params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mstype.int32)
|
||||||
|
indices = Tensor(np.array([[5, 2], [8, 5]]), mstype.int32)
|
||||||
|
offset = 4
|
||||||
|
embedding = Net(offset)
|
||||||
|
out = embedding(params, indices)
|
||||||
|
assert(out.asnumpy() == [[[10, 11], [0, 0]], [[0, 0], [10, 11]]]).all()
|
|
@ -19,6 +19,7 @@ import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.common.api import _executor
|
from mindspore.common.api import _executor
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops.operations import _inner_ops as inner
|
||||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,29 +34,27 @@ class NetWithLoss(nn.Cell):
|
||||||
return self.loss(predict)
|
return self.loss(predict)
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, shape, axis, offset, reduce_scatter_flag, split_num):
|
def __init__(self, shape, offset, reduce_scatter_flag, split_num):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||||
self.axis = axis
|
|
||||||
self.offset = offset
|
self.offset = offset
|
||||||
self.reduce_scatter_flag = reduce_scatter_flag
|
self.reduce_scatter_flag = reduce_scatter_flag
|
||||||
self.split_num = split_num
|
self.split_num = split_num
|
||||||
self.elu = P.EmbeddingLookup()
|
self.elu = inner.EmbeddingLookup()
|
||||||
self.mm = P.BatchMatMul()
|
self.mm = P.BatchMatMul()
|
||||||
|
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
out = self.elu(x, self.index, self.axis, self.offset, self.reduce_scatter_flag, self.split_num)
|
out = self.elu(x, self.index, self.offset, self.reduce_scatter_flag, self.split_num)
|
||||||
out = self.mm(out, y)
|
out = self.mm(out, y)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def test_embeddinglookup_reducescatter_false():
|
def test_embeddinglookup_reducescatter_false():
|
||||||
shape = [8, 8]
|
shape = [8, 8]
|
||||||
axis = 0
|
|
||||||
offset = 8
|
offset = 8
|
||||||
reduce_scatter_flag = False
|
reduce_scatter_flag = False
|
||||||
split_num = 1
|
split_num = 1
|
||||||
net = NetWithLoss(Net(shape, axis, offset, reduce_scatter_flag, split_num))
|
net = NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num))
|
||||||
net.set_auto_parallel()
|
net.set_auto_parallel()
|
||||||
|
|
||||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
@ -64,14 +63,13 @@ def test_embeddinglookup_reducescatter_false():
|
||||||
|
|
||||||
|
|
||||||
def test_embeddinglookup_reducescatter_true():
|
def test_embeddinglookup_reducescatter_true():
|
||||||
shape = [8, 8]
|
shape = [64, 8]
|
||||||
axis = 0
|
|
||||||
offset = 8
|
offset = 8
|
||||||
reduce_scatter_flag = True
|
reduce_scatter_flag = True
|
||||||
split_num = 8
|
split_num = 8
|
||||||
net = NetWithLoss(Net(shape, axis, offset, reduce_scatter_flag, split_num))
|
net = NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num))
|
||||||
net.set_auto_parallel()
|
net.set_auto_parallel()
|
||||||
|
|
||||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
y = Tensor(np.ones([1, 32, 8]), dtype=ms.float32)
|
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
|
||||||
_executor.compile(net, x, y)
|
_executor.compile(net, x, y)
|
||||||
|
|
Loading…
Reference in New Issue