!3025 [AutoParallel]Add embedding look up op

Merge pull request !3025 from lichen/add_embedding_look_up_op
This commit is contained in:
mindspore-ci-bot 2020-07-14 09:59:01 +08:00 committed by Gitee
commit bfc3065fc7
9 changed files with 115 additions and 110 deletions

View File

@ -132,6 +132,7 @@ REGISTER(SqueezeInfo);
REGISTER(SigmoidCrossEntropyWithLogitsInfo); REGISTER(SigmoidCrossEntropyWithLogitsInfo);
REGISTER(SquareInfo); REGISTER(SquareInfo);
REGISTER(GatherV2PInfo); REGISTER(GatherV2PInfo);
REGISTER(EmbeddingLookupInfo);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -28,24 +28,25 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status GatherV2PInfo::GetAttrs() { Status GatherV2PInfo::GetAttrs() {
// get axis, the third input is the axis, is a ValueNode // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
if (input_value_.at(2) == nullptr) { if (target_ != CPU) {
MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; if (input_value_.at(2) == nullptr) {
return FAILED; MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
return FAILED;
}
auto axis = GetValue<int>(input_value_.at(2));
// if axis is negative then convert it to positive
auto params_shape = inputs_shape_.at(0);
if (params_shape.size() == 0) {
MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
return FAILED;
}
if (axis < 0) {
axis += SizeToInt(inputs_shape_[0].size());
}
axis_ = axis;
} }
auto axis = GetValue<int>(input_value_.at(2));
// if axis is negative then convert it to positive
auto params_shape = inputs_shape_.at(0);
if (params_shape.size() == 0) {
MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
return FAILED;
}
if (axis < 0) {
axis += SizeToInt(inputs_shape_[0].size());
}
axis_ = axis;
// get target
auto target_iter = attrs_.find(TARGET); auto target_iter = attrs_.find(TARGET);
if (target_iter != attrs_.end()) { if (target_iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(target_iter->second); MS_EXCEPTION_IF_NULL(target_iter->second);
@ -53,16 +54,8 @@ Status GatherV2PInfo::GetAttrs() {
target_ = target_iter->second->cast<StringImmPtr>()->value(); target_ = target_iter->second->cast<StringImmPtr>()->value();
} else { } else {
MS_LOG(ERROR) << name_ << " : The value of target is not a string."; MS_LOG(ERROR) << name_ << " : The value of target is not a string.";
return FAILED;
} }
} }
// target=CPU, axis must be 0
if (target_ == "CPU" && axis_ != 0) {
MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_;
return FAILED;
}
auto manual_split_iter = attrs_.find("manual_split"); auto manual_split_iter = attrs_.find("manual_split");
if (manual_split_iter != attrs_.end()) { if (manual_split_iter != attrs_.end()) {
param_split_shapes_.clear(); param_split_shapes_.clear();
@ -459,38 +452,13 @@ Status GatherV2PInfo::InferForwardCommunication() {
MS_LOG(ERROR) << name_ << ": Infer Group failed."; MS_LOG(ERROR) << name_ << ": Infer Group failed.";
return FAILED; return FAILED;
} }
auto group_size = group_.GetDevNum();
Attr attr_group; Attr attr_group;
if (host_reduce_scatter_) { operator_name = REDUCE_SCATTER;
// group size <= 8 if (InferGroup() != SUCCESS) {
std::vector<int32_t> rank_list; MS_LOG(ERROR) << name_ << ": Infer Group failed.";
if (group_size <= 8) { return FAILED;
reduce_scatter_flag_ = false;
operator_name = HOST_REDUCE_SCATTER;
rank_list = GetRankFromGroup(group_);
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
} else {
// group size > 8, don't support host reduce_scatter
reduce_scatter_flag_ = true;
split_num_ = SizeToInt(group_size / 8);
CheckGlobalDeviceManager();
operator_name = REDUCE_SCATTER;
int32_t rank = g_device_manager->global_rank();
size_t repeat = group_size / 8;
for (size_t i = 0; i < repeat; ++i) {
rank_list.push_back(rank + SizeToInt(i * 8));
}
Group g = g_device_manager->CreateGroup(rank_list);
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
}
} else {
operator_name = REDUCE_SCATTER;
if (InferGroup() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
return FAILED;
}
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
} }
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
OperatorAttrs attrs = {attr_op, attr_group}; OperatorAttrs attrs = {attr_op, attr_group};
OperatorParams params; OperatorParams params;
@ -582,10 +550,7 @@ Status GatherV2PInfo::ComputeReplaceOp() {
OperatorName op_name = EMBEDDING_LOOKUP; OperatorName op_name = EMBEDDING_LOOKUP;
OperatorAttrs attrs; OperatorAttrs attrs;
Attr param_offset = std::make_pair("offset", MakeValue(bias_)); Attr param_offset = std::make_pair("offset", MakeValue(bias_));
Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_)); OperatorParams params = {std::make_pair(param_offset, 3)};
Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_));
OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4),
std::make_pair(param_split_num, 5)};
OperatorArgs args = std::make_pair(attrs, params); OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(op_name, args); Operator op = std::make_pair(op_name, args);
replace_op_.push_back(op); replace_op_.push_back(op);

View File

