forked from mindspore-Ecosystem/mindspore
!3025 [AutoParallel]Add embedding look up op
Merge pull request !3025 from lichen/add_embedding_look_up_op
This commit is contained in:
commit
bfc3065fc7
|
@ -132,6 +132,7 @@ REGISTER(SqueezeInfo);
|
||||||
REGISTER(SigmoidCrossEntropyWithLogitsInfo);
|
REGISTER(SigmoidCrossEntropyWithLogitsInfo);
|
||||||
REGISTER(SquareInfo);
|
REGISTER(SquareInfo);
|
||||||
REGISTER(GatherV2PInfo);
|
REGISTER(GatherV2PInfo);
|
||||||
|
REGISTER(EmbeddingLookupInfo);
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -28,24 +28,25 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
Status GatherV2PInfo::GetAttrs() {
|
Status GatherV2PInfo::GetAttrs() {
|
||||||
// get axis, the third input is the axis, is a ValueNode
|
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
|
||||||
if (input_value_.at(2) == nullptr) {
|
if (target_ != CPU) {
|
||||||
MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
|
if (input_value_.at(2) == nullptr) {
|
||||||
return FAILED;
|
MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
auto axis = GetValue<int>(input_value_.at(2));
|
||||||
|
// if axis is negative then convert it to positive
|
||||||
|
auto params_shape = inputs_shape_.at(0);
|
||||||
|
if (params_shape.size() == 0) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (axis < 0) {
|
||||||
|
axis += SizeToInt(inputs_shape_[0].size());
|
||||||
|
}
|
||||||
|
axis_ = axis;
|
||||||
}
|
}
|
||||||
auto axis = GetValue<int>(input_value_.at(2));
|
|
||||||
// if axis is negative then convert it to positive
|
|
||||||
auto params_shape = inputs_shape_.at(0);
|
|
||||||
if (params_shape.size() == 0) {
|
|
||||||
MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
if (axis < 0) {
|
|
||||||
axis += SizeToInt(inputs_shape_[0].size());
|
|
||||||
}
|
|
||||||
axis_ = axis;
|
|
||||||
|
|
||||||
// get target
|
|
||||||
auto target_iter = attrs_.find(TARGET);
|
auto target_iter = attrs_.find(TARGET);
|
||||||
if (target_iter != attrs_.end()) {
|
if (target_iter != attrs_.end()) {
|
||||||
MS_EXCEPTION_IF_NULL(target_iter->second);
|
MS_EXCEPTION_IF_NULL(target_iter->second);
|
||||||
|
@ -53,16 +54,8 @@ Status GatherV2PInfo::GetAttrs() {
|
||||||
target_ = target_iter->second->cast<StringImmPtr>()->value();
|
target_ = target_iter->second->cast<StringImmPtr>()->value();
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << name_ << " : The value of target is not a string.";
|
MS_LOG(ERROR) << name_ << " : The value of target is not a string.";
|
||||||
return FAILED;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// target=CPU, axis must be 0
|
|
||||||
if (target_ == "CPU" && axis_ != 0) {
|
|
||||||
MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_;
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto manual_split_iter = attrs_.find("manual_split");
|
auto manual_split_iter = attrs_.find("manual_split");
|
||||||
if (manual_split_iter != attrs_.end()) {
|
if (manual_split_iter != attrs_.end()) {
|
||||||
param_split_shapes_.clear();
|
param_split_shapes_.clear();
|
||||||
|
@ -459,38 +452,13 @@ Status GatherV2PInfo::InferForwardCommunication() {
|
||||||
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
auto group_size = group_.GetDevNum();
|
|
||||||
Attr attr_group;
|
Attr attr_group;
|
||||||
if (host_reduce_scatter_) {
|
operator_name = REDUCE_SCATTER;
|
||||||
// group size <= 8
|
if (InferGroup() != SUCCESS) {
|
||||||
std::vector<int32_t> rank_list;
|
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
||||||
if (group_size <= 8) {
|
return FAILED;
|
||||||
reduce_scatter_flag_ = false;
|
|
||||||
operator_name = HOST_REDUCE_SCATTER;
|
|
||||||
rank_list = GetRankFromGroup(group_);
|
|
||||||
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
|
|
||||||
} else {
|
|
||||||
// group size > 8, don't support host reduce_scatter
|
|
||||||
reduce_scatter_flag_ = true;
|
|
||||||
split_num_ = SizeToInt(group_size / 8);
|
|
||||||
CheckGlobalDeviceManager();
|
|
||||||
operator_name = REDUCE_SCATTER;
|
|
||||||
int32_t rank = g_device_manager->global_rank();
|
|
||||||
size_t repeat = group_size / 8;
|
|
||||||
for (size_t i = 0; i < repeat; ++i) {
|
|
||||||
rank_list.push_back(rank + SizeToInt(i * 8));
|
|
||||||
}
|
|
||||||
Group g = g_device_manager->CreateGroup(rank_list);
|
|
||||||
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
operator_name = REDUCE_SCATTER;
|
|
||||||
if (InferGroup() != SUCCESS) {
|
|
||||||
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
|
||||||
}
|
}
|
||||||
|
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
||||||
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
|
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
|
||||||
OperatorAttrs attrs = {attr_op, attr_group};
|
OperatorAttrs attrs = {attr_op, attr_group};
|
||||||
OperatorParams params;
|
OperatorParams params;
|
||||||
|
@ -582,10 +550,7 @@ Status GatherV2PInfo::ComputeReplaceOp() {
|
||||||
OperatorName op_name = EMBEDDING_LOOKUP;
|
OperatorName op_name = EMBEDDING_LOOKUP;
|
||||||
OperatorAttrs attrs;
|
OperatorAttrs attrs;
|
||||||
Attr param_offset = std::make_pair("offset", MakeValue(bias_));
|
Attr param_offset = std::make_pair("offset", MakeValue(bias_));
|
||||||
Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_));
|
OperatorParams params = {std::make_pair(param_offset, 3)};
|
||||||
Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_));
|
|
||||||
OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4),
|
|
||||||
std::make_pair(param_split_num, 5)};
|
|
||||||
OperatorArgs args = std::make_pair(attrs, params);
|
OperatorArgs args = std::make_pair(attrs, params);
|
||||||
Operator op = std::make_pair(op_name, args);
|
Operator op = std::make_pair(op_name, args);
|
||||||
replace_op_.push_back(op);
|
replace_op_.push_back(op);
|
||||||
|
|
|
@ -65,16 +65,13 @@ class GatherV2PInfo : public OperatorInfo {
|
||||||
Status InferGroup();
|
Status InferGroup();
|
||||||
|
|
||||||
int32_t axis_;
|
int32_t axis_;
|
||||||
std::string target_;
|
std::string target_ = DEVICE;
|
||||||
std::string replace_op_name_ = GATHERV2;
|
std::string replace_op_name_ = GATHERV2;
|
||||||
int32_t bias_;
|
int32_t bias_;
|
||||||
int32_t index_offset_;
|
int32_t index_offset_;
|
||||||
int32_t slice_size_;
|
int32_t slice_size_;
|
||||||
Shape out_dev_matrix_shape_;
|
Shape out_dev_matrix_shape_;
|
||||||
Group group_;
|
Group group_;
|
||||||
bool reduce_scatter_flag_ = false;
|
|
||||||
int32_t split_num_ = 1;
|
|
||||||
bool host_reduce_scatter_ = false;
|
|
||||||
bool manual_split_ = false;
|
bool manual_split_ = false;
|
||||||
std::vector<int32_t> param_split_shapes_;
|
std::vector<int32_t> param_split_shapes_;
|
||||||
std::vector<int32_t> index_offsets_;
|
std::vector<int32_t> index_offsets_;
|
||||||
|
@ -90,6 +87,14 @@ class SparseGatherV2Info : public GatherV2PInfo {
|
||||||
private:
|
private:
|
||||||
std::string replace_op_name_ = SPARSE_GATHERV2;
|
std::string replace_op_name_ = SPARSE_GATHERV2;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class EmbeddingLookupInfo : public GatherV2PInfo {
|
||||||
|
public:
|
||||||
|
EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||||
|
~EmbeddingLookupInfo() override = default;
|
||||||
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_
|
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_
|
||||||
|
|
|
@ -132,6 +132,7 @@ constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
|
||||||
constexpr char DARA_PARALLEL[] = "data_parallel";
|
constexpr char DARA_PARALLEL[] = "data_parallel";
|
||||||
constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
|
constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
|
||||||
constexpr char OPTIMIZER_SUB_STRING[] = "optimizer";
|
constexpr char OPTIMIZER_SUB_STRING[] = "optimizer";
|
||||||
|
constexpr char DEVICE[] = "Device";
|
||||||
|
|
||||||
// Operator
|
// Operator
|
||||||
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
||||||
|
|
|
@ -536,7 +536,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
|
||||||
}
|
}
|
||||||
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
|
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
|
||||||
auto prim = GetValueNode<PrimitivePtr>(node->input(0));
|
auto prim = GetValueNode<PrimitivePtr>(node->input(0));
|
||||||
if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) {
|
if (prim->name() == EMBEDDING_LOOKUP) {
|
||||||
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
|
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
|
||||||
}
|
}
|
||||||
if (!params.empty()) {
|
if (!params.empty()) {
|
||||||
|
|
|
@ -105,3 +105,49 @@ class Embedding(Cell):
|
||||||
self.embedding_table,
|
self.embedding_table,
|
||||||
self.dtype)
|
self.dtype)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
class EmbeddingLookup(Cell):
|
||||||
|
r"""
|
||||||
|
Returns a slice of input tensor based on the specified indices.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
When 'target' is set to 'CPU', this module will use
|
||||||
|
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
|
||||||
|
specified 'offset = 0' to lookup table.
|
||||||
|
when 'target' is set to 'DEVICE', this module will use P.GatherV2() which
|
||||||
|
specified 'axis = 0' to lookup table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (str): Specify the target where the op is executed. Default: 'CPU'.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||||
|
The Tensor slice, instead of the entire Tensor.
|
||||||
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||||
|
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
|
||||||
|
and the exceeding part will be filled with 0 in the output.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
|
||||||
|
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
|
||||||
|
>>> out = nn.EmbeddingLookup()(input_params, input_indices)
|
||||||
|
[[[10, 11], [8 ,9]], [[14, 15], [12, 13]]]
|
||||||
|
"""
|
||||||
|
def __init__(self, target='CPU'):
|
||||||
|
super(EmbeddingLookup, self).__init__()
|
||||||
|
self.target = target
|
||||||
|
if target not in ('CPU', 'DEVICE'):
|
||||||
|
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
|
||||||
|
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
|
||||||
|
self.gatherv2 = P.GatherV2()
|
||||||
|
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
||||||
|
|
||||||
|
def construct(self, params, indices):
|
||||||
|
if self.target == "CPU":
|
||||||
|
out = self.embeddinglookup(params, ids, 0)
|
||||||
|
else:
|
||||||
|
out = self.gatherv2(param, ids, 0)
|
||||||
|
return out
|
||||||
|
|
|
@ -188,7 +188,7 @@ class WideDeepModel(nn.Cell):
|
||||||
self.deep_layer_act,
|
self.deep_layer_act,
|
||||||
use_activation=False, convert_dtype=True, drop_out=config.dropout_flag)
|
use_activation=False, convert_dtype=True, drop_out=config.dropout_flag)
|
||||||
|
|
||||||
self.gather_v2 = P.GatherV2()
|
self.embeddinglookup = nn.EmbeddingLookup()
|
||||||
self.mul = P.Mul()
|
self.mul = P.Mul()
|
||||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
|
@ -206,11 +206,11 @@ class WideDeepModel(nn.Cell):
|
||||||
"""
|
"""
|
||||||
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
|
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
|
||||||
# Wide layer
|
# Wide layer
|
||||||
wide_id_weight = self.gather_v2(self.wide_w, id_hldr, 0)
|
wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr, 0)
|
||||||
wx = self.mul(wide_id_weight, mask)
|
wx = self.mul(wide_id_weight, mask)
|
||||||
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
|
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
|
||||||
# Deep layer
|
# Deep layer
|
||||||
deep_id_embs = self.gather_v2(self.embedding_table, id_hldr, 0)
|
deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr, 0)
|
||||||
vx = self.mul(deep_id_embs, mask)
|
vx = self.mul(deep_id_embs, mask)
|
||||||
deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim))
|
deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim))
|
||||||
deep_in = self.dense_layer_1(deep_in)
|
deep_in = self.dense_layer_1(deep_in)
|
||||||
|
|
|
@ -41,12 +41,12 @@ class NetWithLoss(nn.Cell):
|
||||||
return self.loss(predict)
|
return self.loss(predict)
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, shape, offset):
|
def __init__(self, shape, offset, strategy1=None, strategy2=None, target="Device"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||||
self.offset = offset
|
self.offset = offset
|
||||||
self.elu = P.EmbeddingLookup()
|
self.elu = P.EmbeddingLookup().set_strategy(strategy1).add_prim_attr("primitive_target", target)
|
||||||
self.mm = P.BatchMatMul()
|
self.mm = P.BatchMatMul().set_strategy(strategy2)
|
||||||
|
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
out = self.elu(x, self.index, self.offset)
|
out = self.elu(x, self.index, self.offset)
|
||||||
|
@ -97,3 +97,31 @@ def test_embeddinglookup_reducescatter_true_grad():
|
||||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
|
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
|
||||||
_executor.compile(net, x, y)
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embeddinglookup_semi_auto1():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
shape = [64, 32]
|
||||||
|
offset = 0
|
||||||
|
strategy1 = ((8, 1), (1, 1))
|
||||||
|
strategy2 = ((4, 1, 2), (4, 2, 1))
|
||||||
|
net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU")))
|
||||||
|
|
||||||
|
net.set_auto_parallel()
|
||||||
|
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embeddinglookup_semi_auto2():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
shape = [64, 32]
|
||||||
|
offset = 0
|
||||||
|
strategy1 = ((1, 8), (1, 1))
|
||||||
|
strategy2 = ((4, 1, 2), (4, 2, 1))
|
||||||
|
net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU")))
|
||||||
|
|
||||||
|
net.set_auto_parallel()
|
||||||
|
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
|
@ -13,8 +13,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
|
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
|
@ -183,42 +181,3 @@ def test_gatherv2_auto1():
|
||||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||||
_executor.compile(net, x, y)
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
|
|
||||||
def test_gatherv2_cpu0():
|
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
|
||||||
strategy1 = ((8, 1), (1, 1))
|
|
||||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
|
||||||
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
|
||||||
net.set_auto_parallel()
|
|
||||||
|
|
||||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
|
||||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
|
||||||
_executor.compile(net, x, y)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
|
|
||||||
def test_gatherv2_cpu1():
|
|
||||||
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
|
|
||||||
strategy1 = ((16, 1), (1, 1))
|
|
||||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
|
||||||
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
|
||||||
net.set_auto_parallel()
|
|
||||||
|
|
||||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
|
||||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
|
||||||
_executor.compile(net, x, y)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen")
|
|
||||||
def test_gatherv2_cpu2():
|
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
|
||||||
strategy1 = ((1, 8), (1, 1))
|
|
||||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
|
||||||
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
|
|
||||||
net.set_auto_parallel()
|
|
||||||
|
|
||||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
|
||||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
|
||||||
_executor.compile(net, x, y)
|
|
||||||
|
|
Loading…
Reference in New Issue