forked from mindspore-Ecosystem/mindspore
embedding lookup auto parallel
This commit is contained in:
parent
28755b2f1a
commit
6f6a8ae9f0
|
@ -264,7 +264,8 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP,
|
||||
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
|
||||
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
|
||||
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS};
|
||||
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
|
||||
EMBEDDING_LOOKUP};
|
||||
// clang-format on
|
||||
|
||||
auto iter = splittable_op.find(op_name);
|
||||
|
|
|
@ -115,6 +115,13 @@ def test_auto_parallel_error():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_auto_parallel():
|
||||
context.set_context(save_graphs=True)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0)
|
||||
net = Net(split_string="fake")
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_axis_error():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||
strategy1 = ((2, 1), (1, 2))
|
||||
|
|
Loading…
Reference in New Issue