forked from mindspore-Ecosystem/mindspore
refactor get cnode strategy
This commit is contained in:
parent
b802ce7f8f
commit
069318899a
|
@ -54,31 +54,6 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
|
|||
return dict;
|
||||
}
|
||||
|
||||
py::dict GetCNodeStrategy(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
py::dict dict;
|
||||
auto ret = graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto nodes = DeepScopedGraphSearch(ret);
|
||||
|
||||
for (auto node : nodes) {
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto distributed_operation_info = cnode->user_data<OperatorInfo>();
|
||||
if (distributed_operation_info != nullptr) {
|
||||
auto strategyPtr = distributed_operation_info->strategy();
|
||||
if (strategyPtr != nullptr) {
|
||||
auto strategy = strategyPtr->GetInputDim();
|
||||
auto name = cnode->fullname_with_scope();
|
||||
dict[py::str(name)] = strategy;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return dict;
|
||||
}
|
||||
|
||||
py::dict GetAllreduceFusion(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
py::dict dict;
|
||||
|
|
|
@ -25,7 +25,6 @@ namespace py = pybind11;
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
py::dict GetParameterLayout(const FuncGraphPtr &graph);
|
||||
py::dict GetCNodeStrategy(const FuncGraphPtr &graph);
|
||||
py::dict GetAllreduceFusion(const FuncGraphPtr &graph);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1524,7 +1524,6 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
// Get global rank after the checkpoint?
|
||||
int32_t global_rank = ParallelContext::GetInstance()->global_rank();
|
||||
std::vector<int32_t> stages = ParallelContext::GetInstance()->stage();
|
||||
|
||||
for (auto &node : all_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
|
@ -2478,17 +2477,32 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG
|
|||
InsertNode(op, node, 2, pre_node, root, "shape");
|
||||
}
|
||||
|
||||
void HandleRootReshape(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
// If root graph has reshape op. Find the corresponding parameter.
|
||||
// Reshape's shape is the shape of the parameter.
|
||||
auto executor = pipeline::ExecutorPy::GetInstance();
|
||||
for (auto &node : all_nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!IsValueNode<Primitive>(cnode->input(0)) || cnode->in_forward_flag()) {
|
||||
if (!IsValueNode<Primitive>(cnode->input(0)) || cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (cnode->in_forward_flag()) {
|
||||
// Save strategy in executor
|
||||
OperatorInfoPtr op_info = cnode->user_data<OperatorInfo>();
|
||||
if (op_info) {
|
||||
auto stra_ptr = op_info->strategy();
|
||||
if (stra_ptr) {
|
||||
auto strategy = stra_ptr->GetInputDim();
|
||||
// fullname with scope should be found in step parallel end ir
|
||||
executor->SetCNodeStrategy(cnode->fullname_with_scope(), strategy);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim->name() != RESHAPE) {
|
||||
continue;
|
||||
|
@ -2844,7 +2858,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
ReshapeInit(all_nodes);
|
||||
}
|
||||
|
||||
HandleRootReshape(all_nodes);
|
||||
HandleRootReshapeAndSaveStrategy(all_nodes);
|
||||
|
||||
HandleForwardMakeTupleAndMakeList(all_nodes);
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "frontend/optimizer/opt.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#include "pipeline/jit/pipeline.h"
|
||||
|
||||
using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;
|
||||
|
||||
|
|
|
@ -243,9 +243,12 @@ py::dict ExecutorPy::GetParameterLayout(const std::string &phase) {
|
|||
|
||||
py::dict ExecutorPy::GetCNodeStrategy(const std::string &phase) {
|
||||
MS_LOG(DEBUG) << "GetCNodeStrategy!";
|
||||
std::string layout_graph = phase + kStepParallelGraph;
|
||||
auto graph = GetFuncGraph(layout_graph);
|
||||
return mindspore::parallel::GetCNodeStrategy(graph);
|
||||
return stra_dict_[phase];
|
||||
}
|
||||
|
||||
void ExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy) {
|
||||
MS_LOG(DEBUG) << "SetCNodeStrategy!";
|
||||
stra_dict_[phase_][py::str(name)] = strategy;
|
||||
}
|
||||
|
||||
py::dict ExecutorPy::GetAllreduceFusion(const std::string &phase) {
|
||||
|
@ -449,6 +452,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
|
|||
#endif
|
||||
ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
|
||||
auto phase_s = py::cast<std::string>(phase);
|
||||
phase_ = phase_s;
|
||||
MS_LOG(INFO) << "ExecutorPy compile phase:" << phase_s << "!";
|
||||
ResourcePtr resource = std::make_shared<Resource>(obj);
|
||||
|
||||
|
|
|
@ -92,6 +92,7 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
|
|||
void RunInitGraph(const py::dict &init_params, const std::string &phase);
|
||||
py::dict GetParameterLayout(const std::string &phase);
|
||||
py::dict GetCNodeStrategy(const std::string &phase);
|
||||
void SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy);
|
||||
py::dict GetAllreduceFusion(const std::string &phase);
|
||||
void DelNetRes(const std::string &id);
|
||||
void ReleaseResource(const py::object &phase);
|
||||
|
@ -114,6 +115,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
|
|||
static std::shared_ptr<ExecutorPy> executor_;
|
||||
static std::mutex instance_lock_;
|
||||
static bool debugger_terminate_;
|
||||
std::map<std::string, py::dict> stra_dict_;
|
||||
std::string phase_ = "";
|
||||
};
|
||||
using ExecutorPyPtr = std::shared_ptr<ExecutorPy>;
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
|
@ -96,15 +97,15 @@ def test_all_to_all():
|
|||
_reset_op_id()
|
||||
strategys = all_to_all_common(strategy1)
|
||||
print(strategys)
|
||||
expect_dict = {'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits'
|
||||
'/SoftmaxCrossEntropyWithLogits-op3': [[8, 1], [8, 1]],
|
||||
'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/'
|
||||
'OneHot-op4': [[8, 1], [], []],
|
||||
'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_backbone-AllToAllNet/Transpose-op1': [
|
||||
[8, 1]],
|
||||
'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_backbone-AllToAllNet/MatMul-op0': [
|
||||
[1, 1], [1, 8]]}
|
||||
assert strategys == expect_dict
|
||||
for (k, v) in strategys.items():
|
||||
if re.search('SoftmaxCrossEntropyWithLogits-op', k) is not None:
|
||||
assert v == [[8, 1], [8, 1]]
|
||||
elif re.search('OneHot-op', k) is not None:
|
||||
assert v == [[8, 1], [], []]
|
||||
elif re.search('Transpose-op', k) is not None:
|
||||
assert v == [[8, 1]]
|
||||
elif re.search('MatMul-op', k) is not None:
|
||||
assert v == [[1, 1], [1, 8]]
|
||||
context.set_context(save_graphs=False)
|
||||
|
||||
|
||||
|
|
|
@ -77,8 +77,8 @@ def test_auto_parallel_arithmetic():
|
|||
b = Tensor(np.ones([64, 128]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b, phase='train')
|
||||
strategies = _executor._get_shard_strategy(net)
|
||||
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[2, 4], [2, 4]],
|
||||
'Default/network-Net/MatMul-op1': [[2, 1], [1, 4]]}
|
||||
expected_strategies = {'Default/network-Net/FloorDiv-op1': [[2, 4], [2, 4]],
|
||||
'Default/network-Net/MatMul-op0': [[2, 1], [1, 4]]}
|
||||
assert strategies == expected_strategies
|
||||
|
||||
|
||||
|
@ -104,8 +104,8 @@ def test_auto_parallel_arithmetic_broadcast_both():
|
|||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b, phase='train')
|
||||
strategies = _executor._get_shard_strategy(net)
|
||||
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[8, 1], [1, 1]],
|
||||
'Default/network-Net/MatMul-op1': [[8, 1], [1, 1]]}
|
||||
expected_strategies = {'Default/network-Net/FloorDiv-op1': [[8, 1], [1, 1]],
|
||||
'Default/network-Net/MatMul-op0': [[8, 1], [1, 1]]}
|
||||
assert strategies == expected_strategies
|
||||
|
||||
|
||||
|
@ -131,8 +131,8 @@ def test_auto_parallel_arithmetic_broadcast_right():
|
|||
b = Tensor(np.ones([32]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b, phase='train')
|
||||
strategies = _executor._get_shard_strategy(net)
|
||||
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [2]],
|
||||
'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]}
|
||||
expected_strategies = {'Default/network-Net/FloorDiv-op1': [[4, 2], [2]],
|
||||
'Default/network-Net/MatMul-op0': [[4, 1], [1, 2]]}
|
||||
assert strategies == expected_strategies
|
||||
|
||||
|
||||
|
@ -158,6 +158,6 @@ def test_auto_parallel_arithmetic_broadcast_left():
|
|||
b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b, phase="train")
|
||||
strategies = _executor._get_shard_strategy(net)
|
||||
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [1, 4, 2]],
|
||||
'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]}
|
||||
expected_strategies = {'Default/network-Net/FloorDiv-op1': [[4, 2], [1, 4, 2]],
|
||||
'Default/network-Net/MatMul-op0': [[4, 1], [1, 2]]}
|
||||
assert strategies == expected_strategies
|
||||
|
|
|
@ -86,6 +86,6 @@ def test_double_star_graph():
|
|||
expected_strategies = {'Default/network-Net/Cast-op0': [[8, 1]],
|
||||
'Default/network-Net/Cast-op1': [[1, 8]],
|
||||
'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]],
|
||||
'Default/network-Net/MatMul-op4': [[1, 1], [1, 8]],
|
||||
'Default/network-Net/MatMul-op2': [[1, 8], [8, 1]]}
|
||||
'Default/network-Net/MatMul-op2': [[1, 1], [1, 8]],
|
||||
'Default/network-Net/MatMul-op4': [[1, 8], [8, 1]]}
|
||||
assert strategies == expected_strategies
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
|
@ -114,12 +115,16 @@ def test_double_subgraphs():
|
|||
reset_op_id()
|
||||
_executor.compile(net, x, phase='train')
|
||||
strategies = _executor._get_shard_strategy(net)
|
||||
expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op0': [[8, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/net-Net/ReLU-op1': [[8, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/net-Net/Mul-op2': [[8, 1, 1, 1], [8, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]}
|
||||
assert strategies == expected_strategies
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('ReduceMean-op', k) is not None:
|
||||
assert v == [[8, 1, 1, 1]]
|
||||
elif re.search('ReLU-op', k) is not None:
|
||||
assert v == [[8, 1, 1, 1]]
|
||||
elif re.search('Mul-op', k) is not None:
|
||||
assert v == [[8, 1, 1, 1], [8, 1, 1, 1]]
|
||||
elif re.search('ReduceSum-op', k) is not None:
|
||||
assert v == [[8, 1, 1, 1]]
|
||||
|
||||
|
||||
class DatasetLenet():
|
||||
def __init__(self, predict, label, length=3):
|
||||
|
@ -160,10 +165,14 @@ def test_double_subgraphs_train():
|
|||
model = Model(net)
|
||||
model.train(1, ds_train, dataset_sink_mode=False)
|
||||
strategies = _executor._get_shard_strategy(net)
|
||||
expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op3': [[1, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/net-Net/ReLU-op4': [[1, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/net-Net/Mul-op5': [[1, 1, 1, 1], [1, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/net-Net/Mul-op6': [[1, 1, 1, 1], [1, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/net-Net/Cast-op1': [[1, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/ReduceSum-op7': [[1, 1, 1, 1]]}
|
||||
assert strategies == expected_strategies
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('ReduceMean-op', k) is not None:
|
||||
assert v == [[1, 1, 1, 1]]
|
||||
elif re.search('ReLU-op', k) is not None:
|
||||
assert v == [[1, 1, 1, 1]]
|
||||
elif re.search('Mul-op', k) is not None:
|
||||
assert v == [[1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
elif re.search('Cast-op', k) is not None:
|
||||
assert v == [[1, 1, 1, 1]]
|
||||
elif re.search('ReduceSum-op', k) is not None:
|
||||
assert v == [[1, 1, 1, 1]]
|
||||
|
|
|
@ -78,8 +78,8 @@ def test_two_matmul_transpose():
|
|||
|
||||
_executor.compile(net, x, y, b, phase='train')
|
||||
strategies = _executor._get_shard_strategy(net)
|
||||
expected_strategies = {'Default/network-Net/Transpose-op0': [[1, 16]],
|
||||
'Default/network-Net/Transpose-op1': [[16, 1]],
|
||||
'Default/network-Net/MatMul-op2': [[16, 1], [1, 1]],
|
||||
'Default/network-Net/MatMul-op3': [[16, 1], [1, 1]]}
|
||||
expected_strategies = {'Default/network-Net/Transpose-op3': [[1, 16]],
|
||||
'Default/network-Net/Transpose-op2': [[16, 1]],
|
||||
'Default/network-Net/MatMul-op0': [[16, 1], [1, 1]],
|
||||
'Default/network-Net/MatMul-op1': [[16, 1], [1, 1]]}
|
||||
assert strategies == expected_strategies
|
||||
|
|
Loading…
Reference in New Issue