forked from mindspore-Ecosystem/mindspore
fix bug of loadding ctrl input tensors failed in control sink mode
This commit is contained in:
parent
017cdbe865
commit
729ea8cc55
|
@ -375,18 +375,16 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP(
|
|||
return assign_add_one;
|
||||
}
|
||||
|
||||
bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::Context> &context,
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
||||
bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
||||
if (!NeedInsertSwitch()) {
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
auto input_nodes = kernel_graph_ptr->inputs();
|
||||
std::vector<tensor::TensorPtr> inputs;
|
||||
LoadSwitchInputs(&inputs);
|
||||
std::shared_ptr<std::vector<tensor::TensorPtr>> inputsPtr = std::make_shared<std::vector<tensor::TensorPtr>>(inputs);
|
||||
context->SetResult(session::kInputCtrlTensors, inputsPtr);
|
||||
kernel_graph_ptr->set_input_ctrl_tensors(inputsPtr);
|
||||
size_t input_ctrl_size = inputs.size();
|
||||
// inputs_node:include four ctrl nodes in the back. such as:conv,loop_cnt, ites_loop, zero, one.
|
||||
// deal four ctrl nodes.
|
||||
|
|
|
@ -53,8 +53,7 @@ class KernelAdjust {
|
|||
}
|
||||
void Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
|
||||
void InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
|
||||
bool StepLoadCtrlInputs(const std::shared_ptr<session::Context> &context,
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
|
||||
bool StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
|
||||
void Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr);
|
||||
static bool NeedInsertSwitch();
|
||||
CNodePtr CreateStreamActiveOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
|
||||
|
|
|
@ -517,7 +517,7 @@ void AscendSession::RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input
|
|||
|
||||
void AscendSession::GenerateTaskInfo(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
(void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(context_, kernel_graph);
|
||||
(void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph);
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
bool ret_ok = runtime_instance->GenTask(kernel_graph.get());
|
||||
|
|
|
@ -107,6 +107,12 @@ class KernelGraph : public FuncGraph {
|
|||
std::vector<std::shared_ptr<KernelGraph>> child_graph_order() const { return child_graph_order_; }
|
||||
// checkout whether current graph is leaf graph
|
||||
bool IsLeafGraph() const;
|
||||
// set input_tensors pointer of control parameter
|
||||
void set_input_ctrl_tensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &input_tensors_ptr) {
|
||||
input_ctrl_tensors_ = input_tensors_ptr;
|
||||
}
|
||||
// get input_tensors pointer of control parameter
|
||||
std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors() const { return input_ctrl_tensors_; }
|
||||
|
||||
private:
|
||||
// remove value node form graph
|
||||
|
@ -150,6 +156,8 @@ class KernelGraph : public FuncGraph {
|
|||
std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> node_to_child_graphs_;
|
||||
// child graph execute order in root graph
|
||||
std::vector<std::shared_ptr<KernelGraph>> child_graph_order_;
|
||||
// input_tensors of control parameter
|
||||
std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors_;
|
||||
};
|
||||
} // namespace session
|
||||
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
|
||||
|
|
|
@ -266,23 +266,12 @@ AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input,
|
|||
return make_tuple;
|
||||
}
|
||||
|
||||
bool NeedInsertSwitch() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() &&
|
||||
ConfigManager::GetInstance().iter_num() > 1);
|
||||
}
|
||||
|
||||
size_t LoadCtrlInputTensor(const std::shared_ptr<Context> &context, std::vector<tensor::TensorPtr> *inputs) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (!NeedInsertSwitch()) {
|
||||
(void)context->results_.erase(kInputCtrlTensors);
|
||||
size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
|
||||
MS_LOG(INFO) << "Load kInputCtrlTensors";
|
||||
auto inputs_params = graph->input_ctrl_tensors();
|
||||
if (inputs_params == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
MS_LOG(INFO) << "Load kInputCtrlTensors";
|
||||
auto inputs_params =
|
||||
context->GetResult(kInputCtrlTensors).cast<const std::shared_ptr<std::vector<tensor::TensorPtr>>>();
|
||||
MS_EXCEPTION_IF_NULL(inputs_params);
|
||||
if (inputs_params->empty()) {
|
||||
MS_LOG(EXCEPTION) << "Illegal empty inputs_params";
|
||||
}
|
||||
|
@ -686,11 +675,10 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
const std::vector<tensor::TensorPtr> &inputs_const) const {
|
||||
std::vector<tensor::TensorPtr> inputs(inputs_const);
|
||||
size_t input_ctrl_size = 1;
|
||||
MS_EXCEPTION_IF_NULL(context_);
|
||||
if (context_->HasResult(kInputCtrlTensors)) {
|
||||
input_ctrl_size = LoadCtrlInputTensor(context_, &inputs);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (kernel_graph->input_ctrl_tensors()) {
|
||||
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
|
||||
}
|
||||
auto input_nodes = kernel_graph->inputs();
|
||||
if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
|
||||
MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
|
||||
|
|
|
@ -39,8 +39,7 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve
|
|||
} // namespace ascend
|
||||
void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; }
|
||||
void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; }
|
||||
bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::Context> &context,
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
||||
bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
||||
return true;
|
||||
}
|
||||
bool KernelAdjust::NeedInsertSwitch() { return true; }
|
||||
|
|
Loading…
Reference in New Issue