!15396 add parallel virtual output in eval/predict

From: @yao_yf
Reviewed-by: @yangzhenzhang,@stsuteng
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2021-04-25 17:43:13 +08:00 committed by Gitee
commit 7dc9eb3e0b
19 changed files with 504 additions and 88 deletions

View File

@ -201,6 +201,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Virtual Dataset
virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
"virtual_dataset_eliminate", prim::kPrimVirtualDataset);
// Virtual Dataset
virtual_output_eliminate_ =
MakeSubstitution(std::make_shared<VirtualOutputEliminater>(), "virtual_output_eliminate", prim::kPrimVirtualOutput);
// Receive
receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive);

View File

@ -118,6 +118,8 @@ class OptimizeIRPassLib {
// virtual dataset
SubstitutionPtr virtual_dataset_eliminate_;
// virtual output
SubstitutionPtr virtual_output_eliminate_;
// Receive
SubstitutionPtr receive_eliminate_;

View File

@ -99,6 +99,24 @@ class VirtualDatasetEliminater : public AnfVisitor {
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimVirtualOutput, X} -> X
class VirtualOutputEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimVirtualOutput) || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() <= 1) {
return nullptr;
}
return cnode->input(1);
}
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimReceive, X} -> prim::kPrimReceive
class ReceiveEliminater : public AnfVisitor {
public:

View File

@ -127,7 +127,7 @@ double MatMulCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, co
// this operator uses
double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int64_t) const {
// In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
// In forward phase, the computation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
double result = 0.0;
TensorInfo output0 = outputs[0];
Shape input0_slice_shape = inputs[0].slice_shape();
@ -368,7 +368,7 @@ void ReLU6Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_outpu
// Taking account of input
void TransposeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
// When calulating 'dx', taking account of 'y'
// When calculating 'dx', taking account of 'y'
if (is_parameter_[0]) {
is_inputs_should_in_memory_[0] = true;
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {

View File

@ -195,6 +195,7 @@ REGISTER(SelectInfo);
REGISTER(GatherNdInfo);
REGISTER(TopKInfo);
REGISTER(ScatterUpdateInfo);
REGISTER(VirtualOutputInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -54,5 +54,6 @@
#include "frontend/parallel/ops_info/gathernd_info.h"
#include "frontend/parallel/ops_info/topk_info.h"
#include "frontend/parallel/ops_info/scatter_update_info.h"
#include "frontend/parallel/ops_info/virtual_output_info.h"
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_

View File

@ -218,6 +218,7 @@ constexpr char ASSIGN_SUB[] = "AssignSub";
constexpr char GREATER[] = "Greater";
constexpr char UNIFORM_CANDIDATE_SAMPLER[] = "UniformCandidateSampler";
constexpr char VIRTUAL_DATA_SET[] = "_VirtualDataset";
constexpr char VIRTUAL_OUTPUT[] = "_VirtualOutput";
constexpr char VIRTUAL_DATA_SET_INFO[] = "VirtualDatasetInfo";
constexpr char SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SparseSoftmaxCrossEntropyWithLogits";
constexpr char RELU[] = "ReLU";

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/virtual_output_info.h"
#include <memory>
#include <utility>
#include <vector>
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/context.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace parallel {
Status VirtualOutputInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
return FAILED;
}
Strategys stra = strategy->GetInputDim();
if (stra.size() != 1) {
MS_LOG(ERROR) << name_ << ": Strategys size must be 1.";
return FAILED;
}
Dimensions strategy_first = stra.at(0);
for (auto dim = strategy_first.begin() + 1; dim != strategy_first.end(); ++dim) {
if (*dim != 1) {
MS_LOG(ERROR) << name_ << ": All dimension except the first dimension of the strategy must be 1.";
return FAILED;
}
}
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PARALLEL_OPS_INFO_OUTPUT_INFO_H_
#define PARALLEL_OPS_INFO_OUTPUT_INFO_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/ops_info/virtual_dataset_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
class VirtualOutputInfo : public VirtualDatasetInfo {
public:
VirtualOutputInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: VirtualDatasetInfo(name, inputs_shape, outputs_shape, attrs) {}
~VirtualOutputInfo() override = default;
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
};
} // namespace parallel
} // namespace mindspore
#endif // PARALLEL_OPS_INFO_VIRTUAL_OUTPUT_INFO_H_

View File

@ -69,6 +69,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) {
return changes;
}
// check whether strategy_search_mode is valid
std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) {
@ -87,14 +88,17 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
TOTAL_OPS = 0;
AnfNodePtr ret = root->get_return();
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
if (ParallelInit() != SUCCESS) {
MS_LOG(EXCEPTION) << "Parallel init failed";
}
// mark the forward cnodes, parallel only care these nodes
MarkForwardCNode(root);
if (!root->has_flag(TRAINING)) {
InsertVirtualOutput(root, all_nodes);
AnfNodePtr ret_after = root->get_return();
MS_EXCEPTION_IF_NULL(ret_after);
all_nodes = DeepScopedGraphSearch(ret_after);
}
if (FindCommunicationOp(all_nodes)) {
MS_LOG(EXCEPTION) << "The graph contain communication op";
}
@ -163,7 +167,7 @@ bool IsSplittableOperator(const std::string &op_name) {
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT,
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE};
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT};
// clang-format on
auto iter = splittable_op.find(op_name);
@ -239,13 +243,7 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive
StrategyMap *stra_map, const std::string &strategy_key_name) {
// In this case, the configured strategy should be extracted to help setting cost
StrategyPtr strategyPtr;
if (is_last_nodes) {
bool full_batch = ParallelContext::GetInstance()->full_batch();
strategyPtr = GenerateBatchParallelStrategy(operator_info, prim);
if (full_batch) {
SetLastNodeStrategy(strategyPtr);
}
} else if (StrategyFound(attrs)) {
if (StrategyFound(attrs)) {
strategyPtr = parallel::ExtractStrategy(attrs);
} else {
strategyPtr = (*stra_map)[strategy_key_name];
@ -332,10 +330,9 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
// If no strategy has been configured for this operator, then candidate strategies are generated for
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy;
// if strategy is set to load from checkpoint, it is preferred to load strategy from checkpoint.
bool is_gen_stra = (!StrategyFound(attrs) || prim->name() == CAST) && (!load_strategy_from_ckpt) && (!is_last_nodes);
if (is_gen_stra) {
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
// if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
// BatchParallelInfo operator
operator_info->ComputeBatchSplitFlagList();
@ -371,11 +368,6 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
}
std::vector<std::string> last_forward_node_ids;
if (!root->has_flag(TRAINING)) {
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
}
for (auto &node : all_nodes) {
// NOTE: we only care about splittable Primitive operators
@ -421,8 +413,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr));
continue;
}
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
last_forward_node_ids.end();
bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput);
auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
if (operator_info == nullptr) {
return FAILED;
@ -496,11 +487,6 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
std::vector<std::string> last_forward_node_ids;
if (!root->has_flag(TRAINING)) {
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
}
for (auto &node : all_nodes) {
// NOTE: we only care about splittable Primitive operators
auto cnode = node->cast<CNodePtr>();
@ -546,8 +532,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
continue;
}
// In this case, the corresponding OperatorInfo is not created, create the new one.
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
last_forward_node_ids.end();
bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput);
auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
MS_EXCEPTION_IF_NULL(operator_info);

