forked from mindspore-Ecosystem/mindspore
pipeline_adapt_interleaved
This commit is contained in:
parent
8abc711298
commit
8207242ebf
|
@ -428,6 +428,7 @@ constexpr char ACCU_GRAD[] = "accu_grad";
|
|||
constexpr char PARAMETER_START[] = "parameter_start";
|
||||
constexpr char PARAM_INDEX[] = "param_index";
|
||||
constexpr char PARAMETER[] = "parameter";
|
||||
constexpr char FUNC_GRAPH_FLAG_STRIDED_SLICE[] = "strided_slice_flag";
|
||||
|
||||
// Parallel don't care
|
||||
constexpr char STRING_EQUAL[] = "string_equal";
|
||||
|
|
|
@ -1754,29 +1754,6 @@ void CoverSliceShape(const FuncGraphPtr &root) {
|
|||
g_RefMap.clear();
|
||||
}
|
||||
|
||||
void LableBatchSizeSplit(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
FuncGraphPtr func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto node_user_map = manager->node_users();
|
||||
auto node_users = node_user_map[node];
|
||||
for (auto &node_user : node_users) {
|
||||
if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
|
||||
auto data_users = manager->node_users()[node_user.first];
|
||||
auto node_first = data_users.front().first;
|
||||
for (auto &data_user : data_users) {
|
||||
PrimitivePtr prim = GetCNodePrimitive(data_user.first);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->HasAttr(FUNC_GRAPH_FLAG_STRIDED_SLICE)) {
|
||||
SetStridedSliceStrategy(data_user.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
|
@ -1923,15 +1900,6 @@ static bool CheckExtractInfomation(const CNodePtr &cnode) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
for (auto &node : all_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset)) {
|
||||
LableBatchSizeSplit(cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ExtractStrategyAndInit(const CNodePtr &cnode, const PrimitivePtr &prim, const OperatorInfoPtr &op_info) {
|
||||
StrategyPtr in_strategy = nullptr, out_strategy = nullptr;
|
||||
auto attrs = prim->attrs();
|
||||
|
|
|
@ -35,7 +35,6 @@
|
|||
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
|
||||
|
||||
using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;
|
||||
const char FUNC_GRAPH_FLAG_STRIDED_SLICE[] = "strided_slice_flag";
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "frontend/parallel/graph_util/graph_info.h"
|
||||
#include "frontend/parallel/graph_util/node_info.h"
|
||||
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
|
||||
#include "frontend/parallel/node_check.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/tensor.h"
|
||||
|
@ -226,5 +227,23 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
|
|||
SetCommunicationOpGroupLabel(replace_input);
|
||||
return replace_input;
|
||||
}
|
||||
|
||||
void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
for (auto &node : all_nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsPrimitiveCNode(cnode, prim::kPrimStridedSlice)) {
|
||||
continue;
|
||||
}
|
||||
auto slice_prim = GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(slice_prim);
|
||||
if (slice_prim->HasAttr(FUNC_GRAPH_FLAG_STRIDED_SLICE)) {
|
||||
SetStridedSliceStrategy(cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,6 +32,7 @@ std::string CreateInstanceName(const CNodePtr &node, size_t index);
|
|||
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
|
||||
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
|
||||
const CNodePtr &node);
|
||||
void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue