pipeline_adapt_interleaved

This commit is contained in:
lichenever 2021-11-29 10:42:57 +08:00
parent 8abc711298
commit 8207242ebf
5 changed files with 21 additions and 33 deletions

View File

@ -428,6 +428,7 @@ constexpr char ACCU_GRAD[] = "accu_grad";
constexpr char PARAMETER_START[] = "parameter_start";
constexpr char PARAM_INDEX[] = "param_index";
constexpr char PARAMETER[] = "parameter";
constexpr char FUNC_GRAPH_FLAG_STRIDED_SLICE[] = "strided_slice_flag";
// Parallel don't care
constexpr char STRING_EQUAL[] = "string_equal";

View File

@ -1754,29 +1754,6 @@ void CoverSliceShape(const FuncGraphPtr &root) {
g_RefMap.clear();
}
void LableBatchSizeSplit(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
FuncGraphPtr func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_user_map = manager->node_users();
auto node_users = node_user_map[node];
for (auto &node_user : node_users) {
if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
auto data_users = manager->node_users()[node_user.first];
auto node_first = data_users.front().first;
for (auto &data_user : data_users) {
PrimitivePtr prim = GetCNodePrimitive(data_user.first);
MS_EXCEPTION_IF_NULL(prim);
if (prim->HasAttr(FUNC_GRAPH_FLAG_STRIDED_SLICE)) {
SetStridedSliceStrategy(data_user.first);
}
}
}
}
}
void SetVirtualDatasetStrategy(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
@ -1923,15 +1900,6 @@ static bool CheckExtractInfomation(const CNodePtr &cnode) {
return true;
}
void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes) {
for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
if (IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset)) {
LableBatchSizeSplit(cnode);
}
}
}
static void ExtractStrategyAndInit(const CNodePtr &cnode, const PrimitivePtr &prim, const OperatorInfoPtr &op_info) {
StrategyPtr in_strategy = nullptr, out_strategy = nullptr;
auto attrs = prim->attrs();

View File

@ -35,7 +35,6 @@
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;
const char FUNC_GRAPH_FLAG_STRIDED_SLICE[] = "strided_slice_flag";
namespace mindspore {
namespace parallel {

View File

@ -34,6 +34,7 @@
#include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/graph_util/graph_info.h"
#include "frontend/parallel/graph_util/node_info.h"
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
#include "frontend/parallel/node_check.h"
#include "ir/param_info.h"
#include "ir/tensor.h"
@ -226,5 +227,23 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
SetCommunicationOpGroupLabel(replace_input);
return replace_input;
}
void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes) {
for (auto &node : all_nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsPrimitiveCNode(cnode, prim::kPrimStridedSlice)) {
continue;
}
auto slice_prim = GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(slice_prim);
if (slice_prim->HasAttr(FUNC_GRAPH_FLAG_STRIDED_SLICE)) {
SetStridedSliceStrategy(cnode);
}
}
}
} // namespace parallel
} // namespace mindspore

View File

@ -32,6 +32,7 @@ std::string CreateInstanceName(const CNodePtr &node, size_t index);
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
const CNodePtr &node);
void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes);
} // namespace parallel
} // namespace mindspore