refactor get cnode strategy

This commit is contained in:
Ziyan 2020-10-22 20:09:33 +08:00
parent b802ce7f8f
commit 069318899a
11 changed files with 75 additions and 69 deletions

View File

@ -54,31 +54,6 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
return dict;
}
py::dict GetCNodeStrategy(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
py::dict dict;
auto ret = graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto nodes = DeepScopedGraphSearch(ret);
for (auto node : nodes) {
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto distributed_operation_info = cnode->user_data<OperatorInfo>();
if (distributed_operation_info != nullptr) {
auto strategyPtr = distributed_operation_info->strategy();
if (strategyPtr != nullptr) {
auto strategy = strategyPtr->GetInputDim();
auto name = cnode->fullname_with_scope();
dict[py::str(name)] = strategy;
}
}
}
}
return dict;
}
py::dict GetAllreduceFusion(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
py::dict dict;

View File

@ -25,7 +25,6 @@ namespace py = pybind11;
namespace mindspore {
namespace parallel {
py::dict GetParameterLayout(const FuncGraphPtr &graph);
py::dict GetCNodeStrategy(const FuncGraphPtr &graph);
py::dict GetAllreduceFusion(const FuncGraphPtr &graph);
} // namespace parallel
} // namespace mindspore

View File

@ -1524,7 +1524,6 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
// Get global rank after the checkpoint?
int32_t global_rank = ParallelContext::GetInstance()->global_rank();
std::vector<int32_t> stages = ParallelContext::GetInstance()->stage();
for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
@ -2478,17 +2477,32 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG
InsertNode(op, node, 2, pre_node, root, "shape");
}
void HandleRootReshape(const std::vector<AnfNodePtr> &all_nodes) {
void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes) {
// If root graph has reshape op. Find the corresponding parameter.
// Reshape's shape is the shape of the parameter.
auto executor = pipeline::ExecutorPy::GetInstance();
for (auto &node : all_nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0)) || cnode->in_forward_flag()) {
if (!IsValueNode<Primitive>(cnode->input(0)) || cnode == nullptr) {
continue;
}
if (cnode->in_forward_flag()) {
// Save strategy in executor
OperatorInfoPtr op_info = cnode->user_data<OperatorInfo>();
if (op_info) {
auto stra_ptr = op_info->strategy();
if (stra_ptr) {
auto strategy = stra_ptr->GetInputDim();
// fullname with scope should be found in step parallel end ir
executor->SetCNodeStrategy(cnode->fullname_with_scope(), strategy);
}
}
continue;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->name() != RESHAPE) {
continue;
@ -2844,7 +2858,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
ReshapeInit(all_nodes);
}
HandleRootReshape(all_nodes);
HandleRootReshapeAndSaveStrategy(all_nodes);
HandleForwardMakeTupleAndMakeList(all_nodes);

View File

@ -29,6 +29,7 @@
#include "frontend/optimizer/opt.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "pipeline/jit/pipeline.h"
using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;

View File

@ -243,9 +243,12 @@ py::dict ExecutorPy::GetParameterLayout(const std::string &phase) {
py::dict ExecutorPy::GetCNodeStrategy(const std::string &phase) {
MS_LOG(DEBUG) << "GetCNodeStrategy!";
std::string layout_graph = phase + kStepParallelGraph;
auto graph = GetFuncGraph(layout_graph);
return mindspore::parallel::GetCNodeStrategy(graph);
return stra_dict_[phase];
}
void ExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy) {
MS_LOG(DEBUG) << "SetCNodeStrategy!";
stra_dict_[phase_][py::str(name)] = strategy;
}
py::dict ExecutorPy::GetAllreduceFusion(const std::string &phase) {
@ -449,6 +452,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
#endif
ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
auto phase_s = py::cast<std::string>(phase);
phase_ = phase_s;
MS_LOG(INFO) << "ExecutorPy compile phase:" << phase_s << "!";
ResourcePtr resource = std::make_shared<Resource>(obj);

View File

@ -92,6 +92,7 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
void RunInitGraph(const py::dict &init_params, const std::string &phase);
py::dict GetParameterLayout(const std::string &phase);
py::dict GetCNodeStrategy(const std::string &phase);
void SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy);
py::dict GetAllreduceFusion(const std::string &phase);
void DelNetRes(const std::string &id);
void ReleaseResource(const py::object &phase);
@ -114,6 +115,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
static std::shared_ptr<ExecutorPy> executor_;
static std::mutex instance_lock_;
static bool debugger_terminate_;
std::map<std::string, py::dict> stra_dict_;
std::string phase_ = "";
};
using ExecutorPyPtr = std::shared_ptr<ExecutorPy>;

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import numpy as np
import mindspore as ms
@ -96,15 +97,15 @@ def test_all_to_all():
_reset_op_id()
strategys = all_to_all_common(strategy1)
print(strategys)
expect_dict = {'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits'
'/SoftmaxCrossEntropyWithLogits-op3': [[8, 1], [8, 1]],
'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/'
'OneHot-op4': [[8, 1], [], []],
'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_backbone-AllToAllNet/Transpose-op1': [
[8, 1]],
'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_backbone-AllToAllNet/MatMul-op0': [
[1, 1], [1, 8]]}
assert strategys == expect_dict
for (k, v) in strategys.items():
if re.search('SoftmaxCrossEntropyWithLogits-op', k) is not None:
assert v == [[8, 1], [8, 1]]
elif re.search('OneHot-op', k) is not None:
assert v == [[8, 1], [], []]
elif re.search('Transpose-op', k) is not None:
assert v == [[8, 1]]
elif re.search('MatMul-op', k) is not None:
assert v == [[1, 1], [1, 8]]
context.set_context(save_graphs=False)

