reshape strategy search

This commit is contained in:
yao_yf 2020-05-05 18:00:40 +08:00
parent 08d86c483c
commit f0bf438a55
9 changed files with 520 additions and 61 deletions

View File

@ -13,9 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "parallel/auto_parallel/graph_costmodel.h"
#include <algorithm>
#include <cstdlib>
#include <iterator>
@ -24,6 +21,10 @@
#include <utility>
#include <vector>
#include "parallel/auto_parallel/graph_costmodel.h"
#include "parallel/ops_info/reshape_info.h"
#include "parallel/step_auto_parallel.h"
namespace mindspore {
namespace parallel {
CostGraphPtr entire_costgraph = nullptr;
@ -40,6 +41,7 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
int32_t RUN_PHASE = DEFAULT_RUN_PHASE;
constexpr char RESHAPEINFO[] = "ReshapeInfo";
void CostGraph::SetDeviceMemoryAndCostParameter() {
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
@ -182,6 +184,20 @@ bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) {
return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test));
}
void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) {
std::vector<EdgePtr> curr_edges(edges_[{u_node, v_node}]);
curr_edges.push_back(edge);
edges_[{u_node, v_node}] = curr_edges;
std::vector<EdgePtr> curr_out_edges(out_edges_[u_node]);
curr_out_edges.push_back(edge);
out_edges_[u_node] = curr_out_edges;
std::vector<EdgePtr> curr_in_edges(in_edges_[v_node]);
curr_in_edges.push_back(edge);
in_edges_[v_node] = curr_in_edges;
}
bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) {
for (auto &edge_pair : edges_) {
auto edges = edge_pair.second;
@ -1338,11 +1354,51 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo
Status CostGraph::InitSelectedStrategy() {
for (auto &op : ops_) {
MS_EXCEPTION_IF_NULL(op);
if (op->name().find(RESHAPEINFO) != std::string::npos) {
continue;
}
auto result = op->InitSelectedStrategy(op->selected_strategy());
if (result != SUCCESS) {
return result;
}
}
// reshape init should be apply after the init of it's previous node and next node.
for (size_t i = 0; i < ops_.size(); ++i) {
if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) {
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(ops_[i]);
auto in_edges = GetOriginalPrevEdges(ops_[i]);
auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](std::shared_ptr<Edge> edge) {
return edge->prev_operator()->name() == reshape_info->pre_operator_name();
});
auto out_edges = GetOriginalNextEdges(ops_[i]);
auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr<Edge> edge) {
return edge->next_operator()->name() == reshape_info->next_operator_name();
});
if (pre_iter != in_edges.end()) {
MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name();
int32_t pre_index = reshape_info->pre_operator_index();
Dimensions stra;
TensorInfo pre_info;
if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) {
pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index];
} else {
pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index];
}
reshape_info->SetInputLayout(pre_info.tensor_layout());
InferStraByTensorInfo(pre_info, &stra);
std::vector<Dimensions> stra_inputs = {stra};
StrategyPtr reshape_stra =
std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
reshape_info->set_strategy(reshape_stra);
}
if (next_iter != out_edges.end()) {
MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name();
int32_t next_index = reshape_info->next_operator_index();
reshape_info->SetOutputLayout((*next_iter)->next_operator()->inputs_tensor_info()[next_index].tensor_layout());
}
return reshape_info->Init(nullptr);
}
}
return SUCCESS;
}

View File

@ -87,11 +87,9 @@ class CostGraph {
void RemoveOperator(const OperatorInfoPtr &op);
bool IsOperatorInCostGraph(const OperatorInfoPtr &op);
// the edge is in the form: u --> v
void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) {
std::vector<EdgePtr> curr_edges(edges_[{u_node, v_node}]);
curr_edges.push_back(edge);
edges_[{u_node, v_node}] = curr_edges;
}
void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge);
std::vector<std::shared_ptr<Edge>> GetOriginalPrevEdges(OperatorInfoPtr v_node) { return in_edges_[v_node]; }
std::vector<std::shared_ptr<Edge>> GetOriginalNextEdges(OperatorInfoPtr u_node) { return out_edges_[u_node]; }
// An edge is uniquely identified by its name, and its output index and input index.
bool IsEdgeInCostGraph(const std::string &, size_t, size_t);
@ -219,6 +217,8 @@ class CostGraph {
std::vector<OperatorInfoPtr> ops_;
std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_;
std::vector<std::shared_ptr<CostGraph>> connected_compoents_;
std::map<OperatorInfoPtr, std::vector<EdgePtr>> out_edges_;
std::map<OperatorInfoPtr, std::vector<EdgePtr>> in_edges_;
};
} // namespace parallel
} // namespace mindspore

