forked from mindspore-Ecosystem/mindspore
reshape strategy search
This commit is contained in:
parent
08d86c483c
commit
f0bf438a55
|
@ -13,9 +13,6 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "parallel/auto_parallel/graph_costmodel.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <iterator>
|
||||
|
@ -24,6 +21,10 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "parallel/auto_parallel/graph_costmodel.h"
|
||||
#include "parallel/ops_info/reshape_info.h"
|
||||
#include "parallel/step_auto_parallel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
CostGraphPtr entire_costgraph = nullptr;
|
||||
|
@ -40,6 +41,7 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
|
|||
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
|
||||
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
|
||||
int32_t RUN_PHASE = DEFAULT_RUN_PHASE;
|
||||
constexpr char RESHAPEINFO[] = "ReshapeInfo";
|
||||
|
||||
void CostGraph::SetDeviceMemoryAndCostParameter() {
|
||||
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
|
||||
|
@ -182,6 +184,20 @@ bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) {
|
|||
return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test));
|
||||
}
|
||||
|
||||
void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) {
|
||||
std::vector<EdgePtr> curr_edges(edges_[{u_node, v_node}]);
|
||||
curr_edges.push_back(edge);
|
||||
edges_[{u_node, v_node}] = curr_edges;
|
||||
|
||||
std::vector<EdgePtr> curr_out_edges(out_edges_[u_node]);
|
||||
curr_out_edges.push_back(edge);
|
||||
out_edges_[u_node] = curr_out_edges;
|
||||
|
||||
std::vector<EdgePtr> curr_in_edges(in_edges_[v_node]);
|
||||
curr_in_edges.push_back(edge);
|
||||
in_edges_[v_node] = curr_in_edges;
|
||||
}
|
||||
|
||||
bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) {
|
||||
for (auto &edge_pair : edges_) {
|
||||
auto edges = edge_pair.second;
|
||||
|
@ -1338,11 +1354,51 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo
|
|||
Status CostGraph::InitSelectedStrategy() {
|
||||
for (auto &op : ops_) {
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
if (op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
auto result = op->InitSelectedStrategy(op->selected_strategy());
|
||||
if (result != SUCCESS) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
// reshape init should be apply after the init of it's previous node and next node.
|
||||
for (size_t i = 0; i < ops_.size(); ++i) {
|
||||
if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(ops_[i]);
|
||||
auto in_edges = GetOriginalPrevEdges(ops_[i]);
|
||||
auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](std::shared_ptr<Edge> edge) {
|
||||
return edge->prev_operator()->name() == reshape_info->pre_operator_name();
|
||||
});
|
||||
auto out_edges = GetOriginalNextEdges(ops_[i]);
|
||||
auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr<Edge> edge) {
|
||||
return edge->next_operator()->name() == reshape_info->next_operator_name();
|
||||
});
|
||||
if (pre_iter != in_edges.end()) {
|
||||
MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name();
|
||||
int32_t pre_index = reshape_info->pre_operator_index();
|
||||
Dimensions stra;
|
||||
TensorInfo pre_info;
|
||||
if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) {
|
||||
pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index];
|
||||
} else {
|
||||
pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index];
|
||||
}
|
||||
reshape_info->SetInputLayout(pre_info.tensor_layout());
|
||||
InferStraByTensorInfo(pre_info, &stra);
|
||||
std::vector<Dimensions> stra_inputs = {stra};
|
||||
StrategyPtr reshape_stra =
|
||||
std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
|
||||
reshape_info->set_strategy(reshape_stra);
|
||||
}
|
||||
if (next_iter != out_edges.end()) {
|
||||
MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name();
|
||||
int32_t next_index = reshape_info->next_operator_index();
|
||||
reshape_info->SetOutputLayout((*next_iter)->next_operator()->inputs_tensor_info()[next_index].tensor_layout());
|
||||
}
|
||||
return reshape_info->Init(nullptr);
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -87,11 +87,9 @@ class CostGraph {
|
|||
void RemoveOperator(const OperatorInfoPtr &op);
|
||||
bool IsOperatorInCostGraph(const OperatorInfoPtr &op);
|
||||
// the edge is in the form: u --> v
|
||||
void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) {
|
||||
std::vector<EdgePtr> curr_edges(edges_[{u_node, v_node}]);
|
||||
curr_edges.push_back(edge);
|
||||
edges_[{u_node, v_node}] = curr_edges;
|
||||
}
|
||||
void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge);
|
||||
std::vector<std::shared_ptr<Edge>> GetOriginalPrevEdges(OperatorInfoPtr v_node) { return in_edges_[v_node]; }
|
||||
std::vector<std::shared_ptr<Edge>> GetOriginalNextEdges(OperatorInfoPtr u_node) { return out_edges_[u_node]; }
|
||||
// An edge is uniquely identified by its name, and its output index and input index.
|
||||
bool IsEdgeInCostGraph(const std::string &, size_t, size_t);
|
||||
|
||||
|
@ -219,6 +217,8 @@ class CostGraph {
|
|||
std::vector<OperatorInfoPtr> ops_;
|
||||
std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_;
|
||||
std::vector<std::shared_ptr<CostGraph>> connected_compoents_;
|
||||
std::map<OperatorInfoPtr, std::vector<EdgePtr>> out_edges_;
|
||||
std::map<OperatorInfoPtr, std::vector<EdgePtr>> in_edges_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -111,6 +111,7 @@ class OperatorInfo {
|
|||
Shape dev_matrix_shape() const { return dev_matrix_shape_; }
|
||||
std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; }
|
||||
std::vector<TensorInfo> outputs_tensor_info() const { return outputs_tensor_info_; }
|
||||
std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost() const { return strategy_cost_; }
|
||||
const std::string &name() const { return name_; }
|
||||
void set_name(const std::string &name) { name_ = name; }
|
||||
RankList global_device_list() const { return global_device_list_; }
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "parallel/device_manager.h"
|
||||
#include "parallel/device_matrix.h"
|
||||
#include "parallel/step_parallel.h"
|
||||
#include "parallel/auto_parallel/graph_costmodel.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
|
@ -46,26 +47,6 @@ Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
}
|
||||
return FAILED;
|
||||
}
|
||||
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||
for (size_t i = 0; i < strategy_size; ++i) {
|
||||
Shape sub_strategy = stra.at(i);
|
||||
size_t strategy_len = sub_strategy.size();
|
||||
bool flag = false;
|
||||
for (size_t j = 0; j < strategy_len; ++j) {
|
||||
int32_t strategy_value = sub_strategy.at(j);
|
||||
if (strategy_value > 1) {
|
||||
if (flag) {
|
||||
if (is_auto_parallel_) {
|
||||
MS_LOG(DEBUG) << name_ << ": Only support batch parallel strategy.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": Only support batch parallel strategy.";
|
||||
}
|
||||
return FAILED;
|
||||
}
|
||||
flag = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -402,6 +383,41 @@ Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
void ReshapeInfo::SetCostForReshapeWithParameter() {
|
||||
size_t success = 0;
|
||||
for (auto &sp : sp_vector_) {
|
||||
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
||||
success++;
|
||||
MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
|
||||
PrintStrategy(sp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
int32_t stage_id = strategy->GetInputStage();
|
||||
double computation_cost =
|
||||
operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
|
||||
result->communication_without_parameter_ =
|
||||
operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
|
||||
result->communication_with_partial_para_ =
|
||||
result->communication_without_parameter_ +
|
||||
COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
|
||||
|
||||
// Breaking ties for preferring data parallelization
|
||||
BreakingTiesForPerferringDataParallel(strategy, result);
|
||||
// refine communication cost calculation for practice
|
||||
RefineForPracticalCost(result, false);
|
||||
|
||||
std::shared_ptr<StrategyWithCost> swc =
|
||||
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
|
||||
swc->cost_list.push_back(result);
|
||||
strategy_cost_.emplace_back(swc);
|
||||
}
|
||||
|
||||
Status ReshapeInfo::GenerateStrategies(int32_t stage_id) {
|
||||
if (GetAttrs() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": GetAttrs failed.";
|
||||
|
@ -414,22 +430,14 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) {
|
|||
}
|
||||
is_auto_parallel_ = true;
|
||||
Shape input0_split;
|
||||
input0_split.emplace_back(1);
|
||||
(void)input0_split.insert(input0_split.end(), inputs_shape_[0].size() - 1, 0);
|
||||
(void)input0_split.insert(input0_split.end(), inputs_shape_[0].size(), 1);
|
||||
Shapes splittable_inputs = {input0_split};
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||
// strategy used only in the input node is parameter,
|
||||
// in other case, use the input node's output_layout as input_layout.
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
|
||||
return FAILED;
|
||||
}
|
||||
size_t success = 0;
|
||||
for (auto &sp : sp_vector) {
|
||||
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
||||
success++;
|
||||
MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
|
||||
PrintStrategy(sp);
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
|
|
|
@ -50,9 +50,19 @@ class ReshapeInfo : public OperatorInfo {
|
|||
output_layout_ = output_layout;
|
||||
output_layout_set_flag_ = true;
|
||||
}
|
||||
void SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy);
|
||||
void SetCostForReshapeWithParameter();
|
||||
void set_pre_operator_name(const std::string &pre_name) { pre_operator_name_ = pre_name; }
|
||||
void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; }
|
||||
void set_pre_operator_index(int32_t pre_index) { pre_operator_index_ = pre_index; }
|
||||
void set_next_operator_index(int32_t next_index) { next_operator_index_ = next_index; }
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
Status GenerateStrategies(int32_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
std::string pre_operator_name() const { return pre_operator_name_; }
|
||||
std::string next_operator_name() const { return next_operator_name_; }
|
||||
int32_t pre_operator_index() const { return pre_operator_index_; }
|
||||
int32_t next_operator_index() const { return next_operator_index_; }
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
|
@ -73,12 +83,17 @@ class ReshapeInfo : public OperatorInfo {
|
|||
Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout);
|
||||
|
||||
int32_t dev_num_;
|
||||
int32_t pre_operator_index_;
|
||||
int32_t next_operator_index_;
|
||||
std::vector<int32_t> parameter_input_v_;
|
||||
std::vector<StrategyPtr> sp_vector_;
|
||||
Dimensions input_strategy_;
|
||||
TensorLayout input_layout_;
|
||||
TensorLayout output_layout_;
|
||||
bool input_layout_set_flag_;
|
||||
bool output_layout_set_flag_;
|
||||
std::string pre_operator_name_;
|
||||
std::string next_operator_name_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,6 +39,7 @@
|
|||
#include "parallel/auto_parallel/rec_core/rec_partition.h"
|
||||
#include "parallel/context.h"
|
||||
#include "parallel/ops_info/tmp_identity_info.h"
|
||||
#include "parallel/ops_info/reshape_info.h"
|
||||
#include "parallel/step_parallel.h"
|
||||
#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
|
||||
#include "pipeline/parse/python_adapter.h"
|
||||
|
@ -608,7 +609,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
EdgePtr edge_ptr;
|
||||
MS_LOG(INFO) << "Creating edge: " << edge_name;
|
||||
|
||||
bool follow_strategy = ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name());
|
||||
bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) ||
|
||||
(ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name()));
|
||||
if (follow_strategy) {
|
||||
// Redistribution in not allowed on the edge.
|
||||
// Elementwise operators have the same strategy as their previous operators.
|
||||
|
@ -893,6 +895,209 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
}
|
||||
}
|
||||
|
||||
bool FindReshape(const CNodePtr &cnode) {
|
||||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return false;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) {
|
||||
return false;
|
||||
}
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
OperatorInfoPtr operator_info = cnode->operator_info();
|
||||
if (operator_info == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
|
||||
}
|
||||
if (prim->name() != RESHAPE) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// find previous node, then obtain its strategy_cost_ vector to get its layout vector.
|
||||
bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int32_t *out_index) {
|
||||
// if previous node is a parameter, handle it in the outsize.
|
||||
if (node->isa<Parameter>()) {
|
||||
return false;
|
||||
}
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return false;
|
||||
}
|
||||
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
|
||||
*pre_operator_info = cnode->operator_info();
|
||||
*out_index = 0;
|
||||
return true;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||
if (prim->name() == TUPLE_GETITEM) {
|
||||
*out_index = GetTupleGetItemIndex(cnode);
|
||||
// find tuple_get_item's previous node
|
||||
auto pre_node = cnode->input(1);
|
||||
if (!pre_node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
|
||||
}
|
||||
CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
|
||||
if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) {
|
||||
*pre_operator_info = pre_cnode->operator_info();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
|
||||
if (prim->name() == DEPEND && index != 1) {
|
||||
continue;
|
||||
}
|
||||
if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) {
|
||||
continue;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error";
|
||||
return false;
|
||||
}
|
||||
|
||||
// find next node, then obtain its strategy_cost_ vector to get its layout vector.
|
||||
// if reshape's output connect to several primitive, return the first layout found
|
||||
bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int32_t *in_index) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(cnode->func_graph());
|
||||
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
AnfNodeIndexSet node_set = manager->node_users()[cnode];
|
||||
for (auto &node_pair : node_set) {
|
||||
CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
|
||||
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
|
||||
continue;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
||||
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node_prim);
|
||||
MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
|
||||
if (node_prim->name() == DEPEND && node_pair.second != 1) {
|
||||
continue;
|
||||
}
|
||||
if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) {
|
||||
MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name();
|
||||
*next_operator_info = use_apply->operator_info();
|
||||
*in_index = node_pair.second - 1;
|
||||
return true;
|
||||
}
|
||||
MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
|
||||
<< " " << (use_apply->operator_info() != nullptr);
|
||||
|
||||
if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra) {
|
||||
Shape shape = pre_out_tensor_info.shape();
|
||||
Shape slice_shape = pre_out_tensor_info.slice_shape();
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
if ((slice_shape[i] == 0) || (shape[i] % slice_shape[i] != 0)) {
|
||||
MS_LOG(EXCEPTION) << "slice_shape is wrong in reshape operator";
|
||||
}
|
||||
int32_t dim = (int32_t)(shape[i] / slice_shape[i]);
|
||||
(*stra).push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
for (auto node : all_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!FindReshape(cnode)) {
|
||||
continue;
|
||||
}
|
||||
MS_ASSERT(cnode->inputs().size() == 3);
|
||||
// get previous node's strategy_cost_
|
||||
auto pre_node = cnode->input(1);
|
||||
int32_t out_index = 0;
|
||||
OperatorInfoPtr pre_operator_info;
|
||||
std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
|
||||
if (pre_node->isa<Parameter>()) {
|
||||
OperatorInfoPtr operator_info = cnode->operator_info();
|
||||
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
|
||||
reshape_info->SetCostForReshapeWithParameter();
|
||||
pre_operator_info = reshape_info;
|
||||
pre_stra_costs = reshape_info->strategy_cost();
|
||||
} else {
|
||||
if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) {
|
||||
MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed";
|
||||
}
|
||||
pre_stra_costs = pre_operator_info->strategy_cost();
|
||||
}
|
||||
// get next node's strategy_cost_
|
||||
int32_t in_index = 0;
|
||||
OperatorInfoPtr next_operator_info;
|
||||
std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs;
|
||||
bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index);
|
||||
if (!find_next_node) {
|
||||
MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed";
|
||||
}
|
||||
// set input_layout and output_layout for reshape.
|
||||
// init reshape and set cost for each input_layout and output_layout.
|
||||
OperatorInfoPtr operator_info = cnode->operator_info();
|
||||
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
|
||||
reshape_info->set_pre_operator_name(pre_operator_info->name());
|
||||
reshape_info->set_pre_operator_index(out_index);
|
||||
if (find_next_node) {
|
||||
next_stra_costs = next_operator_info->strategy_cost();
|
||||
reshape_info->set_next_operator_name(next_operator_info->name());
|
||||
reshape_info->set_next_operator_index(in_index);
|
||||
}
|
||||
for (auto pre_stra_cost : pre_stra_costs) {
|
||||
std::vector<TensorInfo> pre_out_tensor_infos;
|
||||
if (pre_node->isa<Parameter>()) {
|
||||
pre_out_tensor_infos = pre_stra_cost->inputs_ptr;
|
||||
} else {
|
||||
pre_out_tensor_infos = pre_stra_cost->outputs_ptr;
|
||||
}
|
||||
if (pre_out_tensor_infos.size() <= IntToSize(out_index)) {
|
||||
MS_LOG(EXCEPTION) << "out_index is out of range of the tensor_infos in setting reshape's input_layout";
|
||||
}
|
||||
TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index];
|
||||
TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout();
|
||||
reshape_info->SetInputLayout(pre_out_tensor_layout);
|
||||
// infer pre_node output strategy from output_layout.
|
||||
Dimensions stra;
|
||||
InferStraByTensorInfo(pre_out_tensor_info, &stra);
|
||||
std::vector<Dimensions> stra_inputs = {stra};
|
||||
StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs);
|
||||
if (next_stra_costs.empty()) {
|
||||
if (reshape_info->Init(nullptr) == FAILED) {
|
||||
MS_LOG(EXCEPTION) << "Failure:operator reshape init failed";
|
||||
}
|
||||
// set cost for each input_layout and output_layout pairs.
|
||||
reshape_info->SetCostForReshape(reshape_stra);
|
||||
continue;
|
||||
}
|
||||
for (auto next_stra_cost : next_stra_costs) {
|
||||
std::vector<TensorInfo> next_in_tensor_infos = next_stra_cost->inputs_ptr;
|
||||
if (next_in_tensor_infos.size() <= IntToSize(in_index)) {
|
||||
MS_LOG(EXCEPTION) << "in_index is out of range of the tensor_infos in setting reshape's output_layout";
|
||||
}
|
||||
TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index];
|
||||
TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout();
|
||||
reshape_info->SetOutputLayout(next_in_tensor_layout);
|
||||
if (reshape_info->Init(nullptr) == FAILED) {
|
||||
MS_LOG(EXCEPTION) << "Failure:operator reshape init failed";
|
||||
}
|
||||
// set cost for each input_layout and output_layout pairs.
|
||||
reshape_info->SetCostForReshape(reshape_stra);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
|
||||
// There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
|
||||
// Step 1: Traverse the ANF graph, and create NODEs for costgraph:
|
||||
|
@ -930,7 +1135,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
|
|||
MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
|
||||
}
|
||||
}
|
||||
|
||||
// reshape operator needs the next node's input_layout as its output_layout.
|
||||
// and needs the previous node's output_layout as its input_layout.
|
||||
ReshapeCostCompute(all_nodes);
|
||||
// Step 2
|
||||
ConstructCostGraphEdges(all_nodes);
|
||||
MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
|
||||
|
|
|
@ -51,6 +51,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes);
|
|||
|
||||
void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes);
|
||||
|
||||
void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra);
|
||||
|
||||
Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);
|
||||
|
||||
Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);
|
||||
|
|
|
@ -219,22 +219,5 @@ TEST_F(TestReshapeInfo, CheckStrategy3) {
|
|||
Status ret = reshape->Init(strategy);
|
||||
ASSERT_EQ(ret, SUCCESS);
|
||||
}
|
||||
|
||||
TEST_F(TestReshapeInfo, AutoStrategy1) {
|
||||
ASSERT_EQ(reshape->GenerateStrategies(0), Status::SUCCESS);
|
||||
std::vector<std::shared_ptr<StrategyWithCost>> sc = reshape->GetStrategyCost();
|
||||
|
||||
Shapes splittable_inputs = {{1, 0, 0, 0}};
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
Shapes inputs_shape = {{32, 512, 7, 7}};
|
||||
GenerateStrategiesForIndependentInputs(0, inputs_shape, splittable_inputs, &sp_vector);
|
||||
ASSERT_EQ(sc.size(), sp_vector.size());
|
||||
for (auto stra : sp_vector) {
|
||||
auto stra0 = stra->GetInputDim()[0];
|
||||
ASSERT_EQ(stra0[1], 1);
|
||||
ASSERT_EQ(stra0[2], 1);
|
||||
ASSERT_EQ(stra0[3], 1);
|
||||
}
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -65,6 +65,193 @@ def test_reshape_matmul():
|
|||
net.set_auto_parallel()
|
||||
_executor.compile(net, x)
|
||||
|
||||
def test_reshape_auto_1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul()
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_reshape_matmul()
|
||||
def construct(self, x):
|
||||
out = self.relu(x)
|
||||
out = self.reshape(out, (64, 28))
|
||||
out = self.matmul(out, self.matmul_weight)
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
_executor.compile(net, x)
|
||||
|
||||
def test_reshape_auto_2():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul()
|
||||
self.add_weight = Parameter(Tensor(np.ones([128, 32]), dtype=ms.float32), name="weight1")
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.relu(x)
|
||||
out = self.reshape(out, (64, 28))
|
||||
out = self.matmul(out, self.matmul_weight)
|
||||
out = self.reshape(out, (128, 32))
|
||||
out = out + self.add_weight
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
_executor.compile(net, x)
|
||||
|
||||
def test_reshape_auto_3():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul()
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.relu(x)
|
||||
out = self.matmul(out, self.matmul_weight)
|
||||
out = self.reshape(out, (8, 8, 8, 8))
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([8*size, 28]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
_executor.compile(net, x)
|
||||
|
||||
def test_reshape_auto_4():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul()
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([28*64]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.relu(x)
|
||||
out = self.reshape(out, (64, 28))
|
||||
w = self.reshape(self.matmul_weight, (28, 64))
|
||||
out = self.matmul(out, w)
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
_executor.compile(net, x)
|
||||
|
||||
|
||||
def test_reshape_auto_5():
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
return self.loss(predict)
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
return C.grad_all(self.network)(x, y)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.mul = P.Mul()
|
||||
self.reshape = P.Reshape()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.wide_w = Parameter(Tensor(np.ones([4, 1024*8, 64]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, x, y):
|
||||
mask = self.reshape(y, (4, 1024*8, 1))
|
||||
w_id = self.relu(x)
|
||||
wx = self.mul(w_id, mask)
|
||||
wide_out = self.reshape(self.reduce_sum(wx, 1), (-1,1))
|
||||
deep_id = x + self.wide_w
|
||||
vx = self.mul(deep_id, mask)
|
||||
deep_in = self.reshape(vx, (-1, 1024*8*64))
|
||||
out = wide_out + deep_in
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([4, 1024*size, 1]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([4, 1024*size,]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
def test_reshape_auto_6():
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
return self.loss(predict)
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
return C.grad_all(self.network)(x, y)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.mul = P.Mul()
|
||||
self.reshape = P.Reshape()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.wide_w = Parameter(Tensor(np.ones([4, 1024, 1]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, x, y):
|
||||
out1 = x + self.wide_w
|
||||
w = self.reshape(self.wide_w, (4,1024))
|
||||
out1 = self.reduce_mean(out1, 1)
|
||||
out1 = out1 - w
|
||||
out2 = self.mul(y, w)
|
||||
out = out1 + out2
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([4, 1024,]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
_executor.compile(net, x, y)
|
||||
|
|
Loading…
Reference in New Issue