@ -65,16 +65,13 @@ class GatherV2PInfo : public OperatorInfo {
Status InferGroup(); Status InferGroup();
int32_t axis_; int32_t axis_;
std::string target_; std::string target_ = DEVICE;
std::string replace_op_name_ = GATHERV2; std::string replace_op_name_ = GATHERV2;
int32_t bias_; int32_t bias_;
int32_t index_offset_; int32_t index_offset_;
int32_t slice_size_; int32_t slice_size_;
Shape out_dev_matrix_shape_; Shape out_dev_matrix_shape_;
Group group_; Group group_;
bool reduce_scatter_flag_ = false;
int32_t split_num_ = 1;
bool host_reduce_scatter_ = false;
bool manual_split_ = false; bool manual_split_ = false;
std::vector<int32_t> param_split_shapes_; std::vector<int32_t> param_split_shapes_;
std::vector<int32_t> index_offsets_; std::vector<int32_t> index_offsets_;
@ -90,6 +87,14 @@ class SparseGatherV2Info : public GatherV2PInfo {
private: private:
std::string replace_op_name_ = SPARSE_GATHERV2; std::string replace_op_name_ = SPARSE_GATHERV2;
}; };
class EmbeddingLookupInfo : public GatherV2PInfo {
public:
EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {}
~EmbeddingLookupInfo() override = default;
};
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_

View File

@ -132,6 +132,7 @@ constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
constexpr char DARA_PARALLEL[] = "data_parallel"; constexpr char DARA_PARALLEL[] = "data_parallel";
constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; constexpr char OPTIMIZER_SUB_STRING[] = "optimizer";
constexpr char DEVICE[] = "Device";
// Operator // Operator
constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; constexpr char VIRTUAL_DIV[] = "_VirtualDiv";

View File

@ -536,7 +536,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
} }
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)}; std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
auto prim = GetValueNode<PrimitivePtr>(node->input(0)); auto prim = GetValueNode<PrimitivePtr>(node->input(0));
if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) { if (prim->name() == EMBEDDING_LOOKUP) {
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
} }
if (!params.empty()) { if (!params.empty()) {

View File

@ -105,3 +105,49 @@ class Embedding(Cell):
self.embedding_table, self.embedding_table,
self.dtype) self.dtype)
return s return s
class EmbeddingLookup(Cell):
r"""
Returns a slice of input tensor based on the specified indices.
Note:
When 'target' is set to 'CPU', this module will use
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
specified 'offset = 0' to lookup table.
when 'target' is set to 'DEVICE', this module will use P.GatherV2() which
specified 'axis = 0' to lookup table.
Args:
target (str): Specify the target where the op is executed. Default: 'CPU'.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The Tensor slice, instead of the entire Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
and the exceeding part will be filled with 0 in the output.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples:
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
>>> out = nn.EmbeddingLookup()(input_params, input_indices)
[[[10, 11], [8 ,9]], [[14, 15], [12, 13]]]
"""
def __init__(self, target='CPU'):
super(EmbeddingLookup, self).__init__()
self.target = target
if target not in ('CPU', 'DEVICE'):
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
self.gatherv2 = P.GatherV2()
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
def construct(self, params, indices):
if self.target == "CPU":
out = self.embeddinglookup(params, ids, 0)
else:
out = self.gatherv2(param, ids, 0)
return out

View File

@ -188,7 +188,7 @@ class WideDeepModel(nn.Cell):
self.deep_layer_act, self.deep_layer_act,
use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) use_activation=False, convert_dtype=True, drop_out=config.dropout_flag)
self.gather_v2 = P.GatherV2() self.embeddinglookup = nn.EmbeddingLookup()
self.mul = P.Mul() self.mul = P.Mul()
self.reduce_sum = P.ReduceSum(keep_dims=False) self.reduce_sum = P.ReduceSum(keep_dims=False)
self.reshape = P.Reshape() self.reshape = P.Reshape()
@ -206,11 +206,11 @@ class WideDeepModel(nn.Cell):
""" """
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
# Wide layer # Wide layer
wide_id_weight = self.gather_v2(self.wide_w, id_hldr, 0) wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr, 0)
wx = self.mul(wide_id_weight, mask) wx = self.mul(wide_id_weight, mask)
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
# Deep layer # Deep layer
deep_id_embs = self.gather_v2(self.embedding_table, id_hldr, 0) deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr, 0)
vx = self.mul(deep_id_embs, mask) vx = self.mul(deep_id_embs, mask)
deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim)) deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim))
deep_in = self.dense_layer_1(deep_in) deep_in = self.dense_layer_1(deep_in)

View File

@ -41,12 +41,12 @@ class NetWithLoss(nn.Cell):
return self.loss(predict) return self.loss(predict)
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, shape, offset): def __init__(self, shape, offset, strategy1=None, strategy2=None, target="Device"):
super().__init__() super().__init__()
self.index = Tensor(np.ones(shape), dtype=ms.int32) self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.offset = offset self.offset = offset
self.elu = P.EmbeddingLookup() self.elu = P.EmbeddingLookup().set_strategy(strategy1).add_prim_attr("primitive_target", target)
self.mm = P.BatchMatMul() self.mm = P.BatchMatMul().set_strategy(strategy2)
def construct(self, x, y): def construct(self, x, y):
out = self.elu(x, self.index, self.offset) out = self.elu(x, self.index, self.offset)
@ -97,3 +97,31 @@ def test_embeddinglookup_reducescatter_true_grad():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
_executor.compile(net, x, y) _executor.compile(net, x, y)
def test_embeddinglookup_semi_auto1():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 32]
offset = 0
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 1, 2), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU")))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_embeddinglookup_semi_auto2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 32]
offset = 0
strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 1, 2), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU")))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)

View File

@ -13,8 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import numpy as np import numpy as np
import pytest
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
@ -183,42 +181,3 @@ def test_gatherv2_auto1():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y) _executor.compile(net, x, y)
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
def test_gatherv2_cpu0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
def test_gatherv2_cpu1():
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((16, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
def test_gatherv2_cpu2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)