View File

@ -111,6 +111,7 @@ class OperatorInfo {
Shape dev_matrix_shape() const { return dev_matrix_shape_; }
std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; }
std::vector<TensorInfo> outputs_tensor_info() const { return outputs_tensor_info_; }
std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost() const { return strategy_cost_; }
const std::string &name() const { return name_; }
void set_name(const std::string &name) { name_ = name; }
RankList global_device_list() const { return global_device_list_; }

View File

@ -22,6 +22,7 @@
#include "parallel/device_manager.h"
#include "parallel/device_matrix.h"
#include "parallel/step_parallel.h"
#include "parallel/auto_parallel/graph_costmodel.h"
#include "utils/convert_utils.h"
#include "utils/log_adapter.h"
@ -46,26 +47,6 @@ Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) {
}
return FAILED;
}
std::vector<Dimensions> stra = strategy->GetInputDim();
for (size_t i = 0; i < strategy_size; ++i) {
Shape sub_strategy = stra.at(i);
size_t strategy_len = sub_strategy.size();
bool flag = false;
for (size_t j = 0; j < strategy_len; ++j) {
int32_t strategy_value = sub_strategy.at(j);
if (strategy_value > 1) {
if (flag) {
if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Only support batch parallel strategy.";
} else {
MS_LOG(ERROR) << name_ << ": Only support batch parallel strategy.";
}
return FAILED;
}
flag = true;
}
}
}
return SUCCESS;
}
@ -402,6 +383,41 @@ Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr
return SUCCESS;
}
void ReshapeInfo::SetCostForReshapeWithParameter() {
size_t success = 0;
for (auto &sp : sp_vector_) {
if (SetCostUnderStrategy(sp) == SUCCESS) {
success++;
MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
PrintStrategy(sp);
}
}
}
void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
int32_t stage_id = strategy->GetInputStage();
double computation_cost =
operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
result->communication_without_parameter_ =
operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
result->communication_with_partial_para_ =
result->communication_without_parameter_ +
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
// Breaking ties for preferring data parallelization
BreakingTiesForPerferringDataParallel(strategy, result);
// refine communication cost calculation for practice
RefineForPracticalCost(result, false);
std::shared_ptr<StrategyWithCost> swc =
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
swc->cost_list.push_back(result);
strategy_cost_.emplace_back(swc);
}
Status ReshapeInfo::GenerateStrategies(int32_t stage_id) {
if (GetAttrs() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": GetAttrs failed.";
@ -414,22 +430,14 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) {
}
is_auto_parallel_ = true;
Shape input0_split;
input0_split.emplace_back(1);
(void)input0_split.insert(input0_split.end(), inputs_shape_[0].size() - 1, 0);
(void)input0_split.insert(input0_split.end(), inputs_shape_[0].size(), 1);
Shapes splittable_inputs = {input0_split};
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
// strategy used only in the input node is parameter,
// in other case, use the input node's output_layout as input_layout.
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
return FAILED;
}
size_t success = 0;
for (auto &sp : sp_vector) {
if (SetCostUnderStrategy(sp) == SUCCESS) {
success++;
MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
PrintStrategy(sp);
}
}
return SUCCESS;
}
} // namespace parallel

View File

