!41639 Open auto dynamic shape

Merge pull request !41639 from zjun/open_auto_dynamic
This commit is contained in:
i-robot 2022-09-20 06:28:34 +00:00 committed by Gitee
commit c1893ff7e5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 157 additions and 148 deletions

View File

@ -1096,7 +1096,7 @@ void KPynativeCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool ha
if (has_sens_arg) { if (has_sens_arg) {
auto sens_param = tape_->add_parameter(); auto sens_param = tape_->add_parameter();
sens_param->debug_info()->set_name("sens"); sens_param->debug_info()->set_name("sens");
sens_param->set_abstract(last_node_adjoint_iter->second->out()->ToAbstract()->Broaden()); sens_param->set_abstract(last_node_->abstract()->Broaden());
// Set dout of last node to sens; // Set dout of last node to sens;
last_node_adjoint_iter->second->AccumulateDout(sens_param); last_node_adjoint_iter->second->AccumulateDout(sens_param);
} else { } else {

View File

@ -22,6 +22,8 @@
namespace mindspore { namespace mindspore {
namespace pynative { namespace pynative {
const char kSensInfo[] = "SensInfo"; const char kSensInfo[] = "SensInfo";
static const ShapeValueDType UNKNOWN_DIM = -1;
static const ShapeValueDType UNKNOWN_RANK = -2;
ShapeVector DynamicShape::GetTensorShape(const ValuePtr &v) const { ShapeVector DynamicShape::GetTensorShape(const ValuePtr &v) const {
MS_EXCEPTION_IF_NULL(v); MS_EXCEPTION_IF_NULL(v);
@ -134,6 +136,10 @@ void DynamicShape::SaveDynShapeAbsForMsFunction(const py::args &args, const py::
void DynamicShape::SaveOutputDynamicShape(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v) { void DynamicShape::SaveOutputDynamicShape(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v) {
MS_EXCEPTION_IF_NULL(op_run_info); MS_EXCEPTION_IF_NULL(op_run_info);
MS_EXCEPTION_IF_NULL(v); MS_EXCEPTION_IF_NULL(v);
// Do not use GetNext out abstract, top cell will change to dynamic shape by different tensor shape automaticity
if (op_run_info->base_op_run_info.op_name == kGetNextOpName) {
return;
}
// Save dynamic abs // Save dynamic abs
if (op_run_info->base_op_run_info.has_dynamic_output) { if (op_run_info->base_op_run_info.has_dynamic_output) {
SaveIdWithDynamicAbstract(v, op_run_info->base_op_run_info.abstract); SaveIdWithDynamicAbstract(v, op_run_info->base_op_run_info.abstract);
@ -151,7 +157,7 @@ void DynamicShape::SetDynamicInput(const py::object &cell, const py::args &args)
} }
} }
void DynamicShape::SetFeedDynamicInputAbs(const py::object &cell, const py::args &args, bool is_auto) { void DynamicShape::SetFeedDynamicInputAbs(const py::object &cell, const py::args &args) {
if (!HasFeedDynamicInput()) { if (!HasFeedDynamicInput()) {
return; return;
} }
@ -164,7 +170,6 @@ void DynamicShape::SetFeedDynamicInputAbs(const py::object &cell, const py::args
MS_LOG(DEBUG) << "Dynamic input size " << it->second.size() << " is not equal to real input size " << args.size(); MS_LOG(DEBUG) << "Dynamic input size " << it->second.size() << " is not equal to real input size " << args.size();
return; return;
} }
bool id_changed = false;
for (size_t i = 0; i < args.size(); i++) { for (size_t i = 0; i < args.size(); i++) {
auto abs = it->second.at(i); auto abs = it->second.at(i);
MS_EXCEPTION_IF_NULL(abs); MS_EXCEPTION_IF_NULL(abs);
@ -172,17 +177,13 @@ void DynamicShape::SetFeedDynamicInputAbs(const py::object &cell, const py::args
MS_EXCEPTION_IF_NULL(shape); MS_EXCEPTION_IF_NULL(shape);
if (shape->IsDynamic()) { if (shape->IsDynamic()) {
const auto &arg_id = PyNativeAlgo::PyParser::GetIdByPyObj(args[i]); const auto &arg_id = PyNativeAlgo::PyParser::GetIdByPyObj(args[i]);
MS_LOG(DEBUG) << "Set arg " << i << ", id " << arg_id << " to be dynamic shape; Arg self abs: " MS_LOG(DEBUG) << "Set cur arg " << i << ", id " << arg_id << " to be dynamic shape; Arg self abs: "
<< PyNativeAlgo::DataConvert::PyObjToValue(args[i])->ToAbstract()->Broaden()->ToString() << PyNativeAlgo::DataConvert::PyObjToValue(args[i])->ToAbstract()->Broaden()->ToString()
<< ", dynamic abs: " << abs->ToString(); << ", dynamic abs: " << abs->ToString();
id_with_dynamic_abs_[arg_id] = abs; id_with_dynamic_abs_[arg_id] = abs;
PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->EraseFromNodeAbsMap(arg_id); PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->EraseFromNodeAbsMap(arg_id);
id_changed = true;
} }
} }
if (id_changed && !is_auto) {
CheckPreviousTopCellCanBeDynamicShape(cell, args);
}
} }
py::object DynamicShape::GetDynamicInput(const py::object &actual_input) const { py::object DynamicShape::GetDynamicInput(const py::object &actual_input) const {
@ -313,24 +314,19 @@ TopCellInfoPtr DynamicShape::ChangeTopCellToDynamicShapeBySetInputs(const TopCel
const py::object &cell) { const py::object &cell) {
MS_EXCEPTION_IF_NULL(top_cell); MS_EXCEPTION_IF_NULL(top_cell);
// Change args shape // Change args shape
const auto it = feed_dynamic_input_.find(PyNativeAlgo::PyParser::GetIdByPyObj(cell));
if (it == feed_dynamic_input_.end()) {
return nullptr;
}
for (size_t i = 0; i < new_args_shape.size(); ++i) { for (size_t i = 0; i < new_args_shape.size(); ++i) {
top_cell->cell_self_info()->args_shape[i] = std::make_shared<abstract::Shape>(new_args_shape[i]); top_cell->cell_self_info()->args_shape[i] = std::make_shared<abstract::Shape>(new_args_shape[i]);
}
auto it = feed_dynamic_input_.find(PyNativeAlgo::PyParser::GetIdByPyObj(cell));
if (it != feed_dynamic_input_.end()) {
for (size_t i = 0; i < new_args_shape.size(); i++) {
auto abs = it->second.at(i);
MS_EXCEPTION_IF_NULL(abs);
auto shape = abs->BuildShape();
MS_EXCEPTION_IF_NULL(shape);
if (shape->IsDynamic()) {
const auto &arg_id = top_cell->cell_self_info()->args_id[i]; const auto &arg_id = top_cell->cell_self_info()->args_id[i];
MS_LOG(DEBUG) << "Set arg " << i << ", id " << arg_id << ", dynamic abs: " << abs->ToString(); it->second.at(i)->set_shape(top_cell->cell_self_info()->args_shape[i]);
id_with_dynamic_abs_[arg_id] = abs; MS_LOG(DEBUG) << "Change cur top cell arg " << i << ", id " << arg_id
<< ", dynamic abs: " << it->second.at(i)->ToString();
id_with_dynamic_abs_[arg_id] = it->second.at(i);
PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->EraseFromNodeAbsMap(arg_id); PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->EraseFromNodeAbsMap(arg_id);
} }
}
}
top_cell->ChangeTopCellInfo(new_args_shape.size()); top_cell->ChangeTopCellInfo(new_args_shape.size());
return top_cell; return top_cell;
} }
@ -352,7 +348,6 @@ TopCellInfoPtr DynamicShape::ChangeTopCellToDynamicShapeByAuto(const TopCellInfo
// Set to feed dynamic map, later shapes can match it // Set to feed dynamic map, later shapes can match it
MS_LOG(DEBUG) << "Set dynamic input for auto dynamic shape"; MS_LOG(DEBUG) << "Set dynamic input for auto dynamic shape";
SetDynamicInput(cell, args); SetDynamicInput(cell, args);
SetFeedDynamicInputAbs(cell, args, true);
top_cell->ChangeTopCellInfo(new_args_shape.size()); top_cell->ChangeTopCellInfo(new_args_shape.size());
return top_cell; return top_cell;
} }
@ -397,30 +392,23 @@ void DynamicShape::UpdateTopCellId(const py::args &args) const {
} }
TopCellInfoPtr DynamicShape::GetTopCellWithDynamicShape(const py::object &cell, const py::args &args, bool is_auto) { TopCellInfoPtr DynamicShape::GetTopCellWithDynamicShape(const py::object &cell, const py::args &args, bool is_auto) {
// Current return nullptr for disable auto dynamic shape feature; Later after a complete test will enable this
if (is_auto && !py::isinstance<py::none>(cell)) {
return nullptr;
}
const auto &cell_self_id = PyNativeAlgo::PyParser::GetIdByPyObj(cell); const auto &cell_self_id = PyNativeAlgo::PyParser::GetIdByPyObj(cell);
const auto &top_cell_list = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor()->top_cell_list(); const auto grad_executor = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
const auto it = std::find_if(top_cell_list.begin(), top_cell_list.end(), [&cell_self_id](const TopCellInfoPtr &elem) { size_t grad_order = grad_executor->grad_order();
return elem->cell_self_info() != nullptr && elem->cell_self_info()->cell_self_id == cell_self_id; std::vector<TopCellInfoPtr> match_top_cell_list;
std::copy_if(grad_executor->top_cell_list().begin(), grad_executor->top_cell_list().end(),
std::back_inserter(match_top_cell_list), [&cell_self_id, grad_order](const TopCellInfoPtr &elem) {
return elem->cell_self_info() != nullptr && elem->cell_self_info()->cell_self_id == cell_self_id &&
grad_order == elem->grad_order();
}); });
if (it != top_cell_list.end()) { for (const auto &it : match_top_cell_list) {
const auto &elem = *it;
if (elem->dynamic_shape()) {
MS_LOG(DEBUG) << "Elem has already dynamic shape";
return nullptr;
}
std::vector<ShapeVector> new_args_shape; std::vector<ShapeVector> new_args_shape;
FindMatchTopCell(elem, args, &new_args_shape); FindMatchTopCell(it, args, &new_args_shape);
// Change top cell to be dynamic // Change top cell to be dynamic
if (new_args_shape.size() == args.size()) {
if (is_auto) { if (is_auto) {
return ChangeTopCellToDynamicShapeByAuto(elem, new_args_shape, cell, args); return ChangeTopCellToDynamicShapeByAuto(it, new_args_shape, cell, args);
} else { } else {
return ChangeTopCellToDynamicShapeBySetInputs(elem, new_args_shape, cell); return ChangeTopCellToDynamicShapeBySetInputs(it, new_args_shape, cell);
}
} }
} }
UpdateTopCellId(args); UpdateTopCellId(args);
@ -434,7 +422,7 @@ void DynamicShape::CheckPreviousTopCellCanBeDynamicShape(const py::object &cell,
// In ms_function, new graph run before construct, so top cell create first; After that, set_dynamic_input call // In ms_function, new graph run before construct, so top cell create first; After that, set_dynamic_input call
// in construct, here change top cell to dynamic. // in construct, here change top cell to dynamic.
if (GetTopCellWithDynamicShape(cell, args, false) != nullptr) { if (GetTopCellWithDynamicShape(cell, args, false) != nullptr) {
MS_LOG(DEBUG) << "Convert ms_function top cell to dynamic shape."; MS_LOG(DEBUG) << "Convert cur top cell to dynamic shape.";
} }
} }
@ -478,29 +466,21 @@ void DynamicShape::FindMatchTopCell(const TopCellInfoPtr &top_cell, const py::ar
} }
// Check shape // Check shape
const auto &cur_shape = GetShapeFromAbstract(cur_value_abs)->shape(); const auto &cur_shape = GetShapeFromAbstract(cur_value_abs)->shape();
auto elem_shape = top_cell->cell_self_info()->args_shape[i]->shape(); const auto elem_shape = top_cell->cell_self_info()->args_shape[i]->shape();
ShapeVector new_shape;
// Rank dynamic
if (cur_shape.size() != elem_shape.size()) { if (cur_shape.size() != elem_shape.size()) {
MS_LOG(DEBUG) << "The " << i << "th args shape size is not the same, cur is " << cur_shape.size() MS_LOG(DEBUG) << "The " << i << "th args shape size is not the same, cur is " << cur_shape.size()
<< " and the elem is " << elem_shape.size(); << " and the elem is " << elem_shape.size() << ", change shape to dynamic rank";
return; new_shape.emplace_back(UNKNOWN_RANK);
continue;
} }
ShapeVector new_shape; // Shape dynamic
for (size_t j = 0; j < cur_shape.size(); ++j) { for (size_t j = 0; j < cur_shape.size(); ++j) {
if (cur_shape[j] == elem_shape[j]) { (void)new_shape.emplace_back(UNKNOWN_DIM);
(void)new_shape.emplace_back(cur_shape[j]);
} else {
(void)new_shape.emplace_back(-1);
} }
}
// All shape can not be -1, and all shape can not be actual.
bool is_any_unknown = std::any_of(new_shape.begin(), new_shape.end(), [](int64_t s) { return s == -1; });
bool is_any_actual = std::any_of(new_shape.begin(), new_shape.end(), [](int64_t s) { return s != -1; });
if (is_any_unknown && is_any_actual) {
(void)new_args_shape->emplace_back(new_shape); (void)new_args_shape->emplace_back(new_shape);
} else { MS_LOG(DEBUG) << "Cur shape " << cur_shape << ", elem shape " << elem_shape << ", new shape " << new_shape;
MS_LOG(DEBUG) << "Not support all shape unknown or actual.Cur shape " << cur_shape << ", elem shape "
<< elem_shape << ", and new shape is " << new_shape;
}
} }
} }
} // namespace pynative } // namespace pynative

View File

@ -37,7 +37,7 @@ class DynamicShape {
DynamicShape() = default; DynamicShape() = default;
~DynamicShape() = default; ~DynamicShape() = default;
void SetDynamicInput(const py::object &cell, const py::args &args); void SetDynamicInput(const py::object &cell, const py::args &args);
void SetFeedDynamicInputAbs(const py::object &cell, const py::args &args, bool is_auto); void SetFeedDynamicInputAbs(const py::object &cell, const py::args &args);
py::object GetDynamicInput(const py::object &actual_input) const; py::object GetDynamicInput(const py::object &actual_input) const;
ValuePtr GetSensValueForDynamicShapeOutput(const TopCellInfoPtr &top_cell, const ValuePtr &v, ValuePtr GetSensValueForDynamicShapeOutput(const TopCellInfoPtr &top_cell, const ValuePtr &v,
const AnfNodePtr &node) const; const AnfNodePtr &node) const;

View File

@ -369,7 +369,7 @@ void ForwardExecutor::ProcessBeforeNewGraph(const py::object &cell, const py::ar
if (py::isinstance<Cell>(cell)) { if (py::isinstance<Cell>(cell)) {
PushForwardCell(cell); PushForwardCell(cell);
} }
dynamic_shape()->SetFeedDynamicInputAbs(cell, args, false); dynamic_shape()->SetFeedDynamicInputAbs(cell, args);
} }
void ForwardExecutor::ProcessBeforeEndGraph(const py::object &cell, const py::args &args) { void ForwardExecutor::ProcessBeforeEndGraph(const py::object &cell, const py::args &args) {

View File

@ -181,7 +181,7 @@ void GradExecutor::ClearCellRes(const std::string &cell_id) {
for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) { for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) {
MS_EXCEPTION_IF_NULL(*it); MS_EXCEPTION_IF_NULL(*it);
const auto &top_cell_id = (*it)->cell_id(); const auto &top_cell_id = (*it)->cell_id();
const auto &already_run_cell_id = (*it)->already_run_cell_id(); auto already_run_cell_id = (*it)->already_run_cell_id();
if (IsCellObjIdEq(cell_id, top_cell_id)) { if (IsCellObjIdEq(cell_id, top_cell_id)) {
MS_LOG(DEBUG) << "Clear top cell resource. Top cell id " << top_cell_id; MS_LOG(DEBUG) << "Clear top cell resource. Top cell id " << top_cell_id;
(*it)->Clear(); (*it)->Clear();
@ -286,7 +286,7 @@ void GradExecutor::NewGraphInner(const py::object *ret, const py::object &cell,
auto top_it = already_run_top_cell_.find(already_run_cell_id); auto top_it = already_run_top_cell_.find(already_run_cell_id);
if (top_it != already_run_top_cell_.end()) { if (top_it != already_run_top_cell_.end()) {
// Top cell forward run. // Top cell forward run.
const auto &pre_top_cell = top_it->second; auto pre_top_cell = top_it->second;
MS_EXCEPTION_IF_NULL(pre_top_cell); MS_EXCEPTION_IF_NULL(pre_top_cell);
MS_LOG(DEBUG) << "Pre top cell, hook_changed " << pre_top_cell->hook_changed() << ", is_dynamic_structure " MS_LOG(DEBUG) << "Pre top cell, hook_changed " << pre_top_cell->hook_changed() << ", is_dynamic_structure "
<< pre_top_cell->is_dynamic_structure(); << pre_top_cell->is_dynamic_structure();
@ -349,26 +349,28 @@ void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::object &cell
MS_LOG(WARNING) << "Too many top cell has been built, please check if the cell " << cell.cast<CellPtr>()->ToString() MS_LOG(WARNING) << "Too many top cell has been built, please check if the cell " << cell.cast<CellPtr>()->ToString()
<< " is repeatedly defined in each step/epoch, or the net input shape changes frequently."; << " is repeatedly defined in each step/epoch, or the net input shape changes frequently.";
} }
// Create top cell // Find matched dynamic shape top cell
auto fg = std::make_shared<FuncGraph>(); auto fg = std::make_shared<FuncGraph>();
auto df_builder = std::make_shared<FuncGraph>(); auto df_builder = std::make_shared<FuncGraph>();
auto resource = std::make_shared<pipeline::Resource>(); auto resource = std::make_shared<pipeline::Resource>();
auto top_cell = dynamic_shape()->GetTopCellWithDynamicShape(cell, args, true);
if (top_cell == nullptr) {
const auto &already_run_cell_id = GetAlreadyRunCellId(cell_id); const auto &already_run_cell_id = GetAlreadyRunCellId(cell_id);
auto top_cell = top_cell =
std::make_shared<TopCellInfo>(is_topest, grad_order_, resource, fg, df_builder, cell_id, already_run_cell_id); std::make_shared<TopCellInfo>(is_topest, grad_order_, cell_id, already_run_cell_id, resource, fg, df_builder);
top_cell->SetCellSelfInfoForTopCell(cell, args);
(void)top_cell_list_.emplace_back(top_cell);
} else {
auto new_top_cell = std::make_shared<TopCellInfo>(*top_cell, resource, fg, df_builder);
top_cell->Clear();
EraseTopCellFromTopCellList(top_cell);
top_cell = new_top_cell;
MS_LOG(INFO) << "The shape change of the network input tensor is detected, "
"and the dynamic shape process is triggered. The bprop graph needs to be recompiled, "
"which may take some time";
}
top_cell->set_forward_already_run(true); top_cell->set_forward_already_run(true);
top_cell->set_input_args_id(input_args_id); top_cell->set_input_args_id(input_args_id);
TopCellInfoPtr top_cell_with_dynamic_shape = dynamic_shape()->GetTopCellWithDynamicShape(cell, args, true);
if (top_cell_with_dynamic_shape != nullptr) {
top_cell->set_cell_id(top_cell_with_dynamic_shape->cell_id());
top_cell->set_already_run_cell_id(top_cell_with_dynamic_shape->already_run_cell_id());
top_cell->set_cell_self_info(top_cell_with_dynamic_shape->cell_self_info());
EraseTopCellFromTopCellList(top_cell_with_dynamic_shape);
MS_LOG(DEBUG) << "Pre top cell and current top cell merged to one top cell with dynamic shape";
} else {
top_cell->SetCellSelfInfoForTopCell(cell, args);
}
(void)top_cell_list_.emplace_back(top_cell);
PushHighOrderGraphStack(top_cell); PushHighOrderGraphStack(top_cell);
set_top_cell(top_cell); set_top_cell(top_cell);
MS_LOG(DEBUG) << "New top graph, fg ptr " << fg.get() << " resource ptr " << resource.get(); MS_LOG(DEBUG) << "New top graph, fg ptr " << fg.get() << " resource ptr " << resource.get();
@ -387,17 +389,16 @@ void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string &
output_node = GetObjNode(value_ptr, PyNativeAlgo::Common::GetIdByValue(value_ptr)); output_node = GetObjNode(value_ptr, PyNativeAlgo::Common::GetIdByValue(value_ptr));
} }
MS_EXCEPTION_IF_NULL(output_node); MS_EXCEPTION_IF_NULL(output_node);
if (top_cell()->dynamic_shape()) {
abstract::AbstractBasePtr last_node_abs = nullptr; abstract::AbstractBasePtr last_node_abs = nullptr;
if (output_node->abstract() == nullptr) { if (output_node->abstract() == nullptr) {
last_node_abs = v->ToAbstract()->Broaden(); last_node_abs = v->ToAbstract()->Broaden();
output_node->set_abstract(last_node_abs);
} else { } else {
last_node_abs = output_node->abstract(); last_node_abs = output_node->abstract();
} }
MS_EXCEPTION_IF_NULL(last_node_abs); MS_EXCEPTION_IF_NULL(last_node_abs);
// Set last output abstract and will be used for sens // Set last output abstract and will be used for sens
top_cell()->set_last_output_abs(last_node_abs); top_cell()->set_last_output_abs(last_node_abs);
}
// Set last node and sens for build adjoint // Set last node and sens for build adjoint
const auto &sens_value = dynamic_shape()->GetSensValueForDynamicShapeOutput(top_cell(), v, output_node); const auto &sens_value = dynamic_shape()->GetSensValueForDynamicShapeOutput(top_cell(), v, output_node);
auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr(); auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
@ -811,16 +812,12 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const prim::GradOperationPtr &grad, con
ss << "grad{" << arg_size << "}"; ss << "grad{" << arg_size << "}";
bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
bprop_graph->debug_info()->set_name(ss.str()); bprop_graph->debug_info()->set_name(ss.str());
// Get the parameters items and add the value to args_spec // Set sens abstract
if (top_cell()->dynamic_shape() && grad->sens_param()) { if (grad->sens_param()) {
MS_EXCEPTION_IF_NULL(top_cell()->last_output_abs()); MS_EXCEPTION_IF_NULL(top_cell()->last_output_abs());
auto shape = top_cell()->last_output_abs()->BuildShape();
MS_EXCEPTION_IF_NULL(shape);
if (shape->IsDynamic()) {
const auto &sens_id = PyNativeAlgo::PyParser::GetIdByPyObj(args[arg_size - 1]); const auto &sens_id = PyNativeAlgo::PyParser::GetIdByPyObj(args[arg_size - 1]);
dynamic_shape()->SetIdWithDynamicAbs(sens_id, top_cell()->last_output_abs()); dynamic_shape()->SetIdWithDynamicAbs(sens_id, top_cell()->last_output_abs());
} }
}
UpdateParamAbsByArgs(PyNativeAlgo::PyParser::FilterTensorArgs(args, grad->sens_param_), bprop_graph); UpdateParamAbsByArgs(PyNativeAlgo::PyParser::FilterTensorArgs(args, grad->sens_param_), bprop_graph);
// Dynamic shape graph need add some other pass // Dynamic shape graph need add some other pass
if (top_cell()->dynamic_shape()) { if (top_cell()->dynamic_shape()) {

View File

@ -109,6 +109,7 @@ class GradExecutor {
CNodePtr ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const; CNodePtr ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const;
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell,
const py::object &grad_position, const py::args &args); const py::object &grad_position, const py::args &args);
std::string GetAlreadyRunCellId(const std::string &cell_id) const;
void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v) const; void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v) const;
AnfNodePtr GetInput(const ValuePtr &v) const; AnfNodePtr GetInput(const ValuePtr &v) const;
void ClearGrad(const py::object &cell, const py::args &args); void ClearGrad(const py::object &cell, const py::args &args);
@ -159,7 +160,6 @@ class GradExecutor {
void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const; void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const;
void DoGradForCustomBprop(const py::object &cell, const py::args &args, const ValuePtr &out, void DoGradForCustomBprop(const py::object &cell, const py::args &args, const ValuePtr &out,
const std::string &out_id); const std::string &out_id);
std::string GetAlreadyRunCellId(const std::string &cell_id) const;
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args) const; std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args) const;
void GradNetInner(const py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell, void GradNetInner(const py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
const py::object &weights, const py::object &grad_position, const py::args &args); const py::object &weights, const py::object &grad_position, const py::args &args);

View File

@ -36,8 +36,8 @@ void TopCellInfo::SetCellSelfInfoForTopCell(const py::object &cell, const py::ar
(void)args_shape.emplace_back(shape_ptr); (void)args_shape.emplace_back(shape_ptr);
(void)args_type.emplace_back(abs->BuildType()); (void)args_type.emplace_back(abs->BuildType());
} }
set_cell_self_info( cell_self_info_ =
std::make_shared<CellSelfInfo>(PyNativeAlgo::PyParser::GetIdByPyObj(cell), args_id, args_shape, args_type)); std::make_shared<CellSelfInfo>(PyNativeAlgo::PyParser::GetIdByPyObj(cell), args_id, args_shape, args_type);
} }
bool TopCellInfo::IsSubCell(const std::string &cell_id) const { bool TopCellInfo::IsSubCell(const std::string &cell_id) const {
@ -108,21 +108,31 @@ void TopCellInfo::ClearDeviceMemory() const {
void TopCellInfo::Clear() { void TopCellInfo::Clear() {
MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_; MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
op_num_ = 0;
is_dynamic_structure_ = false; is_dynamic_structure_ = false;
is_real_dynamic_structure_ = false;
vm_compiled_ = false; vm_compiled_ = false;
hook_changed_ = false;
ms_function_flag_ = false; ms_function_flag_ = false;
is_init_kpynative_ = false; is_init_kpynative_ = false;
need_compile_graph_ = false;
forward_already_run_ = false; forward_already_run_ = false;
need_compile_graph_ = false;
op_num_ = 0;
grad_order_ = 0;
fg_ = nullptr;
df_builder_ = nullptr;
k_pynative_cell_ptr_ = nullptr;
k_pynative_cell_ptr_ = nullptr;
last_output_abs_ = nullptr;
cell_self_info_ = nullptr;
cell_id_.clear();
already_run_cell_id_.clear();
input_args_id_.clear(); input_args_id_.clear();
all_op_info_.clear(); all_op_info_.clear();
resource_ = nullptr; grad_operation_.clear();
df_builder_ = nullptr;
fg_ = nullptr;
k_pynative_cell_ptr_ = nullptr;
graph_info_map_.clear(); graph_info_map_.clear();
sub_cell_list_.clear(); sub_cell_list_.clear();
sub_cell_hook_changed_.clear();
cell_backward_hook_op_.clear();
op_info_with_tensor_id_.clear(); op_info_with_tensor_id_.clear();
tensor_id_with_tensor_object_.clear(); tensor_id_with_tensor_object_.clear();
op_info_with_ms_func_forward_tensors_.clear(); op_info_with_ms_func_forward_tensors_.clear();
@ -232,24 +242,16 @@ void TopCellInfo::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const Va
} }
} }
std::string TopCellInfo::GetAlreadyRunCellId(const std::string &cell_id) const {
std::string already_run_cell_id(cell_id);
size_t grad_order = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor()->grad_order();
already_run_cell_id += std::to_string(grad_order == 0 ? 1 : grad_order);
already_run_cell_id += "_" + grad_operation_;
MS_LOG(DEBUG) << "Get already run top cell id " << already_run_cell_id;
return already_run_cell_id;
}
void TopCellInfo::ChangeTopCellInfo(size_t args_size) { void TopCellInfo::ChangeTopCellInfo(size_t args_size) {
dynamic_shape_ = true;
std::string new_cell_id = this->cell_self_info()->cell_self_id; std::string new_cell_id = this->cell_self_info()->cell_self_id;
for (size_t i = 0; i < args_size; ++i) { for (size_t i = 0; i < args_size; ++i) {
new_cell_id += "_" + this->cell_self_info()->args_shape[i]->ToString(); new_cell_id += "_" + this->cell_self_info()->args_shape[i]->ToString();
new_cell_id += this->cell_self_info()->args_type[i]->ToString(); new_cell_id += cell_self_info_->args_type[i]->ToString();
} }
MS_LOG(DEBUG) << "Change top cell " << this->cell_id() << " to be dynamic " << new_cell_id; MS_LOG(DEBUG) << "Change pre top cell " << cell_id_ << " to be dynamic " << new_cell_id;
set_cell_id(new_cell_id); cell_id_ = new_cell_id;
set_already_run_cell_id(GetAlreadyRunCellId(new_cell_id)); already_run_cell_id_ = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor()->GetAlreadyRunCellId(new_cell_id);
} }
} // namespace pynative } // namespace pynative
} // namespace mindspore } // namespace mindspore

View File

@ -78,17 +78,26 @@ using CellSelfInfoPtr = std::shared_ptr<CellSelfInfo>;
class TopCellInfo { class TopCellInfo {
public: public:
TopCellInfo() = default;
~TopCellInfo() = default; ~TopCellInfo() = default;
TopCellInfo(bool topest, size_t grad_order, pipeline::ResourcePtr r, FuncGraphPtr fg, FuncGraphPtr df, TopCellInfo(bool topest, size_t grad_order, std::string cellid, std::string already_run_cell_id,
std::string cellid, std::string already_run_cell_id) pipeline::ResourcePtr r, FuncGraphPtr fg, FuncGraphPtr df)
: is_topest_(topest), : is_topest_(topest),
grad_order_(grad_order), grad_order_(grad_order),
cell_id_(std::move(cellid)),
already_run_cell_id_(std::move(already_run_cell_id)),
resource_(std::move(r)), resource_(std::move(r)),
fg_(std::move(fg)), fg_(std::move(fg)),
df_builder_(std::move(df)), df_builder_(std::move(df)) {}
cell_id_(std::move(cellid)),
already_run_cell_id_(std::move(already_run_cell_id)) {} TopCellInfo(const TopCellInfo &top_cell, pipeline::ResourcePtr r, FuncGraphPtr fg, FuncGraphPtr df)
: is_topest_(top_cell.is_topest_),
grad_order_(top_cell.grad_order_),
cell_id_(top_cell.cell_id_),
already_run_cell_id_(top_cell.already_run_cell_id_),
cell_self_info_(top_cell.cell_self_info_),
resource_(std::move(r)),
fg_(std::move(fg)),
df_builder_(std::move(df)) {}
bool is_init_kpynative() const { return is_init_kpynative_; } bool is_init_kpynative() const { return is_init_kpynative_; }
void set_init_kpynative(bool init) { is_init_kpynative_ = init; } void set_init_kpynative(bool init) { is_init_kpynative_ = init; }
@ -126,14 +135,12 @@ class TopCellInfo {
const std::string &cell_id() const { return cell_id_; } const std::string &cell_id() const { return cell_id_; }
void set_cell_id(const std::string &cell_id) { cell_id_ = cell_id; } void set_cell_id(const std::string &cell_id) { cell_id_ = cell_id; }
const std::string &already_run_cell_id() const { return already_run_cell_id_; } const std::string &already_run_cell_id() const { return already_run_cell_id_; }
void set_already_run_cell_id(const std::string &already_run_cell_id) { already_run_cell_id_ = already_run_cell_id; }
void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; } void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; }
const std::string &grad_operation() const { return grad_operation_; } const std::string &grad_operation() const { return grad_operation_; }
void set_grad_operation(const std::string &grad_operation) { grad_operation_ = grad_operation; } void set_grad_operation(const std::string &grad_operation) { grad_operation_ = grad_operation; }
const abstract::AbstractBasePtr &last_output_abs() const { return last_output_abs_; } const abstract::AbstractBasePtr &last_output_abs() const { return last_output_abs_; }
void set_last_output_abs(const abstract::AbstractBasePtr &last_output_abs) { last_output_abs_ = last_output_abs; } void set_last_output_abs(const abstract::AbstractBasePtr &last_output_abs) { last_output_abs_ = last_output_abs; }
CellSelfInfoPtr cell_self_info() const { return cell_self_info_; } CellSelfInfoPtr cell_self_info() const { return cell_self_info_; }
void set_cell_self_info(const CellSelfInfoPtr &cell_self_info) { cell_self_info_ = cell_self_info; }
void SetCellSelfInfoForTopCell(const py::object &cell, const py::args &args); void SetCellSelfInfoForTopCell(const py::object &cell, const py::args &args);
void EraseFromSubCellList(const std::string &cell_id) { (void)sub_cell_list_.erase(cell_id); } void EraseFromSubCellList(const std::string &cell_id) { (void)sub_cell_list_.erase(cell_id); }
void SetSubCellList(const std::string &cell_id) { (void)sub_cell_list_.emplace(cell_id); } void SetSubCellList(const std::string &cell_id) { (void)sub_cell_list_.emplace(cell_id); }
@ -171,7 +178,6 @@ class TopCellInfo {
const std::vector<int64_t> &index); const std::vector<int64_t> &index);
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const ValuePtr &v, const AnfNodePtr &node, void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const ValuePtr &v, const AnfNodePtr &node,
bool is_param = false); bool is_param = false);
std::string GetAlreadyRunCellId(const std::string &cell_id) const;
void ChangeTopCellInfo(size_t args_size); void ChangeTopCellInfo(size_t args_size);
void ClearDeviceMemory() const; void ClearDeviceMemory() const;
void Clear(); void Clear();
@ -199,17 +205,17 @@ class TopCellInfo {
bool need_compile_graph_{false}; bool need_compile_graph_{false};
size_t op_num_{0}; size_t op_num_{0};
size_t grad_order_{0}; size_t grad_order_{0};
pipeline::ResourcePtr resource_{nullptr};
FuncGraphPtr fg_{nullptr};
FuncGraphPtr df_builder_{nullptr};
ad::KPynativeCellPtr k_pynative_cell_ptr_{nullptr};
std::string cell_id_; std::string cell_id_;
std::string already_run_cell_id_; std::string already_run_cell_id_;
std::string input_args_id_; std::string input_args_id_;
std::string all_op_info_; std::string all_op_info_;
std::string grad_operation_; std::string grad_operation_;
abstract::AbstractBasePtr last_output_abs_;
CellSelfInfoPtr cell_self_info_{nullptr}; CellSelfInfoPtr cell_self_info_{nullptr};
pipeline::ResourcePtr resource_{nullptr};
FuncGraphPtr fg_{nullptr};
FuncGraphPtr df_builder_{nullptr};
ad::KPynativeCellPtr k_pynative_cell_ptr_{nullptr};
abstract::AbstractBasePtr last_output_abs_{nullptr};
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_; OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
mindspore::HashSet<std::string> sub_cell_list_; mindspore::HashSet<std::string> sub_cell_list_;
// Record `register hook` or `remove hook` function has been called by sub cell // Record `register hook` or `remove hook` function has been called by sub cell

View File

@ -181,6 +181,30 @@ def stop_gradient(x):
return x return x
def TensorShape(x):
"""Implement `TensorShape`."""
return Tensor(F.shape(x))
def DynamicBroadcastGradientArgs(x, y):
"""Implement `DynamicBroadcastGradientArgs`."""
return -1, -1
def StridedSlice(x, begin, end, stride):
"""Implement `StridedSlice`."""
if isinstance(x, Tensor):
x = x.asnumpy()
ret = x
for i in range(len(x)):
ret[i:] = x[begin[i], end[i], stride[i]]
return Tensor(ret)
ret = x
for i in range(len(x)):
ret[i:] = x[begin[i], end[i], stride[i]]
return ret
hyper_map = C.HyperMap() hyper_map = C.HyperMap()

View File

@ -64,18 +64,18 @@ class CommonFunc():
self.np_net = np_net self.np_net = np_net
self.input_np = input_np self.input_np = input_np
self.input_np_bp = input_np self.input_np_t = Tensor(input_np)
self.out_np = np.array(1).astype(input_np.dtype) self.out_np = np.array(1).astype(input_np.dtype)
def forward_cmp(self): def forward_cmp(self):
out_ms = self.ms_net(Tensor(self.input_np)) out_ms = self.ms_net(self.input_np_t)
self.out_np = self.np_net(self.input_np) self.out_np = self.np_net(self.input_np)
assert np.all(out_ms.asnumpy() == self.out_np) assert np.all(out_ms.asnumpy() == self.out_np)
def grad_impl(self): def grad_impl(self):
grad_net = GradOfFirstInput(self.ms_net) grad_net = GradOfFirstInput(self.ms_net)
grad_net.set_train() grad_net.set_train()
grad_net(Tensor(self.input_np_bp), Tensor(self.out_np)) grad_net(self.input_np_t, Tensor(self.out_np))
@pytest.mark.level0 @pytest.mark.level0

View File

@ -70,14 +70,14 @@ class CommonFunc():
8 * 16 * 3).reshape(8, 16, 3).astype(np.float32) 8 * 16 * 3).reshape(8, 16, 3).astype(np.float32)
self.input_np1 = np.arange( self.input_np1 = np.arange(
16 * 32 * 3).reshape(16, 32, 3).astype(np.float32) 16 * 32 * 3).reshape(16, 32, 3).astype(np.float32)
self.input_np0_bp = self.input_np0.copy() self.input_np0_t = Tensor(self.input_np0)
self.input_np1_bp = self.input_np1.copy() self.input_np1_t = Tensor(self.input_np1)
self.out_np0 = np.array(1).astype(self.input_np0.dtype) self.out_np0 = np.array(1).astype(self.input_np0.dtype)
self.out_np1 = np.array(1).astype(self.input_np1.dtype) self.out_np1 = np.array(1).astype(self.input_np1.dtype)
def forward_cmp(self): def forward_cmp(self):
out_ms0, out_ms1 = self.ms_net( out_ms0, out_ms1 = self.ms_net(
Tensor(self.input_np0), Tensor(self.input_np1)) self.input_np0_t, self.input_np1_t)
self.out_np0, self.out_np1 = self. np_net( self.out_np0, self.out_np1 = self. np_net(
self.input_np0, self.input_np1) self.input_np0, self.input_np1)
assert np.all(out_ms0.asnumpy() == self.out_np0) assert np.all(out_ms0.asnumpy() == self.out_np0)
@ -86,7 +86,7 @@ class CommonFunc():
def grad_impl(self): def grad_impl(self):
grad_net = GradOfAllInputs(self.ms_net) grad_net = GradOfAllInputs(self.ms_net)
grad_net.set_train() grad_net.set_train()
grad_net(Tensor(self.input_np0_bp), Tensor(self.input_np1_bp), grad_net(self.input_np0_t, self.input_np1_t,
(Tensor(self.out_np0), Tensor(self.out_np1))) (Tensor(self.out_np0), Tensor(self.out_np1)))