View File

@ -625,7 +625,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) {
return false;
}
// get_next is not in the forward graph, we need mark the get_next as the forward node
if (prim->name() == GET_NEXT) {
if (prim->name() == GET_NEXT || prim->name() == VIRTUAL_OUTPUT) {
return true;
}
if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) {
@ -1004,6 +1004,55 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node
}
}
void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
vector<std::string> last_forward_node_ids;
vector<size_t> last_indexs;
FindLastNodesUniqueId(root, &last_forward_node_ids, &last_indexs);
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
for (auto &node : all_nodes) {
// here insert virtualoutput node
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
auto last_node_iter = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId());
if (last_node_iter == last_forward_node_ids.end()) {
continue;
}
for (size_t last_node_index = 0; last_node_index < last_forward_node_ids.size(); ++last_node_index) {
if (last_forward_node_ids[last_node_index] != cnode->UniqueId()) {
continue;
}
MS_LOG(INFO) << "find last node: " << cnode->fullname_with_scope() << ", the parallel care node is: "
<< cnode->input(last_indexs[last_node_index])->fullname_with_scope();
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_pair = manager->node_users()[cnode].front();
if (!node_pair.first->isa<CNode>()) {
MS_LOG(EXCEPTION) << "the output of tuple_get_item is not a cnode";
}
cnode = node_pair.first->cast<CNodePtr>();
last_indexs[last_node_index] = size_t(node_pair.second);
}
FuncGraphPtr func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
OperatorParams params;
OperatorAttrs attrs;
OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(VIRTUAL_OUTPUT, args);
auto pre_node = cnode->input(last_indexs[last_node_index]);
Shapes shape_outputs = GetNodeShape(pre_node);
InsertNode(op, cnode, last_indexs[last_node_index], pre_node, func_graph, VIRTUAL_OUTPUT);
auto virtual_output_node = cnode->input(last_indexs[last_node_index]);
AbstractBasePtr virtual_output_abstract = pre_node->abstract()->Clone();
std::shared_ptr<abstract::BaseShape> virtual_output_shape = std::make_shared<abstract::Shape>(shape_outputs[0]);
virtual_output_abstract->set_shape(virtual_output_shape);
virtual_output_node->set_abstract(virtual_output_abstract);
}
}
}
static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
if (IsValueNode<RefKey>(node)) {
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
@ -1826,7 +1875,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node->input(0));
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == VIRTUAL_DATA_SET) {
if (prim->name() == VIRTUAL_DATA_SET || prim->name() == VIRTUAL_OUTPUT) {
CheckGlobalDeviceManager();
int64_t dev_num;
if (full_batch) {
@ -1856,32 +1905,36 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
}
}
// find previous parallel care node.
bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) {
// find previous parallel care node's next node.
bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vector<size_t> *indexes) {
MS_EXCEPTION_IF_NULL(unique_ids);
// if previous node is a parameter, handle it in the outsize.
if (node->isa<Parameter>()) {
return false;
}
MS_EXCEPTION_IF_NULL(indexes);
if (!node->isa<CNode>()) {
return false;
}
CNodePtr cnode = node->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0))) {
CNodePtr pre_cnode = node->cast<CNodePtr>();
if (!IsValueNode<Primitive>(pre_cnode->input(0))) {
return false;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
if (IsParallelCareNode(cnode) && prim->name() != MAKE_TUPLE && prim->name() != MAKE_LIST) {
unique_ids->push_back(cnode->UniqueId());
return true;
}
bool find = false;
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
if (prim->name() == DEPEND && index != 1) {
for (size_t index = 1; index < pre_cnode->inputs().size(); ++index) {
auto next_node = pre_cnode->inputs()[index];
if (!next_node->isa<CNode>() || next_node->isa<Parameter>()) {
return false;
}
CNodePtr cnode = next_node->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
if (IsParallelCareNode(cnode) && prim->name() != MAKE_TUPLE && prim->name() != MAKE_LIST) {
unique_ids->push_back(pre_cnode->UniqueId());
indexes->push_back(index);
find = true;
continue;
}
if (FindPreNodes(cnode->inputs()[index], unique_ids)) {
if (FindPreNodes(cnode, unique_ids, indexes)) {
find = true;
continue;
}
@ -1889,20 +1942,12 @@ bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) {
return find;
}
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids) {
void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *unique_ids,
std::vector<size_t> *indexes) {
MS_EXCEPTION_IF_NULL(unique_ids);
for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == RETURN) {
if (!FindPreNodes(cnode, unique_ids)) {
MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph";
}
}
CNodePtr cnode = root->get_return();
if (!FindPreNodes(cnode, unique_ids, indexes)) {
MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph";
}
}
@ -1926,16 +1971,6 @@ StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const
return strategyPtr;
}
void SetLastNodeStrategy(const StrategyPtr strategyPtr) {
auto strategys = strategyPtr->GetInputDim();
for (size_t i = 0; i < strategys.size(); ++i) {
for (size_t j = 0; j < strategys[i].size(); ++j) {
strategys[i][j] = 1;
}
}
strategyPtr->ResetInputs(strategys);
}
static bool CheckExtractInfomation(const CNodePtr &cnode) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
return false;
@ -1960,11 +1995,6 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
(StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS)) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
vector<std::string> last_forward_node_ids;
if (!is_training) {
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
}
for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
@ -2012,10 +2042,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
}
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
last_forward_node_ids.end();
bool full_batch = ParallelContext::GetInstance()->full_batch();
if ((is_last_nodes && !full_batch) || (!StrategyFound(attrs) && !load_strategy_from_ckpt)) {
if ((!StrategyFound(attrs) && !load_strategy_from_ckpt)) {
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel";
strategyPtr = GenerateBatchParallelStrategy(operator_, prim);
@ -2026,9 +2053,6 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
}
MS_EXCEPTION_IF_NULL(strategyPtr);
if (is_last_nodes && full_batch) {
SetLastNodeStrategy(strategyPtr);
}
if (operator_->Init(strategyPtr) == FAILED) {
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
}
@ -3537,6 +3561,14 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
MS_LOG(EXCEPTION) << "The graph contain communication op";
}
if (!root->has_flag(TRAINING)) {
InsertVirtualOutput(root, all_nodes);
AnfNodePtr ret_after = root->get_return();
MS_EXCEPTION_IF_NULL(ret_after);
all_nodes = DeepScopedGraphSearch(ret_after);
std::reverse(all_nodes.begin(), all_nodes.end());
}
// extract shape and strategy, set operator_info
ExtractInformation(all_nodes, root->has_flag(TRAINING));
ReshapeInit(all_nodes);

