!3025 [AutoParallel]Add embedding look up op

Merge pull request !3025 from lichen/add_embedding_look_up_op
This commit is contained in:
mindspore-ci-bot 2020-07-14 09:59:01 +08:00 committed by Gitee
commit bfc3065fc7
9 changed files with 115 additions and 110 deletions

View File

@ -132,6 +132,7 @@ REGISTER(SqueezeInfo);
REGISTER(SigmoidCrossEntropyWithLogitsInfo);
REGISTER(SquareInfo);
REGISTER(GatherV2PInfo);
REGISTER(EmbeddingLookupInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -28,24 +28,25 @@
namespace mindspore {
namespace parallel {
Status GatherV2PInfo::GetAttrs() {
// get axis, the third input is the axis, is a ValueNode
if (input_value_.at(2) == nullptr) {
MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
return FAILED;
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
if (target_ != CPU) {
if (input_value_.at(2) == nullptr) {
MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
return FAILED;
}
auto axis = GetValue<int>(input_value_.at(2));
// if axis is negative then convert it to positive
auto params_shape = inputs_shape_.at(0);
if (params_shape.size() == 0) {
MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
return FAILED;
}
if (axis < 0) {
axis += SizeToInt(inputs_shape_[0].size());
}
axis_ = axis;
}
auto axis = GetValue<int>(input_value_.at(2));
// if axis is negative then convert it to positive
auto params_shape = inputs_shape_.at(0);
if (params_shape.size() == 0) {
MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
return FAILED;
}
if (axis < 0) {
axis += SizeToInt(inputs_shape_[0].size());
}
axis_ = axis;
// get target
auto target_iter = attrs_.find(TARGET);
if (target_iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(target_iter->second);
@ -53,16 +54,8 @@ Status GatherV2PInfo::GetAttrs() {
target_ = target_iter->second->cast<StringImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << " : The value of target is not a string.";
return FAILED;
}
}
// target=CPU, axis must be 0
if (target_ == "CPU" && axis_ != 0) {
MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_;
return FAILED;
}
auto manual_split_iter = attrs_.find("manual_split");
if (manual_split_iter != attrs_.end()) {
param_split_shapes_.clear();
@ -459,38 +452,13 @@ Status GatherV2PInfo::InferForwardCommunication() {
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
return FAILED;
}
auto group_size = group_.GetDevNum();
Attr attr_group;
if (host_reduce_scatter_) {
// group size <= 8
std::vector<int32_t> rank_list;
if (group_size <= 8) {
reduce_scatter_flag_ = false;
operator_name = HOST_REDUCE_SCATTER;
rank_list = GetRankFromGroup(group_);
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
} else {
// group size > 8, don't support host reduce_scatter
reduce_scatter_flag_ = true;
split_num_ = SizeToInt(group_size / 8);
CheckGlobalDeviceManager();
operator_name = REDUCE_SCATTER;
int32_t rank = g_device_manager->global_rank();
size_t repeat = group_size / 8;
for (size_t i = 0; i < repeat; ++i) {
rank_list.push_back(rank + SizeToInt(i * 8));
}
Group g = g_device_manager->CreateGroup(rank_list);
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
}
} else {
operator_name = REDUCE_SCATTER;
if (InferGroup() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
return FAILED;
}
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
operator_name = REDUCE_SCATTER;
if (InferGroup() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
return FAILED;
}
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
OperatorAttrs attrs = {attr_op, attr_group};
OperatorParams params;
@ -582,10 +550,7 @@ Status GatherV2PInfo::ComputeReplaceOp() {
OperatorName op_name = EMBEDDING_LOOKUP;
OperatorAttrs attrs;
Attr param_offset = std::make_pair("offset", MakeValue(bias_));
Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_));
Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_));
OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4),
std::make_pair(param_split_num, 5)};
OperatorParams params = {std::make_pair(param_offset, 3)};
OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(op_name, args);
replace_op_.push_back(op);

View File

@ -65,16 +65,13 @@ class GatherV2PInfo : public OperatorInfo {
Status InferGroup();
int32_t axis_;
std::string target_;
std::string target_ = DEVICE;
std::string replace_op_name_ = GATHERV2;
int32_t bias_;
int32_t index_offset_;
int32_t slice_size_;
Shape out_dev_matrix_shape_;
Group group_;
bool reduce_scatter_flag_ = false;
int32_t split_num_ = 1;
bool host_reduce_scatter_ = false;
bool manual_split_ = false;
std::vector<int32_t> param_split_shapes_;
std::vector<int32_t> index_offsets_;
@ -90,6 +87,14 @@ class SparseGatherV2Info : public GatherV2PInfo {
private:
std::string replace_op_name_ = SPARSE_GATHERV2;
};
class EmbeddingLookupInfo : public GatherV2PInfo {
public:
EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {}
~EmbeddingLookupInfo() override = default;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_

View File

@ -132,6 +132,7 @@ constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
constexpr char DARA_PARALLEL[] = "data_parallel";
constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
constexpr char OPTIMIZER_SUB_STRING[] = "optimizer";
constexpr char DEVICE[] = "Device";
// Operator
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";

View File

@ -536,7 +536,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
}
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
auto prim = GetValueNode<PrimitivePtr>(node->input(0));
if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) {
if (prim->name() == EMBEDDING_LOOKUP) {
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
}
if (!params.empty()) {

View File

@ -105,3 +105,49 @@ class Embedding(Cell):
self.embedding_table,
self.dtype)
return s
class EmbeddingLookup(Cell):
r"""
Returns a slice of input tensor based on the specified indices.
Note:
When 'target' is set to 'CPU', this module will use
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
specified 'offset = 0' to lookup table.
when 'target' is set to 'DEVICE', this module will use P.GatherV2() which
specified 'axis = 0' to lookup table.
Args:
target (str): Specify the target where the op is executed. Default: 'CPU'.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The Tensor slice, instead of the entire Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
and the exceeding part will be filled with 0 in the output.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples:
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
>>> out = nn.EmbeddingLookup()(input_params, input_indices)
[[[10, 11], [8 ,9]], [[14, 15], [12, 13]]]
"""
def __init__(self, target='CPU'):
super(EmbeddingLookup, self).__init__()
self.target = target
if target not in ('CPU', 'DEVICE'):
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
self.gatherv2 = P.GatherV2()
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
def construct(self, params, indices):
if self.target == "CPU":
out = self.embeddinglookup(params, ids, 0)
else:
out = self.gatherv2(param, ids, 0)
return out

View File

@ -188,7 +188,7 @@ class WideDeepModel(nn.Cell):
self.deep_layer_act,
use_activation=False, convert_dtype=True, drop_out=config.dropout_flag)
self.gather_v2 = P.GatherV2()
self.embeddinglookup = nn.EmbeddingLookup()
self.mul = P.Mul()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.reshape = P.Reshape()
@ -206,11 +206,11 @@ class WideDeepModel(nn.Cell):
"""
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
# Wide layer
wide_id_weight = self.gather_v2(self.wide_w, id_hldr, 0)
wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr, 0)
wx = self.mul(wide_id_weight, mask)
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
# Deep layer
deep_id_embs = self.gather_v2(self.embedding_table, id_hldr, 0)
deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr, 0)
vx = self.mul(deep_id_embs, mask)
deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim))
deep_in = self.dense_layer_1(deep_in)

View File

@ -41,12 +41,12 @@ class NetWithLoss(nn.Cell):
return self.loss(predict)
class Net(nn.Cell):
def __init__(self, shape, offset):
def __init__(self, shape, offset, strategy1=None, strategy2=None, target="Device"):
super().__init__()
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.offset = offset
self.elu = P.EmbeddingLookup()
self.mm = P.BatchMatMul()
self.elu = P.EmbeddingLookup().set_strategy(strategy1).add_prim_attr("primitive_target", target)
self.mm = P.BatchMatMul().set_strategy(strategy2)
def construct(self, x, y):
out = self.elu(x, self.index, self.offset)
@ -97,3 +97,31 @@ def test_embeddinglookup_reducescatter_true_grad():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_embeddinglookup_semi_auto1():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 32]
offset = 0
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 1, 2), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU")))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_embeddinglookup_semi_auto2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 32]
offset = 0
strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 1, 2), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU")))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)

View File

@ -13,8 +13,6 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
@ -183,42 +181,3 @@ def test_gatherv2_auto1():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
def test_gatherv2_cpu0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
def test_gatherv2_cpu1():
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((16, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
def test_gatherv2_cpu2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)