forked from mindspore-Ecosystem/mindspore
session-code-review
This commit is contained in:
parent
932b7649e7
commit
4de9e2506a
|
@ -52,6 +52,7 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
|
|||
return nullptr;
|
||||
}
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
|
||||
MS_EXCEPTION_IF_NULL(param_value);
|
||||
auto py_param = param_value->value();
|
||||
return py_param.ptr();
|
||||
}
|
||||
|
@ -69,7 +70,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
|
|||
}
|
||||
if (node->isa<Parameter>()) {
|
||||
for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
|
||||
if (input_idx > input_tensors.size()) {
|
||||
if (input_idx >= input_tensors.size()) {
|
||||
MS_LOG(EXCEPTION) << "input idx:" << input_idx << "out of range:" << input_tensors.size();
|
||||
}
|
||||
if (graph.inputs()[input_idx] == node) {
|
||||
|
@ -149,6 +150,8 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
|
|||
}
|
||||
|
||||
ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto value_node = anf->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
|
@ -229,6 +232,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph,
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||
auto value_node = std::make_shared<ValueNode>(input_tensor);
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
// construct abstract of value node
|
||||
auto type_of_tensor = input_tensor->Dtype();
|
||||
auto shape_of_tensor = input_tensor->shape();
|
||||
|
@ -242,6 +246,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph,
|
|||
|
||||
ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
|
||||
int tensor_mask) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto param = graph->NewParameter();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
if (tensor_mask == kParameterWeightTensorMask) {
|
||||
|
@ -295,6 +300,7 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
|
|||
}
|
||||
|
||||
bool ExistSummaryNode(const KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto ret = graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto all_nodes = DeepLinkedGraphSearch(ret);
|
||||
|
@ -315,7 +321,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
|
|||
if (!anf->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto m_tensor = GetParamDefaultInputTensor(anf);
|
||||
auto valid_inputs = graph->MutableValidInputs();
|
||||
MS_EXCEPTION_IF_NULL(valid_inputs);
|
||||
|
@ -344,6 +350,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
|
|||
|
||||
AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
|
||||
auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
|
||||
if (parameters.empty()) {
|
||||
|
@ -482,6 +489,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
|
|||
|
||||
ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto value_node = anf->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
|
||||
|
@ -509,6 +517,7 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker
|
|||
|
||||
ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!anf->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
|
||||
}
|
||||
|
@ -536,6 +545,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
|
|||
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
|
||||
auto graph = NewKernelGraph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
|
||||
size_t from_other_graph_depend_num = 0;
|
||||
for (const auto &node : lst) {
|
||||
|
@ -583,6 +593,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
auto graph = NewKernelGraph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
front_backend_graph_map_[func_graph] = graph;
|
||||
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
|
||||
|
||||
|
@ -722,8 +733,8 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
}
|
||||
auto anf_outputs = kernel_graph->outputs();
|
||||
for (auto &item : anf_outputs) {
|
||||
MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
|
||||
if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
|
||||
outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
|
||||
continue;
|
||||
|
@ -759,6 +770,7 @@ void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
|
|||
auto node = cnode->input(kSummaryGetItem);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(item_with_index.first);
|
||||
if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
|
||||
}
|
||||
|
@ -810,6 +822,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> output_args;
|
||||
for (const auto &output : outputs) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
MS_LOG(INFO) << "output:" << output->DebugString();
|
||||
}
|
||||
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
|
||||
|
@ -881,7 +894,9 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
|
|||
}
|
||||
auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
|
||||
inputs.push_back(parameter);
|
||||
graph->MutableInputs()->push_back(parameter);
|
||||
auto mutable_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(mutable_inputs);
|
||||
mutable_inputs->push_back(parameter);
|
||||
}
|
||||
// set execution order
|
||||
auto cnode = graph->NewCNode(inputs);
|
||||
|
|
|
@ -48,11 +48,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
|
|||
|
||||
class SessionBasic {
|
||||
public:
|
||||
SessionBasic() : device_id_(0) {
|
||||
graphs_ = {};
|
||||
run_op_graphs_ = {};
|
||||
summary_callback_ = nullptr;
|
||||
}
|
||||
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {}
|
||||
|
||||
virtual void Init(uint32_t device_id) { device_id_ = device_id; }
|
||||
|
||||
|
|
Loading…
Reference in New Issue