View File

@ -172,7 +172,10 @@ void SetLastNodeStrategy(const StrategyPtr strategyPtr);
bool CreateGroupsByCkptFile(const std::string &file);
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids);
void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *unique_ids,
std::vector<size_t> *indexes);
void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes);
} // namespace parallel
} // namespace mindspore

View File

@ -195,6 +195,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
{"parallel", opt::OptPassConfig(parallel::StepParallel)},
{"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
{"virtual_dataset", virtual_dataset},
{"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})},
{"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())},
{"resolve", resolve_pass},
{"a_after_grad", a_after_grad},

View File

@ -317,6 +317,7 @@ inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared<Primitive>("
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
inline const PrimitivePtr kPrimVirtualAdd = std::make_shared<Primitive>("_VirtualAdd");
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
inline const PrimitivePtr kPrimVirtualOutput = std::make_shared<Primitive>("_VirtualOutput");
inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");
inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive");
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");

View File

@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
Unique, GatherD, Identity, Range)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice, _VirtualAdd,
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd,
_HostAllGather, _HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print, Assert)

View File

@ -670,7 +670,7 @@ class _VirtualDataset(PrimitiveWithInfer):
"""
Auto parallel virtual dataset operator.
It would insert Broadcast operator in forward computation and be deleted before backward computation.
It would insert VirtualDataset operator in forward computation and be deleted before backward computation.
"""
@prim_attr_register
@ -686,6 +686,22 @@ class _VirtualDataset(PrimitiveWithInfer):
virtual_dataset = _VirtualDataset()
class _VirtualOutput(PrimitiveWithInfer):
"""
Auto parallel virtual out operator.
It would insert VirtualOutput operator in forward computation and be deleted before backward computation.
"""
@prim_attr_register
def __init__(self):
"""init"""
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
return x_dtype
class _GetTensorSlice(PrimitiveWithInfer):
"""

View File

@ -74,6 +74,7 @@ def test_two_bn():
net = NetWithLoss(Net())
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
net.set_auto_parallel()
net.set_train()
set_algo_parameters(elementwise_op_strategy_follow=True)
reset_op_id()

View File

@ -158,4 +158,5 @@ def test_only_one_get_next():
context.set_auto_parallel_context(device_num=4, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net = Net()
net.set_train()
compile_net(net)

View File

@ -0,0 +1,252 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import re
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE)
class DenseMutMulNet(nn.Cell):
def __init__(self):
super(DenseMutMulNet, self).__init__()
self.fc1 = nn.Dense(128, 768)
self.fc2 = nn.Dense(128, 768)
self.fc3 = nn.Dense(128, 768)
self.fc4 = nn.Dense(768, 768, has_bias=False)
self.relu4 = nn.ReLU()
self.relu5 = nn.ReLU()
self.transpose = P.Transpose()
self.matmul1 = P.MatMul()
self.matmul2 = P.MatMul()
self.fc4.matmul.shard(((1, 1), (8, 1)))
def construct(self, x):
q = self.fc1(x)
k = self.fc2(x)
v = self.fc3(x)
k = self.transpose(k, (1, 0))
c = self.relu4(self.matmul1(q, k))
s = self.relu5(self.matmul2(c, v))
s = self.fc4(s)
return s
class MulNegTwoOutputNet(nn.Cell):
def __init__(self):
super().__init__()
self.mul = P.Mul().shard(((2, 4), (2, 4)))
self.neg = P.Neg().shard(((2, 4),))
self.mul_weight = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight")
def construct(self, x):
out1 = self.mul(x, self.mul_weight)
out2 = self.neg(out1)
return out1, out2
class ReshapeMatMulNet(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.reshape = P.Reshape()
self.matmul = P.MatMul().shard(strategy2)
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
# x (64, 4, 7)
def construct(self, x):
out = self.reshape(x, (64, 28))
out = self.matmul(out, self.matmul_weight)
return out
class MatMulReshapeNet(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.reshape = P.Reshape()
self.matmul = P.MatMul().shard(strategy1)
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
# x (128, 28)
def construct(self, x):
out = self.matmul(x, self.matmul_weight)
out = self.reshape(out, (64, -1))
return out
class ReshapeMulNet(nn.Cell):
def __init__(self):
super().__init__()
self.reshape = P.Reshape()
self.mul = P.Mul().shard(((1, 2, 4), (2, 4)))
self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
def construct(self, x):
weight = self.reshape(self.mul_weight, (1, 128, 96))
out = self.mul(weight, self.mul_weight)
return out
def compile_graph(x, net):
net.set_auto_parallel()
net.set_train(False)
_executor.compile(net, x, auto_parallel_mode=True)
strategies = _executor._get_shard_strategy(net)
return strategies
def test_dense_relu_semi_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
net = DenseMutMulNet()
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
strategies = compile_graph(x, net)
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
assert v[0][0] == 8
def test_dense_relu_semi_auto_full_batch():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True)
net = DenseMutMulNet()
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
strategies = compile_graph(x, net)
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
assert v[0][0] == 1
def test_dense_relu_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
net = DenseMutMulNet()
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
strategies = compile_graph(x, net)
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
assert v[0][0] == 8
def test_dense_relu_auto_full_batch():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True)
net = DenseMutMulNet()
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
strategies = compile_graph(x, net)
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
assert v[0][0] == 1
def test_mul_neg_two_output_semi_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
net = MulNegTwoOutputNet()
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
strategies = compile_graph(x, net)
count = 0
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
count += 1
assert v[0][0] == 8
assert count == 2
def test_mul_neg_two_output_semi_auto_full_batch():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True)
net = MulNegTwoOutputNet()
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
strategies = compile_graph(x, net)
count = 0
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
count += 1
assert v[0][0] == 1
assert count == 2
def test_mul_neg_two_output_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
net = MulNegTwoOutputNet()
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
strategies = compile_graph(x, net)
count = 0
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
count += 1
assert v[0][0] == 8
assert count == 2
def test_mul_neg_two_output_full_batch():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True)
net = MulNegTwoOutputNet()
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
strategies = compile_graph(x, net)
count = 0
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
count += 1
assert v[0][0] == 1
assert count == 2
def test_reshape_matmul_semi_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
strategy1 = None
strategy2 = ((1, 1), (1, 8))
net = ReshapeMatMulNet(strategy1, strategy2)
x = Tensor(np.ones([64, 4, 7]), ms.float32)
strategies = compile_graph(x, net)
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
assert v[0][0] == 8
def test_reshape_matmul_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
strategy1 = None
strategy2 = ((1, 1), (1, 8))
net = ReshapeMatMulNet(strategy1, strategy2)
x = Tensor(np.ones([64, 4, 7]), ms.float32)
strategies = compile_graph(x, net)
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
assert v[0][0] == 8
def test_matmul_reshape_semi_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
strategy2 = None
strategy1 = ((1, 1), (1, 8))
net = MatMulReshapeNet(strategy1, strategy2)
x = Tensor(np.ones([128, 28]), ms.float32)
strategies = compile_graph(x, net)
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
assert v[0][0] == 8
def test_matmul_reshape_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
strategy2 = None
strategy1 = ((1, 1), (1, 8))
net = MatMulReshapeNet(strategy1, strategy2)
x = Tensor(np.ones([128, 28]), ms.float32)
strategies = compile_graph(x, net)
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
assert v[0][0] == 8
def test_reshape_mul_semi_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True)
net = ReshapeMulNet()
x = Tensor(np.ones([64, 4]), ms.float32)
strategies = compile_graph(x, net)
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
assert v[0][0] == 1
def test_reshape_mul_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True)
net = ReshapeMulNet()
x = Tensor(np.ones([64, 4]), ms.float32)
strategies = compile_graph(x, net)
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
assert v[0][0] == 1