From cde5cc2bd2c2b384fee1caeb068fe0f74722ddf7 Mon Sep 17 00:00:00 2001 From: lichenever Date: Fri, 10 Jul 2020 15:39:49 +0800 Subject: [PATCH] add_embedding_look_up --- mindspore/ccsrc/parallel/dynamic_creator.h | 1 + .../parallel/ops_info/gather_v2_p_info.cc | 81 ++++++------------- .../parallel/ops_info/gather_v2_p_info.h | 13 ++- mindspore/ccsrc/parallel/ops_info/ops_utils.h | 1 + mindspore/ccsrc/parallel/step_parallel.cc | 2 +- mindspore/nn/layer/embedding.py | 46 +++++++++++ model_zoo/wide_and_deep/src/wide_and_deep.py | 6 +- .../python/parallel/test_embeddinglookup.py | 34 +++++++- tests/ut/python/parallel/test_gather_v2.py | 41 ---------- 9 files changed, 115 insertions(+), 110 deletions(-) diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h index f8e1d62d0ab..352c7449a5b 100644 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/parallel/dynamic_creator.h @@ -132,6 +132,7 @@ REGISTER(SqueezeInfo); REGISTER(SigmoidCrossEntropyWithLogitsInfo); REGISTER(SquareInfo); REGISTER(GatherV2PInfo); +REGISTER(EmbeddingLookupInfo); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc index dfecb29e889..d62111c0107 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc @@ -28,24 +28,25 @@ namespace mindspore { namespace parallel { Status GatherV2PInfo::GetAttrs() { - // get axis, the third input is the axis, is a ValueNode - if (input_value_.at(2) == nullptr) { - MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; - return FAILED; + // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis. + if (target_ != CPU) { + if (input_value_.at(2) == nullptr) { + MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; + return FAILED; + } + auto axis = GetValue(input_value_.at(2)); + // if axis is negative then convert it to positive + auto params_shape = inputs_shape_.at(0); + if (params_shape.size() == 0) { + MS_LOG(ERROR) << name_ << ": params can not be a scalar!"; + return FAILED; + } + if (axis < 0) { + axis += SizeToInt(inputs_shape_[0].size()); + } + axis_ = axis; } - auto axis = GetValue(input_value_.at(2)); - // if axis is negative then convert it to positive - auto params_shape = inputs_shape_.at(0); - if (params_shape.size() == 0) { - MS_LOG(ERROR) << name_ << ": params can not be a scalar!"; - return FAILED; - } - if (axis < 0) { - axis += SizeToInt(inputs_shape_[0].size()); - } - axis_ = axis; - // get target auto target_iter = attrs_.find(TARGET); if (target_iter != attrs_.end()) { MS_EXCEPTION_IF_NULL(target_iter->second); @@ -53,16 +54,8 @@ Status GatherV2PInfo::GetAttrs() { target_ = target_iter->second->cast()->value(); } else { MS_LOG(ERROR) << name_ << " : The value of target is not a string."; - return FAILED; } } - - // target=CPU, axis must be 0 - if (target_ == "CPU" && axis_ != 0) { - MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_; - return FAILED; - } - auto manual_split_iter = attrs_.find("manual_split"); if (manual_split_iter != attrs_.end()) { param_split_shapes_.clear(); @@ -459,38 +452,13 @@ Status GatherV2PInfo::InferForwardCommunication() { MS_LOG(ERROR) << name_ << ": Infer Group failed."; return FAILED; } - auto group_size = group_.GetDevNum(); Attr attr_group; - if (host_reduce_scatter_) { - // group size <= 8 - std::vector rank_list; - if (group_size <= 8) { - reduce_scatter_flag_ = false; - operator_name = HOST_REDUCE_SCATTER; - rank_list = GetRankFromGroup(group_); - attr_group = std::make_pair(GROUP, MakeValue(rank_list)); - } else { - // group size > 8, don't support host reduce_scatter - reduce_scatter_flag_ = true; - split_num_ = SizeToInt(group_size / 8); - CheckGlobalDeviceManager(); - operator_name = REDUCE_SCATTER; - int32_t rank = g_device_manager->global_rank(); - size_t repeat = group_size / 8; - for (size_t i = 0; i < repeat; ++i) { - rank_list.push_back(rank + SizeToInt(i * 8)); - } - Group g = g_device_manager->CreateGroup(rank_list); - attr_group = std::make_pair(GROUP, MakeValue(g.name())); - } - } else { - operator_name = REDUCE_SCATTER; - if (InferGroup() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer Group failed."; - return FAILED; - } - attr_group = std::make_pair(GROUP, MakeValue(group_.name())); + operator_name = REDUCE_SCATTER; + if (InferGroup() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Group failed."; + return FAILED; } + attr_group = std::make_pair(GROUP, MakeValue(group_.name())); Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); OperatorAttrs attrs = {attr_op, attr_group}; OperatorParams params; @@ -582,10 +550,7 @@ Status GatherV2PInfo::ComputeReplaceOp() { OperatorName op_name = EMBEDDING_LOOKUP; OperatorAttrs attrs; Attr param_offset = std::make_pair("offset", MakeValue(bias_)); - Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_)); - Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_)); - OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4), - std::make_pair(param_split_num, 5)}; + OperatorParams params = {std::make_pair(param_offset, 3)}; OperatorArgs args = std::make_pair(attrs, params); Operator op = std::make_pair(op_name, args); replace_op_.push_back(op); diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h index acdecb49a3d..16d5c856229 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h @@ -65,16 +65,13 @@ class GatherV2PInfo : public OperatorInfo { Status InferGroup(); int32_t axis_; - std::string target_; + std::string target_ = DEVICE; std::string replace_op_name_ = GATHERV2; int32_t bias_; int32_t index_offset_; int32_t slice_size_; Shape out_dev_matrix_shape_; Group group_; - bool reduce_scatter_flag_ = false; - int32_t split_num_ = 1; - bool host_reduce_scatter_ = false; bool manual_split_ = false; std::vector param_split_shapes_; std::vector index_offsets_; @@ -90,6 +87,14 @@ class SparseGatherV2Info : public GatherV2PInfo { private: std::string replace_op_name_ = SPARSE_GATHERV2; }; + +class EmbeddingLookupInfo : public GatherV2PInfo { + public: + EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {} + ~EmbeddingLookupInfo() override = default; +}; } // namespace parallel } // namespace mindspore #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 93e14d7f348..79dfb56693b 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -132,6 +132,7 @@ constexpr char REDISTRIBUTION_OP[] = "redistribution_op"; constexpr char DARA_PARALLEL[] = "data_parallel"; constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; +constexpr char DEVICE[] = "Device"; // Operator constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index cea82bc180d..c22b6ed5520 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -536,7 +536,7 @@ std::vector ReplaceOpInput(const Operator &replace_op, const std::st } std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; auto prim = GetValueNode(node->input(0)); - if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) { + if (prim->name() == EMBEDDING_LOOKUP) { replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; } if (!params.empty()) { diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index c8873039ab7..a0887886a08 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -105,3 +105,49 @@ class Embedding(Cell): self.embedding_table, self.dtype) return s + +class EmbeddingLookup(Cell): + r""" + Returns a slice of input tensor based on the specified indices. + + Note: + When 'target' is set to 'CPU', this module will use + P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which + specified 'offset = 0' to lookup table. + when 'target' is set to 'DEVICE', this module will use P.GatherV2() which + specified 'axis = 0' to lookup table. + + Args: + target (str): Specify the target where the op is executed. Default: 'CPU'. + + Inputs: + - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + The Tensor slice, instead of the entire Tensor. + - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. + Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, + and the exceeding part will be filled with 0 in the output. + + Outputs: + Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. + + Examples: + >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) + >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) + >>> out = nn.EmbeddingLookup()(input_params, input_indices) + [[[10, 11], [8 ,9]], [[14, 15], [12, 13]]] + """ + def __init__(self, target='CPU'): + super(EmbeddingLookup, self).__init__() + self.target = target + if target not in ('CPU', 'DEVICE'): + raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed ' + + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') + self.gatherv2 = P.GatherV2() + self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') + + def construct(self, params, indices): + if self.target == "CPU": + out = self.embeddinglookup(params, ids, 0) + else: + out = self.gatherv2(param, ids, 0) + return out diff --git a/model_zoo/wide_and_deep/src/wide_and_deep.py b/model_zoo/wide_and_deep/src/wide_and_deep.py index 16102039a88..5c04687fdce 100644 --- a/model_zoo/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/wide_and_deep/src/wide_and_deep.py @@ -188,7 +188,7 @@ class WideDeepModel(nn.Cell): self.deep_layer_act, use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) - self.gather_v2 = P.GatherV2() + self.embeddinglookup = nn.EmbeddingLookup() self.mul = P.Mul() self.reduce_sum = P.ReduceSum(keep_dims=False) self.reshape = P.Reshape() @@ -206,11 +206,11 @@ class WideDeepModel(nn.Cell): """ mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) # Wide layer - wide_id_weight = self.gather_v2(self.wide_w, id_hldr, 0) + wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr, 0) wx = self.mul(wide_id_weight, mask) wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) # Deep layer - deep_id_embs = self.gather_v2(self.embedding_table, id_hldr, 0) + deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr, 0) vx = self.mul(deep_id_embs, mask) deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim)) deep_in = self.dense_layer_1(deep_in) diff --git a/tests/ut/python/parallel/test_embeddinglookup.py b/tests/ut/python/parallel/test_embeddinglookup.py index f52010987ef..db84ab26eb3 100644 --- a/tests/ut/python/parallel/test_embeddinglookup.py +++ b/tests/ut/python/parallel/test_embeddinglookup.py @@ -41,12 +41,12 @@ class NetWithLoss(nn.Cell): return self.loss(predict) class Net(nn.Cell): - def __init__(self, shape, offset): + def __init__(self, shape, offset, strategy1=None, strategy2=None, target="Device"): super().__init__() self.index = Tensor(np.ones(shape), dtype=ms.int32) self.offset = offset - self.elu = P.EmbeddingLookup() - self.mm = P.BatchMatMul() + self.elu = P.EmbeddingLookup().set_strategy(strategy1).add_prim_attr("primitive_target", target) + self.mm = P.BatchMatMul().set_strategy(strategy2) def construct(self, x, y): out = self.elu(x, self.index, self.offset) @@ -97,3 +97,31 @@ def test_embeddinglookup_reducescatter_true_grad(): x = Tensor(np.ones([64, 32]), dtype=ms.float32) y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) _executor.compile(net, x, y) + + +def test_embeddinglookup_semi_auto1(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + shape = [64, 32] + offset = 0 + strategy1 = ((8, 1), (1, 1)) + strategy2 = ((4, 1, 2), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU"))) + + net.set_auto_parallel() + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_embeddinglookup_semi_auto2(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + shape = [64, 32] + offset = 0 + strategy1 = ((1, 8), (1, 1)) + strategy2 = ((4, 1, 2), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU"))) + + net.set_auto_parallel() + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) diff --git a/tests/ut/python/parallel/test_gather_v2.py b/tests/ut/python/parallel/test_gather_v2.py index 1467cd1e40d..2e853875bf6 100644 --- a/tests/ut/python/parallel/test_gather_v2.py +++ b/tests/ut/python/parallel/test_gather_v2.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================ import numpy as np -import pytest - import mindspore as ms import mindspore.nn as nn from mindspore import Tensor @@ -183,42 +181,3 @@ def test_gatherv2_auto1(): x = Tensor(np.ones([64, 32]), dtype=ms.float32) y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) _executor.compile(net, x, y) - - -@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") -def test_gatherv2_cpu0(): - context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((8, 1), (1, 1)) - strategy2 = ((4, 2, 1), (4, 2, 1)) - net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) - net.set_auto_parallel() - - x = Tensor(np.ones([64, 64]), dtype=ms.float32) - y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) - _executor.compile(net, x, y) - - -@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") -def test_gatherv2_cpu1(): - context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((16, 1), (1, 1)) - strategy2 = ((4, 2, 1), (4, 2, 1)) - net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) - net.set_auto_parallel() - - x = Tensor(np.ones([64, 64]), dtype=ms.float32) - y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) - _executor.compile(net, x, y) - - -@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") -def test_gatherv2_cpu2(): - context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((1, 8), (1, 1)) - strategy2 = ((4, 2, 1), (4, 2, 1)) - net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) - net.set_auto_parallel() - - x = Tensor(np.ones([64, 64]), dtype=ms.float32) - y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) - _executor.compile(net, x, y)