View File

@ -77,8 +77,8 @@ def test_auto_parallel_arithmetic():
b = Tensor(np.ones([64, 128]), dtype=ms.float32)
compile_net(net, x, y, b, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[2, 4], [2, 4]],
'Default/network-Net/MatMul-op1': [[2, 1], [1, 4]]}
expected_strategies = {'Default/network-Net/FloorDiv-op1': [[2, 4], [2, 4]],
'Default/network-Net/MatMul-op0': [[2, 1], [1, 4]]}
assert strategies == expected_strategies
@ -104,8 +104,8 @@ def test_auto_parallel_arithmetic_broadcast_both():
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
compile_net(net, x, y, b, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op1': [[8, 1], [1, 1]]}
expected_strategies = {'Default/network-Net/FloorDiv-op1': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op0': [[8, 1], [1, 1]]}
assert strategies == expected_strategies
@ -131,8 +131,8 @@ def test_auto_parallel_arithmetic_broadcast_right():
b = Tensor(np.ones([32]), dtype=ms.float32)
compile_net(net, x, y, b, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [2]],
'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]}
expected_strategies = {'Default/network-Net/FloorDiv-op1': [[4, 2], [2]],
'Default/network-Net/MatMul-op0': [[4, 1], [1, 2]]}
assert strategies == expected_strategies
@ -158,6 +158,6 @@ def test_auto_parallel_arithmetic_broadcast_left():
b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
compile_net(net, x, y, b, phase="train")
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [1, 4, 2]],
'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]}
expected_strategies = {'Default/network-Net/FloorDiv-op1': [[4, 2], [1, 4, 2]],
'Default/network-Net/MatMul-op0': [[4, 1], [1, 2]]}
assert strategies == expected_strategies

View File

@ -86,6 +86,6 @@ def test_double_star_graph():
expected_strategies = {'Default/network-Net/Cast-op0': [[8, 1]],
'Default/network-Net/Cast-op1': [[1, 8]],
'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op4': [[1, 1], [1, 8]],
'Default/network-Net/MatMul-op2': [[1, 8], [8, 1]]}
'Default/network-Net/MatMul-op2': [[1, 1], [1, 8]],
'Default/network-Net/MatMul-op4': [[1, 8], [8, 1]]}
assert strategies == expected_strategies

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import re
import numpy as np
import mindspore as ms
@ -114,12 +115,16 @@ def test_double_subgraphs():
reset_op_id()
_executor.compile(net, x, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op0': [[8, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/ReLU-op1': [[8, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/Mul-op2': [[8, 1, 1, 1], [8, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]],
'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]}
assert strategies == expected_strategies
for (k, v) in strategies.items():
if re.search('ReduceMean-op', k) is not None:
assert v == [[8, 1, 1, 1]]
elif re.search('ReLU-op', k) is not None:
assert v == [[8, 1, 1, 1]]
elif re.search('Mul-op', k) is not None:
assert v == [[8, 1, 1, 1], [8, 1, 1, 1]]
elif re.search('ReduceSum-op', k) is not None:
assert v == [[8, 1, 1, 1]]
class DatasetLenet():
def __init__(self, predict, label, length=3):
@ -160,10 +165,14 @@ def test_double_subgraphs_train():
model = Model(net)
model.train(1, ds_train, dataset_sink_mode=False)
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op3': [[1, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/ReLU-op4': [[1, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/Mul-op5': [[1, 1, 1, 1], [1, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/Mul-op6': [[1, 1, 1, 1], [1, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/Cast-op1': [[1, 1, 1, 1]],
'Default/network-NetWithLoss/ReduceSum-op7': [[1, 1, 1, 1]]}
assert strategies == expected_strategies
for (k, v) in strategies.items():
if re.search('ReduceMean-op', k) is not None:
assert v == [[1, 1, 1, 1]]
elif re.search('ReLU-op', k) is not None:
assert v == [[1, 1, 1, 1]]
elif re.search('Mul-op', k) is not None:
assert v == [[1, 1, 1, 1], [1, 1, 1, 1]]
elif re.search('Cast-op', k) is not None:
assert v == [[1, 1, 1, 1]]
elif re.search('ReduceSum-op', k) is not None:
assert v == [[1, 1, 1, 1]]

View File

@ -78,8 +78,8 @@ def test_two_matmul_transpose():
_executor.compile(net, x, y, b, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/Transpose-op0': [[1, 16]],
'Default/network-Net/Transpose-op1': [[16, 1]],
'Default/network-Net/MatMul-op2': [[16, 1], [1, 1]],
'Default/network-Net/MatMul-op3': [[16, 1], [1, 1]]}
expected_strategies = {'Default/network-Net/Transpose-op3': [[1, 16]],
'Default/network-Net/Transpose-op2': [[16, 1]],
'Default/network-Net/MatMul-op0': [[16, 1], [1, 1]],
'Default/network-Net/MatMul-op1': [[16, 1], [1, 1]]}
assert strategies == expected_strategies