@ -50,9 +50,19 @@ class ReshapeInfo : public OperatorInfo {
output_layout_ = output_layout;
output_layout_set_flag_ = true;
}
void SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy);
void SetCostForReshapeWithParameter();
void set_pre_operator_name(const std::string &pre_name) { pre_operator_name_ = pre_name; }
void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; }
void set_pre_operator_index(int32_t pre_index) { pre_operator_index_ = pre_index; }
void set_next_operator_index(int32_t next_index) { next_operator_index_ = next_index; }
Status InitForCostModel(const StrategyPtr &strategy) override;
Status GenerateStrategies(int32_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
std::string pre_operator_name() const { return pre_operator_name_; }
std::string next_operator_name() const { return next_operator_name_; }
int32_t pre_operator_index() const { return pre_operator_index_; }
int32_t next_operator_index() const { return next_operator_index_; }
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
@ -73,12 +83,17 @@ class ReshapeInfo : public OperatorInfo {
Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout);
int32_t dev_num_;
int32_t pre_operator_index_;
int32_t next_operator_index_;
std::vector<int32_t> parameter_input_v_;
std::vector<StrategyPtr> sp_vector_;
Dimensions input_strategy_;
TensorLayout input_layout_;
TensorLayout output_layout_;
bool input_layout_set_flag_;
bool output_layout_set_flag_;
std::string pre_operator_name_;
std::string next_operator_name_;
};
} // namespace parallel
} // namespace mindspore

View File

