!3025 [AutoParallel]Add embedding look up op
Merge pull request !3025 from lichen/add_embedding_look_up_op
This commit is contained in:
commit
bfc3065fc7
|
@ -132,6 +132,7 @@ REGISTER(SqueezeInfo);
|
|||
REGISTER(SigmoidCrossEntropyWithLogitsInfo);
|
||||
REGISTER(SquareInfo);
|
||||
REGISTER(GatherV2PInfo);
|
||||
REGISTER(EmbeddingLookupInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -28,24 +28,25 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status GatherV2PInfo::GetAttrs() {
|
||||
// get axis, the third input is the axis, is a ValueNode
|
||||
if (input_value_.at(2) == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
|
||||
return FAILED;
|
||||
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
|
||||
if (target_ != CPU) {
|
||||
if (input_value_.at(2) == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
|
||||
return FAILED;
|
||||
}
|
||||
auto axis = GetValue<int>(input_value_.at(2));
|
||||
// if axis is negative then convert it to positive
|
||||
auto params_shape = inputs_shape_.at(0);
|
||||
if (params_shape.size() == 0) {
|
||||
MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
|
||||
return FAILED;
|
||||
}
|
||||
if (axis < 0) {
|
||||
axis += SizeToInt(inputs_shape_[0].size());
|
||||
}
|
||||
axis_ = axis;
|
||||
}
|
||||
auto axis = GetValue<int>(input_value_.at(2));
|
||||
// if axis is negative then convert it to positive
|
||||
auto params_shape = inputs_shape_.at(0);
|
||||
if (params_shape.size() == 0) {
|
||||
MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
|
||||
return FAILED;
|
||||
}
|
||||
if (axis < 0) {
|
||||
axis += SizeToInt(inputs_shape_[0].size());
|
||||
}
|
||||
axis_ = axis;
|
||||
|
||||
// get target
|
||||
auto target_iter = attrs_.find(TARGET);
|
||||
if (target_iter != attrs_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(target_iter->second);
|
||||
|
@ -53,16 +54,8 @@ Status GatherV2PInfo::GetAttrs() {
|
|||
target_ = target_iter->second->cast<StringImmPtr>()->value();
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << " : The value of target is not a string.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
// target=CPU, axis must be 0
|
||||
if (target_ == "CPU" && axis_ != 0) {
|
||||
MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
auto manual_split_iter = attrs_.find("manual_split");
|
||||
if (manual_split_iter != attrs_.end()) {
|
||||
param_split_shapes_.clear();
|
||||
|
@ -459,38 +452,13 @@ Status GatherV2PInfo::InferForwardCommunication() {
|
|||
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
||||
return FAILED;
|
||||
}
|
||||
auto group_size = group_.GetDevNum();
|
||||
Attr attr_group;
|
||||
if (host_reduce_scatter_) {
|
||||
// group size <= 8
|
||||
std::vector<int32_t> rank_list;
|
||||
if (group_size <= 8) {
|
||||
reduce_scatter_flag_ = false;
|
||||
operator_name = HOST_REDUCE_SCATTER;
|
||||
rank_list = GetRankFromGroup(group_);
|
||||
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
|
||||
} else {
|
||||
// group size > 8, don't support host reduce_scatter
|
||||
reduce_scatter_flag_ = true;
|
||||
split_num_ = SizeToInt(group_size / 8);
|
||||
CheckGlobalDeviceManager();
|
||||
operator_name = REDUCE_SCATTER;
|
||||
int32_t rank = g_device_manager->global_rank();
|
||||
size_t repeat = group_size / 8;
|
||||
for (size_t i = 0; i < repeat; ++i) {
|
||||
rank_list.push_back(rank + SizeToInt(i * 8));
|
||||
}
|
||||
Group g = g_device_manager->CreateGroup(rank_list);
|
||||
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
|
||||
}
|
||||
} else {
|
||||
operator_name = REDUCE_SCATTER;
|
||||
if (InferGroup() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
||||
return FAILED;
|
||||
}
|
||||
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
||||
operator_name = REDUCE_SCATTER;
|
||||
if (InferGroup() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
||||
return FAILED;
|
||||
}
|
||||
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
||||
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
|
||||
OperatorAttrs attrs = {attr_op, attr_group};
|
||||
OperatorParams params;
|
||||
|
@ -582,10 +550,7 @@ Status GatherV2PInfo::ComputeReplaceOp() {
|
|||
OperatorName op_name = EMBEDDING_LOOKUP;
|
||||
OperatorAttrs attrs;
|
||||
Attr param_offset = std::make_pair("offset", MakeValue(bias_));
|
||||
Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_));
|
||||
Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_));
|
||||
OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4),
|
||||
std::make_pair(param_split_num, 5)};
|
||||
OperatorParams params = {std::make_pair(param_offset, 3)};
|
||||
OperatorArgs args = std::make_pair(attrs, params);
|
||||
Operator op = std::make_pair(op_name, args);
|
||||
replace_op_.push_back(op);
|
||||
|
|
|
@ -65,16 +65,13 @@ class GatherV2PInfo : public OperatorInfo {
|
|||
Status InferGroup();
|
||||
|
||||
int32_t axis_;
|
||||
std::string target_;
|
||||
std::string target_ = DEVICE;
|
||||
std::string replace_op_name_ = GATHERV2;
|
||||
int32_t bias_;
|
||||
int32_t index_offset_;
|
||||
int32_t slice_size_;
|
||||
Shape out_dev_matrix_shape_;
|
||||
Group group_;
|
||||
bool reduce_scatter_flag_ = false;
|
||||
int32_t split_num_ = 1;
|
||||
bool host_reduce_scatter_ = false;
|
||||
bool manual_split_ = false;
|
||||
std::vector<int32_t> param_split_shapes_;
|
||||
std::vector<int32_t> index_offsets_;
|
||||
|
@ -90,6 +87,14 @@ class SparseGatherV2Info : public GatherV2PInfo {
|
|||
private:
|
||||
std::string replace_op_name_ = SPARSE_GATHERV2;
|
||||
};
|
||||
|
||||
class EmbeddingLookupInfo : public GatherV2PInfo {
|
||||
public:
|
||||
EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~EmbeddingLookupInfo() override = default;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_
|
||||
|
|
|
@ -132,6 +132,7 @@ constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
|
|||
constexpr char DARA_PARALLEL[] = "data_parallel";
|
||||
constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
|
||||
constexpr char OPTIMIZER_SUB_STRING[] = "optimizer";
|
||||
constexpr char DEVICE[] = "Device";
|
||||
|
||||
// Operator
|
||||
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
||||
|
|
|
@ -536,7 +536,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
|
|||
}
|
||||
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
|
||||
auto prim = GetValueNode<PrimitivePtr>(node->input(0));
|
||||
if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) {
|
||||
if (prim->name() == EMBEDDING_LOOKUP) {
|
||||
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
|
||||
}
|
||||
if (!params.empty()) {
|
||||
|
|
|
@ -105,3 +105,49 @@ class Embedding(Cell):
|
|||
self.embedding_table,
|
||||
self.dtype)
|
||||
return s
|
||||
|
||||
class EmbeddingLookup(Cell):
|
||||
r"""
|
||||
Returns a slice of input tensor based on the specified indices.
|
||||
|
||||
Note:
|
||||
When 'target' is set to 'CPU', this module will use
|
||||
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
|
||||
specified 'offset = 0' to lookup table.
|
||||
when 'target' is set to 'DEVICE', this module will use P.GatherV2() which
|
||||
specified 'axis = 0' to lookup table.
|
||||
|
||||
Args:
|
||||
target (str): Specify the target where the op is executed. Default: 'CPU'.
|
||||
|
||||
Inputs:
|
||||
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
The Tensor slice, instead of the entire Tensor.
|
||||
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
|
||||
and the exceeding part will be filled with 0 in the output.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
||||
|
||||
Examples:
|
||||
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
|
||||
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
|
||||
>>> out = nn.EmbeddingLookup()(input_params, input_indices)
|
||||
[[[10, 11], [8 ,9]], [[14, 15], [12, 13]]]
|
||||
"""
|
||||
def __init__(self, target='CPU'):
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
self.target = target
|
||||
if target not in ('CPU', 'DEVICE'):
|
||||
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
|
||||
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
|
||||
self.gatherv2 = P.GatherV2()
|
||||
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
||||
|
||||
def construct(self, params, indices):
|
||||
if self.target == "CPU":
|
||||
out = self.embeddinglookup(params, ids, 0)
|
||||
else:
|
||||
out = self.gatherv2(param, ids, 0)
|
||||
return out
|
||||
|
|
|
@ -188,7 +188,7 @@ class WideDeepModel(nn.Cell):
|
|||
self.deep_layer_act,
|
||||
use_activation=False, convert_dtype=True, drop_out=config.dropout_flag)
|
||||
|
||||
self.gather_v2 = P.GatherV2()
|
||||
self.embeddinglookup = nn.EmbeddingLookup()
|
||||
self.mul = P.Mul()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.reshape = P.Reshape()
|
||||
|
@ -206,11 +206,11 @@ class WideDeepModel(nn.Cell):
|
|||
"""
|
||||
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
|
||||
# Wide layer
|
||||
wide_id_weight = self.gather_v2(self.wide_w, id_hldr, 0)
|
||||
wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr, 0)
|
||||
wx = self.mul(wide_id_weight, mask)
|
||||
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
|
||||
# Deep layer
|
||||
deep_id_embs = self.gather_v2(self.embedding_table, id_hldr, 0)
|
||||
deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr, 0)
|
||||
vx = self.mul(deep_id_embs, mask)
|
||||
deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim))
|
||||
deep_in = self.dense_layer_1(deep_in)
|
||||
|
|
|
@ -41,12 +41,12 @@ class NetWithLoss(nn.Cell):
|
|||
return self.loss(predict)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, shape, offset):
|
||||
def __init__(self, shape, offset, strategy1=None, strategy2=None, target="Device"):
|
||||
super().__init__()
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
self.offset = offset
|
||||
self.elu = P.EmbeddingLookup()
|
||||
self.mm = P.BatchMatMul()
|
||||
self.elu = P.EmbeddingLookup().set_strategy(strategy1).add_prim_attr("primitive_target", target)
|
||||
self.mm = P.BatchMatMul().set_strategy(strategy2)
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.elu(x, self.index, self.offset)
|
||||
|
@ -97,3 +97,31 @@ def test_embeddinglookup_reducescatter_true_grad():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_embeddinglookup_semi_auto1():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
shape = [64, 32]
|
||||
offset = 0
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((4, 1, 2), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU")))
|
||||
|
||||
net.set_auto_parallel()
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_embeddinglookup_semi_auto2():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
shape = [64, 32]
|
||||
offset = 0
|
||||
strategy1 = ((1, 8), (1, 1))
|
||||
strategy2 = ((4, 1, 2), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU")))
|
||||
|
||||
net.set_auto_parallel()
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
|
|
@ -13,8 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
@ -183,42 +181,3 @@ def test_gatherv2_auto1():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
|
||||
def test_gatherv2_cpu0():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
|
||||
def test_gatherv2_cpu1():
|
||||
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((16, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
|
||||
def test_gatherv2_cpu2():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((1, 8), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
|
Loading…
Reference in New Issue