forked from mindspore-Ecosystem/mindspore
!15396 add parallel virtual output in eval/predict
From: @yao_yf Reviewed-by: @yangzhenzhang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
7dc9eb3e0b
|
@ -201,6 +201,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
// Virtual Dataset
|
||||
virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
|
||||
"virtual_dataset_eliminate", prim::kPrimVirtualDataset);
|
||||
// Virtual Dataset
|
||||
virtual_output_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<VirtualOutputEliminater>(), "virtual_output_eliminate", prim::kPrimVirtualOutput);
|
||||
|
||||
// Receive
|
||||
receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive);
|
||||
|
|
|
@ -118,6 +118,8 @@ class OptimizeIRPassLib {
|
|||
// virtual dataset
|
||||
SubstitutionPtr virtual_dataset_eliminate_;
|
||||
|
||||
// virtual output
|
||||
SubstitutionPtr virtual_output_eliminate_;
|
||||
// Receive
|
||||
SubstitutionPtr receive_eliminate_;
|
||||
|
||||
|
|
|
@ -99,6 +99,24 @@ class VirtualDatasetEliminater : public AnfVisitor {
|
|||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
// {prim::kPrimVirtualOutput, X} -> X
|
||||
class VirtualOutputEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimVirtualOutput) || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() <= 1) {
|
||||
return nullptr;
|
||||
}
|
||||
return cnode->input(1);
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
// {prim::kPrimReceive, X} -> prim::kPrimReceive
|
||||
class ReceiveEliminater : public AnfVisitor {
|
||||
public:
|
||||
|
|
|
@ -127,7 +127,7 @@ double MatMulCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, co
|
|||
// this operator uses
|
||||
double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
|
||||
const std::vector<TensorInfo> &outputs, int64_t) const {
|
||||
// In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
|
||||
// In forward phase, the computation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C))
|
||||
double result = 0.0;
|
||||
TensorInfo output0 = outputs[0];
|
||||
Shape input0_slice_shape = inputs[0].slice_shape();
|
||||
|
@ -368,7 +368,7 @@ void ReLU6Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_outpu
|
|||
|
||||
// Taking account of input
|
||||
void TransposeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||
// When calulating 'dx', taking account of 'y'
|
||||
// When calculating 'dx', taking account of 'y'
|
||||
if (is_parameter_[0]) {
|
||||
is_inputs_should_in_memory_[0] = true;
|
||||
if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) {
|
||||
|
|
|
@ -195,6 +195,7 @@ REGISTER(SelectInfo);
|
|||
REGISTER(GatherNdInfo);
|
||||
REGISTER(TopKInfo);
|
||||
REGISTER(ScatterUpdateInfo);
|
||||
REGISTER(VirtualOutputInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -54,5 +54,6 @@
|
|||
#include "frontend/parallel/ops_info/gathernd_info.h"
|
||||
#include "frontend/parallel/ops_info/topk_info.h"
|
||||
#include "frontend/parallel/ops_info/scatter_update_info.h"
|
||||
#include "frontend/parallel/ops_info/virtual_output_info.h"
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
||||
|
|
|
@ -218,6 +218,7 @@ constexpr char ASSIGN_SUB[] = "AssignSub";
|
|||
constexpr char GREATER[] = "Greater";
|
||||
constexpr char UNIFORM_CANDIDATE_SAMPLER[] = "UniformCandidateSampler";
|
||||
constexpr char VIRTUAL_DATA_SET[] = "_VirtualDataset";
|
||||
constexpr char VIRTUAL_OUTPUT[] = "_VirtualOutput";
|
||||
constexpr char VIRTUAL_DATA_SET_INFO[] = "VirtualDatasetInfo";
|
||||
constexpr char SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SparseSoftmaxCrossEntropyWithLogits";
|
||||
constexpr char RELU[] = "ReLU";
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ops_info/virtual_output_info.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status VirtualOutputInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Strategys stra = strategy->GetInputDim();
|
||||
if (stra.size() != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Strategys size must be 1.";
|
||||
return FAILED;
|
||||
}
|
||||
Dimensions strategy_first = stra.at(0);
|
||||
for (auto dim = strategy_first.begin() + 1; dim != strategy_first.end(); ++dim) {
|
||||
if (*dim != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": All dimension except the first dimension of the strategy must be 1.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef PARALLEL_OPS_INFO_OUTPUT_INFO_H_
|
||||
#define PARALLEL_OPS_INFO_OUTPUT_INFO_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/ops_info/virtual_dataset_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class VirtualOutputInfo : public VirtualDatasetInfo {
|
||||
public:
|
||||
VirtualOutputInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: VirtualDatasetInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~VirtualOutputInfo() override = default;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // PARALLEL_OPS_INFO_VIRTUAL_OUTPUT_INFO_H_
|
|
@ -69,6 +69,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
|||
root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) {
|
||||
return changes;
|
||||
}
|
||||
|
||||
// check whether strategy_search_mode is valid
|
||||
std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
|
||||
if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) {
|
||||
|
@ -87,14 +88,17 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
|||
TOTAL_OPS = 0;
|
||||
AnfNodePtr ret = root->get_return();
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||
|
||||
if (ParallelInit() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Parallel init failed";
|
||||
}
|
||||
|
||||
// mark the forward cnodes, parallel only care these nodes
|
||||
MarkForwardCNode(root);
|
||||
|
||||
if (!root->has_flag(TRAINING)) {
|
||||
InsertVirtualOutput(root, all_nodes);
|
||||
AnfNodePtr ret_after = root->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret_after);
|
||||
all_nodes = DeepScopedGraphSearch(ret_after);
|
||||
}
|
||||
if (FindCommunicationOp(all_nodes)) {
|
||||
MS_LOG(EXCEPTION) << "The graph contain communication op";
|
||||
}
|
||||
|
@ -163,7 +167,7 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
|
||||
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
|
||||
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT,
|
||||
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE};
|
||||
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT};
|
||||
// clang-format on
|
||||
|
||||
auto iter = splittable_op.find(op_name);
|
||||
|
@ -239,13 +243,7 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive
|
|||
StrategyMap *stra_map, const std::string &strategy_key_name) {
|
||||
// In this case, the configured strategy should be extracted to help setting cost
|
||||
StrategyPtr strategyPtr;
|
||||
if (is_last_nodes) {
|
||||
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
||||
strategyPtr = GenerateBatchParallelStrategy(operator_info, prim);
|
||||
if (full_batch) {
|
||||
SetLastNodeStrategy(strategyPtr);
|
||||
}
|
||||
} else if (StrategyFound(attrs)) {
|
||||
if (StrategyFound(attrs)) {
|
||||
strategyPtr = parallel::ExtractStrategy(attrs);
|
||||
} else {
|
||||
strategyPtr = (*stra_map)[strategy_key_name];
|
||||
|
@ -332,10 +330,9 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
bool load_strategy_from_ckpt =
|
||||
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
|
||||
// If no strategy has been configured for this operator, then candidate strategies are generated for
|
||||
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy;
|
||||
// if strategy is set to load from checkpoint, it is preferred to load strategy from checkpoint.
|
||||
bool is_gen_stra = (!StrategyFound(attrs) || prim->name() == CAST) && (!load_strategy_from_ckpt) && (!is_last_nodes);
|
||||
if (is_gen_stra) {
|
||||
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
|
||||
// if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
|
||||
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
|
||||
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
|
||||
// BatchParallelInfo operator
|
||||
operator_info->ComputeBatchSplitFlagList();
|
||||
|
@ -371,11 +368,6 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|||
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
||||
}
|
||||
}
|
||||
std::vector<std::string> last_forward_node_ids;
|
||||
if (!root->has_flag(TRAINING)) {
|
||||
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
|
||||
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
|
||||
}
|
||||
|
||||
for (auto &node : all_nodes) {
|
||||
// NOTE: we only care about splittable Primitive operators
|
||||
|
@ -421,8 +413,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|||
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr));
|
||||
continue;
|
||||
}
|
||||
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
|
||||
last_forward_node_ids.end();
|
||||
bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput);
|
||||
auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
|
||||
if (operator_info == nullptr) {
|
||||
return FAILED;
|
||||
|
@ -496,11 +487,6 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|||
StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
||||
}
|
||||
std::vector<std::string> last_forward_node_ids;
|
||||
if (!root->has_flag(TRAINING)) {
|
||||
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
|
||||
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
|
||||
}
|
||||
for (auto &node : all_nodes) {
|
||||
// NOTE: we only care about splittable Primitive operators
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -546,8 +532,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|||
continue;
|
||||
}
|
||||
// In this case, the corresponding OperatorInfo is not created, create the new one.
|
||||
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
|
||||
last_forward_node_ids.end();
|
||||
bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput);
|
||||
auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
|
||||
MS_EXCEPTION_IF_NULL(operator_info);
|
||||
|
||||
|
|
|
@ -625,7 +625,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) {
|
|||
return false;
|
||||
}
|
||||
// get_next is not in the forward graph, we need mark the get_next as the forward node
|
||||
if (prim->name() == GET_NEXT) {
|
||||
if (prim->name() == GET_NEXT || prim->name() == VIRTUAL_OUTPUT) {
|
||||
return true;
|
||||
}
|
||||
if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) {
|
||||
|
@ -1004,6 +1004,55 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node
|
|||
}
|
||||
}
|
||||
|
||||
void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
vector<std::string> last_forward_node_ids;
|
||||
vector<size_t> last_indexs;
|
||||
FindLastNodesUniqueId(root, &last_forward_node_ids, &last_indexs);
|
||||
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
|
||||
for (auto &node : all_nodes) {
|
||||
// here insert virtualoutput node
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto last_node_iter = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId());
|
||||
if (last_node_iter == last_forward_node_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
for (size_t last_node_index = 0; last_node_index < last_forward_node_ids.size(); ++last_node_index) {
|
||||
if (last_forward_node_ids[last_node_index] != cnode->UniqueId()) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(INFO) << "find last node: " << cnode->fullname_with_scope() << ", the parallel care node is: "
|
||||
<< cnode->input(last_indexs[last_node_index])->fullname_with_scope();
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
||||
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto node_pair = manager->node_users()[cnode].front();
|
||||
if (!node_pair.first->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "the output of tuple_get_item is not a cnode";
|
||||
}
|
||||
cnode = node_pair.first->cast<CNodePtr>();
|
||||
last_indexs[last_node_index] = size_t(node_pair.second);
|
||||
}
|
||||
FuncGraphPtr func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
OperatorParams params;
|
||||
OperatorAttrs attrs;
|
||||
OperatorArgs args = std::make_pair(attrs, params);
|
||||
Operator op = std::make_pair(VIRTUAL_OUTPUT, args);
|
||||
auto pre_node = cnode->input(last_indexs[last_node_index]);
|
||||
Shapes shape_outputs = GetNodeShape(pre_node);
|
||||
InsertNode(op, cnode, last_indexs[last_node_index], pre_node, func_graph, VIRTUAL_OUTPUT);
|
||||
auto virtual_output_node = cnode->input(last_indexs[last_node_index]);
|
||||
AbstractBasePtr virtual_output_abstract = pre_node->abstract()->Clone();
|
||||
std::shared_ptr<abstract::BaseShape> virtual_output_shape = std::make_shared<abstract::Shape>(shape_outputs[0]);
|
||||
virtual_output_abstract->set_shape(virtual_output_shape);
|
||||
virtual_output_node->set_abstract(virtual_output_abstract);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
|
||||
if (IsValueNode<RefKey>(node)) {
|
||||
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
|
||||
|
@ -1826,7 +1875,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
|||
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->name() == VIRTUAL_DATA_SET) {
|
||||
if (prim->name() == VIRTUAL_DATA_SET || prim->name() == VIRTUAL_OUTPUT) {
|
||||
CheckGlobalDeviceManager();
|
||||
int64_t dev_num;
|
||||
if (full_batch) {
|
||||
|
@ -1856,32 +1905,36 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
|||
}
|
||||
}
|
||||
|
||||
// find previous parallel care node.
|
||||
bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) {
|
||||
// find previous parallel care node's next node.
|
||||
bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vector<size_t> *indexes) {
|
||||
MS_EXCEPTION_IF_NULL(unique_ids);
|
||||
// if previous node is a parameter, handle it in the outsize.
|
||||
if (node->isa<Parameter>()) {
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(indexes);
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
CNodePtr pre_cnode = node->cast<CNodePtr>();
|
||||
if (!IsValueNode<Primitive>(pre_cnode->input(0))) {
|
||||
return false;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||
if (IsParallelCareNode(cnode) && prim->name() != MAKE_TUPLE && prim->name() != MAKE_LIST) {
|
||||
unique_ids->push_back(cnode->UniqueId());
|
||||
return true;
|
||||
}
|
||||
bool find = false;
|
||||
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
|
||||
if (prim->name() == DEPEND && index != 1) {
|
||||
for (size_t index = 1; index < pre_cnode->inputs().size(); ++index) {
|
||||
auto next_node = pre_cnode->inputs()[index];
|
||||
if (!next_node->isa<CNode>() || next_node->isa<Parameter>()) {
|
||||
return false;
|
||||
}
|
||||
CNodePtr cnode = next_node->cast<CNodePtr>();
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return false;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||
if (IsParallelCareNode(cnode) && prim->name() != MAKE_TUPLE && prim->name() != MAKE_LIST) {
|
||||
unique_ids->push_back(pre_cnode->UniqueId());
|
||||
indexes->push_back(index);
|
||||
find = true;
|
||||
continue;
|
||||
}
|
||||
if (FindPreNodes(cnode->inputs()[index], unique_ids)) {
|
||||
if (FindPreNodes(cnode, unique_ids, indexes)) {
|
||||
find = true;
|
||||
continue;
|
||||
}
|
||||
|
@ -1889,20 +1942,12 @@ bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) {
|
|||
return find;
|
||||
}
|
||||
|
||||
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids) {
|
||||
void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *unique_ids,
|
||||
std::vector<size_t> *indexes) {
|
||||
MS_EXCEPTION_IF_NULL(unique_ids);
|
||||
for (auto &node : all_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
if (prim->name() == RETURN) {
|
||||
if (!FindPreNodes(cnode, unique_ids)) {
|
||||
MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph";
|
||||
}
|
||||
}
|
||||
CNodePtr cnode = root->get_return();
|
||||
if (!FindPreNodes(cnode, unique_ids, indexes)) {
|
||||
MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1926,16 +1971,6 @@ StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const
|
|||
return strategyPtr;
|
||||
}
|
||||
|
||||
void SetLastNodeStrategy(const StrategyPtr strategyPtr) {
|
||||
auto strategys = strategyPtr->GetInputDim();
|
||||
for (size_t i = 0; i < strategys.size(); ++i) {
|
||||
for (size_t j = 0; j < strategys[i].size(); ++j) {
|
||||
strategys[i][j] = 1;
|
||||
}
|
||||
}
|
||||
strategyPtr->ResetInputs(strategys);
|
||||
}
|
||||
|
||||
static bool CheckExtractInfomation(const CNodePtr &cnode) {
|
||||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return false;
|
||||
|
@ -1960,11 +1995,6 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
|
|||
(StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS)) {
|
||||
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
||||
}
|
||||
vector<std::string> last_forward_node_ids;
|
||||
if (!is_training) {
|
||||
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
|
||||
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
|
||||
}
|
||||
|
||||
for (auto &node : all_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -2012,10 +2042,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
|
|||
}
|
||||
bool load_strategy_from_ckpt =
|
||||
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
|
||||
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
|
||||
last_forward_node_ids.end();
|
||||
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
||||
if ((is_last_nodes && !full_batch) || (!StrategyFound(attrs) && !load_strategy_from_ckpt)) {
|
||||
if ((!StrategyFound(attrs) && !load_strategy_from_ckpt)) {
|
||||
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
|
||||
<< " is empty, using batch parallel";
|
||||
strategyPtr = GenerateBatchParallelStrategy(operator_, prim);
|
||||
|
@ -2026,9 +2053,6 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
|
|||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(strategyPtr);
|
||||
if (is_last_nodes && full_batch) {
|
||||
SetLastNodeStrategy(strategyPtr);
|
||||
}
|
||||
if (operator_->Init(strategyPtr) == FAILED) {
|
||||
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
|
||||
}
|
||||
|
@ -3537,6 +3561,14 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
MS_LOG(EXCEPTION) << "The graph contain communication op";
|
||||
}
|
||||
|
||||
if (!root->has_flag(TRAINING)) {
|
||||
InsertVirtualOutput(root, all_nodes);
|
||||
AnfNodePtr ret_after = root->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret_after);
|
||||
all_nodes = DeepScopedGraphSearch(ret_after);
|
||||
std::reverse(all_nodes.begin(), all_nodes.end());
|
||||
}
|
||||
|
||||
// extract shape and strategy, set operator_info
|
||||
ExtractInformation(all_nodes, root->has_flag(TRAINING));
|
||||
ReshapeInit(all_nodes);
|
||||
|
|
|
@ -172,7 +172,10 @@ void SetLastNodeStrategy(const StrategyPtr strategyPtr);
|
|||
|
||||
bool CreateGroupsByCkptFile(const std::string &file);
|
||||
|
||||
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids);
|
||||
void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *unique_ids,
|
||||
std::vector<size_t> *indexes);
|
||||
|
||||
void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -195,6 +195,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
{"parallel", opt::OptPassConfig(parallel::StepParallel)},
|
||||
{"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
|
||||
{"virtual_dataset", virtual_dataset},
|
||||
{"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})},
|
||||
{"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())},
|
||||
{"resolve", resolve_pass},
|
||||
{"a_after_grad", a_after_grad},
|
||||
|
|
|
@ -317,6 +317,7 @@ inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared<Primitive>("
|
|||
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
|
||||
inline const PrimitivePtr kPrimVirtualAdd = std::make_shared<Primitive>("_VirtualAdd");
|
||||
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
||||
inline const PrimitivePtr kPrimVirtualOutput = std::make_shared<Primitive>("_VirtualOutput");
|
||||
inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");
|
||||
inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive");
|
||||
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
|
||||
|
|
|
@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
|||
Unique, GatherD, Identity, Range)
|
||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
||||
_VirtualDiv, _GetTensorSlice, _VirtualAdd,
|
||||
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd,
|
||||
_HostAllGather, _HostReduceScatter)
|
||||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||
TensorSummary, HistogramSummary, Print, Assert)
|
||||
|
|
|
@ -670,7 +670,7 @@ class _VirtualDataset(PrimitiveWithInfer):
|
|||
"""
|
||||
Auto parallel virtual dataset operator.
|
||||
|
||||
It would insert Broadcast operator in forward computation and be deleted before backward computation.
|
||||
It would insert VirtualDataset operator in forward computation and be deleted before backward computation.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -686,6 +686,22 @@ class _VirtualDataset(PrimitiveWithInfer):
|
|||
|
||||
virtual_dataset = _VirtualDataset()
|
||||
|
||||
class _VirtualOutput(PrimitiveWithInfer):
|
||||
"""
|
||||
Auto parallel virtual out operator.
|
||||
|
||||
It would insert VirtualOutput operator in forward computation and be deleted before backward computation.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
return x_dtype
|
||||
|
||||
class _GetTensorSlice(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
|
@ -74,6 +74,7 @@ def test_two_bn():
|
|||
net = NetWithLoss(Net())
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
reset_op_id()
|
||||
|
||||
|
|
|
@ -158,4 +158,5 @@ def test_only_one_get_next():
|
|||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net = Net()
|
||||
net.set_train()
|
||||
compile_net(net)
|
||||
|
|
|
@ -0,0 +1,252 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import re
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
class DenseMutMulNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(DenseMutMulNet, self).__init__()
|
||||
self.fc1 = nn.Dense(128, 768)
|
||||
self.fc2 = nn.Dense(128, 768)
|
||||
self.fc3 = nn.Dense(128, 768)
|
||||
self.fc4 = nn.Dense(768, 768, has_bias=False)
|
||||
self.relu4 = nn.ReLU()
|
||||
self.relu5 = nn.ReLU()
|
||||
self.transpose = P.Transpose()
|
||||
self.matmul1 = P.MatMul()
|
||||
self.matmul2 = P.MatMul()
|
||||
self.fc4.matmul.shard(((1, 1), (8, 1)))
|
||||
|
||||
def construct(self, x):
|
||||
q = self.fc1(x)
|
||||
k = self.fc2(x)
|
||||
v = self.fc3(x)
|
||||
k = self.transpose(k, (1, 0))
|
||||
c = self.relu4(self.matmul1(q, k))
|
||||
s = self.relu5(self.matmul2(c, v))
|
||||
s = self.fc4(s)
|
||||
return s
|
||||
|
||||
class MulNegTwoOutputNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().shard(((2, 4), (2, 4)))
|
||||
self.neg = P.Neg().shard(((2, 4),))
|
||||
self.mul_weight = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, x):
|
||||
out1 = self.mul(x, self.mul_weight)
|
||||
out2 = self.neg(out1)
|
||||
return out1, out2
|
||||
|
||||
class ReshapeMatMulNet(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul().shard(strategy2)
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
|
||||
# x (64, 4, 7)
|
||||
def construct(self, x):
|
||||
out = self.reshape(x, (64, 28))
|
||||
out = self.matmul(out, self.matmul_weight)
|
||||
return out
|
||||
|
||||
class MatMulReshapeNet(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
|
||||
# x (128, 28)
|
||||
def construct(self, x):
|
||||
out = self.matmul(x, self.matmul_weight)
|
||||
out = self.reshape(out, (64, -1))
|
||||
return out
|
||||
|
||||
class ReshapeMulNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.Mul().shard(((1, 2, 4), (2, 4)))
|
||||
self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, x):
|
||||
weight = self.reshape(self.mul_weight, (1, 128, 96))
|
||||
out = self.mul(weight, self.mul_weight)
|
||||
return out
|
||||
|
||||
def compile_graph(x, net):
|
||||
net.set_auto_parallel()
|
||||
net.set_train(False)
|
||||
_executor.compile(net, x, auto_parallel_mode=True)
|
||||
strategies = _executor._get_shard_strategy(net)
|
||||
return strategies
|
||||
|
||||
|
||||
def test_dense_relu_semi_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
|
||||
net = DenseMutMulNet()
|
||||
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
|
||||
strategies = compile_graph(x, net)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
assert v[0][0] == 8
|
||||
|
||||
def test_dense_relu_semi_auto_full_batch():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True)
|
||||
net = DenseMutMulNet()
|
||||
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
|
||||
strategies = compile_graph(x, net)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
assert v[0][0] == 1
|
||||
|
||||
def test_dense_relu_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
|
||||
net = DenseMutMulNet()
|
||||
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
|
||||
strategies = compile_graph(x, net)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
assert v[0][0] == 8
|
||||
|
||||
def test_dense_relu_auto_full_batch():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True)
|
||||
net = DenseMutMulNet()
|
||||
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
|
||||
strategies = compile_graph(x, net)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
assert v[0][0] == 1
|
||||
|
||||
def test_mul_neg_two_output_semi_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
|
||||
net = MulNegTwoOutputNet()
|
||||
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
|
||||
strategies = compile_graph(x, net)
|
||||
count = 0
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
count += 1
|
||||
assert v[0][0] == 8
|
||||
assert count == 2
|
||||
|
||||
def test_mul_neg_two_output_semi_auto_full_batch():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True)
|
||||
net = MulNegTwoOutputNet()
|
||||
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
|
||||
strategies = compile_graph(x, net)
|
||||
count = 0
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
count += 1
|
||||
assert v[0][0] == 1
|
||||
assert count == 2
|
||||
|
||||
def test_mul_neg_two_output_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
|
||||
net = MulNegTwoOutputNet()
|
||||
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
|
||||
strategies = compile_graph(x, net)
|
||||
count = 0
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
count += 1
|
||||
assert v[0][0] == 8
|
||||
assert count == 2
|
||||
|
||||
def test_mul_neg_two_output_full_batch():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True)
|
||||
net = MulNegTwoOutputNet()
|
||||
x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
|
||||
strategies = compile_graph(x, net)
|
||||
count = 0
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
count += 1
|
||||
assert v[0][0] == 1
|
||||
assert count == 2
|
||||
|
||||
def test_reshape_matmul_semi_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
|
||||
strategy1 = None
|
||||
strategy2 = ((1, 1), (1, 8))
|
||||
net = ReshapeMatMulNet(strategy1, strategy2)
|
||||
x = Tensor(np.ones([64, 4, 7]), ms.float32)
|
||||
strategies = compile_graph(x, net)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
assert v[0][0] == 8
|
||||
|
||||
def test_reshape_matmul_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
|
||||
strategy1 = None
|
||||
strategy2 = ((1, 1), (1, 8))
|
||||
net = ReshapeMatMulNet(strategy1, strategy2)
|
||||
x = Tensor(np.ones([64, 4, 7]), ms.float32)
|
||||
strategies = compile_graph(x, net)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
assert v[0][0] == 8
|
||||
|
||||
def test_matmul_reshape_semi_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
|
||||
strategy2 = None
|
||||
strategy1 = ((1, 1), (1, 8))
|
||||
net = MatMulReshapeNet(strategy1, strategy2)
|
||||
x = Tensor(np.ones([128, 28]), ms.float32)
|
||||
strategies = compile_graph(x, net)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
assert v[0][0] == 8
|
||||
|
||||
def test_matmul_reshape_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
|
||||
strategy2 = None
|
||||
strategy1 = ((1, 1), (1, 8))
|
||||
net = MatMulReshapeNet(strategy1, strategy2)
|
||||
x = Tensor(np.ones([128, 28]), ms.float32)
|
||||
strategies = compile_graph(x, net)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
assert v[0][0] == 8
|
||||
|
||||
def test_reshape_mul_semi_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True)
|
||||
net = ReshapeMulNet()
|
||||
x = Tensor(np.ones([64, 4]), ms.float32)
|
||||
strategies = compile_graph(x, net)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
assert v[0][0] == 1
|
||||
|
||||
def test_reshape_mul_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True)
|
||||
net = ReshapeMulNet()
|
||||
x = Tensor(np.ones([64, 4]), ms.float32)
|
||||
strategies = compile_graph(x, net)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
assert v[0][0] == 1
|
Loading…
Reference in New Issue