@ -39,6 +39,7 @@
#include "parallel/auto_parallel/rec_core/rec_partition.h"
#include "parallel/context.h"
#include "parallel/ops_info/tmp_identity_info.h"
#include "parallel/ops_info/reshape_info.h"
#include "parallel/step_parallel.h"
#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include "pipeline/parse/python_adapter.h"
@ -608,7 +609,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
EdgePtr edge_ptr;
MS_LOG(INFO) << "Creating edge: " << edge_name;
bool follow_strategy = ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name());
bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) ||
(ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name()));
if (follow_strategy) {
// Redistribution in not allowed on the edge.
// Elementwise operators have the same strategy as their previous operators.
@ -893,6 +895,209 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
}
}
bool FindReshape(const CNodePtr &cnode) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) {
return false;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info();
if (operator_info == nullptr) {
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
}
if (prim->name() != RESHAPE) {
return false;
}
return true;
}
// find previous node, then obtain its strategy_cost_ vector to get its layout vector.
bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int32_t *out_index) {
// if previous node is a parameter, handle it in the outsize.
if (node->isa<Parameter>()) {
return false;
}
if (!node->isa<CNode>()) {
return false;
}
CNodePtr cnode = node->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
*pre_operator_info = cnode->operator_info();
*out_index = 0;
return true;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
if (prim->name() == TUPLE_GETITEM) {
*out_index = GetTupleGetItemIndex(cnode);
// find tuple_get_item's previous node
auto pre_node = cnode->input(1);
if (!pre_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
}
CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) {
*pre_operator_info = pre_cnode->operator_info();
return true;
}
return false;
}
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
if (prim->name() == DEPEND && index != 1) {
continue;
}
if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) {
continue;
}
return true;
}
MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error";
return false;
}
// find next node, then obtain its strategy_cost_ vector to get its layout vector.
// if reshape's output connect to several primitive, return the first layout found
bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int32_t *in_index) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(cnode->func_graph());
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
AnfNodeIndexSet node_set = manager->node_users()[cnode];
for (auto &node_pair : node_set) {
CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
continue;
}
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(node_prim);
MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) {
MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name();
*next_operator_info = use_apply->operator_info();
*in_index = node_pair.second - 1;
return true;
}
MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
<< " " << (use_apply->operator_info() != nullptr);
if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) {
return true;
}
}
return false;
}
void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra) {
Shape shape = pre_out_tensor_info.shape();
Shape slice_shape = pre_out_tensor_info.slice_shape();
for (size_t i = 0; i < shape.size(); ++i) {
if ((slice_shape[i] == 0) || (shape[i] % slice_shape[i] != 0)) {
MS_LOG(EXCEPTION) << "slice_shape is wrong in reshape operator";
}
int32_t dim = (int32_t)(shape[i] / slice_shape[i]);
(*stra).push_back(dim);
}
}
void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
for (auto node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
if (!FindReshape(cnode)) {
continue;
}
MS_ASSERT(cnode->inputs().size() == 3);
// get previous node's strategy_cost_
auto pre_node = cnode->input(1);
int32_t out_index = 0;
OperatorInfoPtr pre_operator_info;
std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
if (pre_node->isa<Parameter>()) {
OperatorInfoPtr operator_info = cnode->operator_info();
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
reshape_info->SetCostForReshapeWithParameter();
pre_operator_info = reshape_info;
pre_stra_costs = reshape_info->strategy_cost();
} else {
if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) {
MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed";
}
pre_stra_costs = pre_operator_info->strategy_cost();
}
// get next node's strategy_cost_
int32_t in_index = 0;
OperatorInfoPtr next_operator_info;
std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs;
bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index);
if (!find_next_node) {
MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed";
}
// set input_layout and output_layout for reshape.
// init reshape and set cost for each input_layout and output_layout.
OperatorInfoPtr operator_info = cnode->operator_info();
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
reshape_info->set_pre_operator_name(pre_operator_info->name());
reshape_info->set_pre_operator_index(out_index);
if (find_next_node) {
next_stra_costs = next_operator_info->strategy_cost();
reshape_info->set_next_operator_name(next_operator_info->name());
reshape_info->set_next_operator_index(in_index);
}
for (auto pre_stra_cost : pre_stra_costs) {
std::vector<TensorInfo> pre_out_tensor_infos;
if (pre_node->isa<Parameter>()) {
pre_out_tensor_infos = pre_stra_cost->inputs_ptr;
} else {
pre_out_tensor_infos = pre_stra_cost->outputs_ptr;
}
if (pre_out_tensor_infos.size() <= IntToSize(out_index)) {
MS_LOG(EXCEPTION) << "out_index is out of range of the tensor_infos in setting reshape's input_layout";
}
TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index];
TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout();
reshape_info->SetInputLayout(pre_out_tensor_layout);
// infer pre_node output strategy from output_layout.
Dimensions stra;
InferStraByTensorInfo(pre_out_tensor_info, &stra);
std::vector<Dimensions> stra_inputs = {stra};
StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs);
if (next_stra_costs.empty()) {
if (reshape_info->Init(nullptr) == FAILED) {
MS_LOG(EXCEPTION) << "Failure:operator reshape init failed";
}
// set cost for each input_layout and output_layout pairs.
reshape_info->SetCostForReshape(reshape_stra);
continue;
}
for (auto next_stra_cost : next_stra_costs) {
std::vector<TensorInfo> next_in_tensor_infos = next_stra_cost->inputs_ptr;
if (next_in_tensor_infos.size() <= IntToSize(in_index)) {
MS_LOG(EXCEPTION) << "in_index is out of range of the tensor_infos in setting reshape's output_layout";
}
TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index];
TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout();
reshape_info->SetOutputLayout(next_in_tensor_layout);
if (reshape_info->Init(nullptr) == FAILED) {
MS_LOG(EXCEPTION) << "Failure:operator reshape init failed";
}
// set cost for each input_layout and output_layout pairs.
reshape_info->SetCostForReshape(reshape_stra);
}
}
}
}
Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
// There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
// Step 1: Traverse the ANF graph, and create NODEs for costgraph:
@ -930,7 +1135,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
}
}
// reshape operator needs the next node's input_layout as its output_layout.
// and needs the previous node's output_layout as its input_layout.
ReshapeCostCompute(all_nodes);
// Step 2
ConstructCostGraphEdges(all_nodes);
MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()

View File

@ -51,6 +51,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes);
void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes);
void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra);
Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);
Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);

View File

