fix grpah mode loop sink bug in auto parallel
This commit is contained in:
parent
976226f9ac
commit
f946aea10d
|
@ -34,8 +34,8 @@ namespace parallel {
|
|||
#define OPERATOR_TO_OPERATOR_CONNECTOR "-"
|
||||
#define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0)
|
||||
#define DEFAULT_COST_MODEL_ALPHA 1.0
|
||||
#define DEFAULT_COST_MODEL_BETA 65.0
|
||||
#define DEFAULT_COST_MODEL_GAMMA 0.02
|
||||
#define DEFAULT_COST_MODEL_BETA 260.0
|
||||
#define DEFAULT_COST_MODEL_GAMMA 0.001
|
||||
#define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true
|
||||
#define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0
|
||||
#define DEFAULT_COST_MODEL_COMMUNI_CONST 3072.0
|
||||
|
|
|
@ -375,6 +375,10 @@ bool IsParallelCareNode(const CNodePtr& cnode) {
|
|||
MS_LOG(INFO) << "Parallel don't care node: " << prim->name();
|
||||
return false;
|
||||
}
|
||||
// get_next is not in the forward graph, we need mark the get_next as the forward node
|
||||
if (prim->name() == GET_NEXT) {
|
||||
return true;
|
||||
}
|
||||
if ((prim->name() == CAST)) {
|
||||
if ((!attrs.count(STRATEGY)) && (cnode->operator_info() == nullptr)) {
|
||||
return false;
|
||||
|
|
|
@ -88,7 +88,7 @@ class _DatasetIter:
|
|||
# times the batch dimension of tensors for run
|
||||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
device_num = _get_device_num()
|
||||
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
||||
self.dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
||||
|
||||
def __iter__(self):
|
||||
self.ind = 0
|
||||
|
|
|
@ -80,9 +80,9 @@ def test_common_parameter():
|
|||
|
||||
_executor.compile(net, x, y, z, w, phase='train')
|
||||
strategies = _executor._get_strategy(net)
|
||||
expected_strategies = {'Default/network-Net/MatMul-op8': [[1, 1], [1, 8]],
|
||||
'Default/network-Net/MatMul-op9': [[1, 1], [1, 8]],
|
||||
'Default/network-Net/Cast-op10': [[1, 8]],
|
||||
'Default/network-Net/MatMul-op0': [[1, 1], [1, 8]],
|
||||
'Default/network-Net/Cast-op11': [[1, 8]]}
|
||||
assert strategies == expected_strategies
|
||||
expected_strategies = {'Default/network-Net/MatMul-op6': [[8, 1], [1, 1]],
|
||||
'Default/network-Net/MatMul-op8': [[8, 1], [1, 1]],
|
||||
'Default/network-Net/Cast-op7': [[1, 1]],
|
||||
'Default/network-Net/MatMul-op0': [[8, 1], [1, 1]],
|
||||
'Default/network-Net/Cast-op9': [[1, 1]]}
|
||||
assert strategies == expected_strategies
|
||||
|
|
|
@ -86,9 +86,9 @@ def test_two_matmul():
|
|||
costmodel_alpha = cost_model_context.get_cost_model_context("costmodel_alpha")
|
||||
assert costmodel_alpha == 1.0
|
||||
costmodel_beta = cost_model_context.get_cost_model_context("costmodel_beta")
|
||||
assert costmodel_beta == 65.0
|
||||
assert costmodel_beta == 260.0
|
||||
costmodel_gamma = cost_model_context.get_cost_model_context("costmodel_gamma")
|
||||
assert costmodel_gamma == 0.02
|
||||
assert costmodel_gamma == 0.001
|
||||
costmodel_communi_threshold = cost_model_context.get_cost_model_context("costmodel_communi_threshold")
|
||||
assert costmodel_communi_threshold == 2048.0
|
||||
costmodel_communi_const = cost_model_context.get_cost_model_context("costmodel_communi_const")
|
||||
|
@ -137,4 +137,5 @@ def test_two_matmul():
|
|||
strategies = _executor._get_strategy(net)
|
||||
expected_strategies = {'Default/network-Net/MatMul-op2': [[16, 1], [1, 1]],
|
||||
'Default/network-Net/MatMul-op3': [[16, 1], [1, 1]]}
|
||||
assert strategies == expected_strategies
|
||||
assert strategies == expected_strategies
|
||||
|
||||
|
|
Loading…
Reference in New Issue