View File

@ -165,8 +165,8 @@ def test_pynative_auto_dynamic_shape_mixing_static_shape_and_dynamic_shape_1():
# run second shape # run second shape
input_x2 = Tensor(np.random.rand(2, 3, 6, 16).astype(np.float32) * 2) input_x2 = Tensor(np.random.rand(2, 3, 6, 16).astype(np.float32) * 2)
input_y2 = Tensor(np.random.rand(2, 3, 6, 16).astype(np.float32) * 5) input_y2 = Tensor(np.random.rand(2, 3, 6, 16).astype(np.float32) * 5)
net.set_inputs(Tensor(shape=[2, 3, 6, None], dtype=ms.float32), net.set_inputs(Tensor(shape=[None, None, None, None], dtype=ms.float32),
Tensor(shape=[2, 3, None, None], dtype=ms.float32)) Tensor(shape=[None, None, None, None], dtype=ms.float32))
out = net(input_x2, input_y2) out = net(input_x2, input_y2)
_ = grad_op(net)(input_x2, input_y2, out) _ = grad_op(net)(input_x2, input_y2, out)
@ -204,8 +204,8 @@ def test_pynative_auto_dynamic_shape_mixing_static_shape_and_dynamic_shape_2():
# run first shape # run first shape
input_x = Tensor(np.random.rand(2, 3, 6, 8).astype(np.float32) * 2) input_x = Tensor(np.random.rand(2, 3, 6, 8).astype(np.float32) * 2)
input_y = Tensor(np.random.rand(2, 3, 6, 8).astype(np.float32) * 5) input_y = Tensor(np.random.rand(2, 3, 6, 8).astype(np.float32) * 5)
net.set_inputs(Tensor(shape=[2, 3, 6, None], dtype=ms.float32), net.set_inputs(Tensor(shape=[None, None, None, None], dtype=ms.float32),
Tensor(shape=[2, 3, None, None], dtype=ms.float32)) Tensor(shape=[None, None, None, None], dtype=ms.float32))
out = net(input_x, input_y) out = net(input_x, input_y)
_ = grad_op(net)(input_x, input_y, out) _ = grad_op(net)(input_x, input_y, out)