@ -219,22 +219,5 @@ TEST_F(TestReshapeInfo, CheckStrategy3) {
Status ret = reshape->Init(strategy);
ASSERT_EQ(ret, SUCCESS);
}
TEST_F(TestReshapeInfo, AutoStrategy1) {
ASSERT_EQ(reshape->GenerateStrategies(0), Status::SUCCESS);
std::vector<std::shared_ptr<StrategyWithCost>> sc = reshape->GetStrategyCost();
Shapes splittable_inputs = {{1, 0, 0, 0}};
std::vector<StrategyPtr> sp_vector;
Shapes inputs_shape = {{32, 512, 7, 7}};
GenerateStrategiesForIndependentInputs(0, inputs_shape, splittable_inputs, &sp_vector);
ASSERT_EQ(sc.size(), sp_vector.size());
for (auto stra : sp_vector) {
auto stra0 = stra->GetInputDim()[0];
ASSERT_EQ(stra0[1], 1);
ASSERT_EQ(stra0[2], 1);
ASSERT_EQ(stra0[3], 1);
}
}
} // namespace parallel
} // namespace mindspore

View File

@ -65,6 +65,193 @@ def test_reshape_matmul():
net.set_auto_parallel()
_executor.compile(net, x)
def test_reshape_auto_1():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = P.ReLU()
self.reshape = P.Reshape()
self.matmul = P.MatMul()
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
if __name__ == '__main__':
test_reshape_matmul()
def construct(self, x):
out = self.relu(x)
out = self.reshape(out, (64, 28))
out = self.matmul(out, self.matmul_weight)
return out
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
_executor.compile(net, x)
def test_reshape_auto_2():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = P.ReLU()
self.reshape = P.Reshape()
self.matmul = P.MatMul()
self.add_weight = Parameter(Tensor(np.ones([128, 32]), dtype=ms.float32), name="weight1")
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
def construct(self, x):
out = self.relu(x)
out = self.reshape(out, (64, 28))
out = self.matmul(out, self.matmul_weight)
out = self.reshape(out, (128, 32))
out = out + self.add_weight
return out
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
_executor.compile(net, x)
def test_reshape_auto_3():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = P.ReLU()
self.reshape = P.Reshape()
self.matmul = P.MatMul()
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
def construct(self, x):
out = self.relu(x)
out = self.matmul(out, self.matmul_weight)
out = self.reshape(out, (8, 8, 8, 8))
return out
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([8*size, 28]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
_executor.compile(net, x)
def test_reshape_auto_4():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = P.ReLU()
self.reshape = P.Reshape()
self.matmul = P.MatMul()
self.matmul_weight = Parameter(Tensor(np.ones([28*64]), dtype=ms.float32), name="weight")
def construct(self, x):
out = self.relu(x)
out = self.reshape(out, (64, 28))
w = self.reshape(self.matmul_weight, (28, 64))
out = self.matmul(out, w)
return out
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
_executor.compile(net, x)
def test_reshape_auto_5():
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return C.grad_all(self.network)(x, y)
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = P.ReLU()
self.mul = P.Mul()
self.reshape = P.Reshape()
self.reduce_sum = P.ReduceSum()
self.wide_w = Parameter(Tensor(np.ones([4, 1024*8, 64]), dtype=ms.float32), name="weight")
def construct(self, x, y):
mask = self.reshape(y, (4, 1024*8, 1))
w_id = self.relu(x)
wx = self.mul(w_id, mask)
wide_out = self.reshape(self.reduce_sum(wx, 1), (-1,1))
deep_id = x + self.wide_w
vx = self.mul(deep_id, mask)
deep_in = self.reshape(vx, (-1, 1024*8*64))
out = wide_out + deep_in
return out
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([4, 1024*size, 1]), dtype=ms.float32)
y = Tensor(np.ones([4, 1024*size,]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
_executor.compile(net, x, y)
def test_reshape_auto_6():
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return C.grad_all(self.network)(x, y)
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = P.ReLU()
self.mul = P.Mul()
self.reshape = P.Reshape()
self.reduce_mean = P.ReduceMean()
self.wide_w = Parameter(Tensor(np.ones([4, 1024, 1]), dtype=ms.float32), name="weight")
def construct(self, x, y):
out1 = x + self.wide_w
w = self.reshape(self.wide_w, (4,1024))
out1 = self.reduce_mean(out1, 1)
out1 = out1 - w
out2 = self.mul(y, w)
out = out1 + out2
return out
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32)
y = Tensor(np.ones([4, 1024,]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
_executor.compile(net, x, y)