session-code-review

This commit is contained in:
lvliang 2020-06-20 14:21:32 +08:00
parent 932b7649e7
commit 4de9e2506a
2 changed files with 20 additions and 9 deletions

View File

@ -52,6 +52,7 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
return nullptr;
}
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
MS_EXCEPTION_IF_NULL(param_value);
auto py_param = param_value->value();
return py_param.ptr();
}
@ -69,7 +70,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
}
if (node->isa<Parameter>()) {
for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
if (input_idx > input_tensors.size()) {
if (input_idx >= input_tensors.size()) {
MS_LOG(EXCEPTION) << "input idx:" << input_idx << "out of range:" << input_tensors.size();
}
if (graph.inputs()[input_idx] == node) {
@ -149,6 +150,8 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
}
ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
auto value_node = anf->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
@ -229,6 +232,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph,
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input_tensor);
auto value_node = std::make_shared<ValueNode>(input_tensor);
MS_EXCEPTION_IF_NULL(value_node);
// construct abstract of value node
auto type_of_tensor = input_tensor->Dtype();
auto shape_of_tensor = input_tensor->shape();
@ -242,6 +246,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph,
ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
int tensor_mask) {
MS_EXCEPTION_IF_NULL(graph);
auto param = graph->NewParameter();
MS_EXCEPTION_IF_NULL(param);
if (tensor_mask == kParameterWeightTensorMask) {
@ -295,6 +300,7 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
}
bool ExistSummaryNode(const KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
auto ret = graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto all_nodes = DeepLinkedGraphSearch(ret);
@ -315,7 +321,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
}
MS_EXCEPTION_IF_NULL(graph);
auto m_tensor = GetParamDefaultInputTensor(anf);
auto valid_inputs = graph->MutableValidInputs();
MS_EXCEPTION_IF_NULL(valid_inputs);
@ -344,6 +350,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
if (parameters.empty()) {
@ -482,6 +489,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
auto value_node = anf->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
@ -509,6 +517,7 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker
ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
}
@ -536,6 +545,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
auto graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
size_t from_other_graph_depend_num = 0;
for (const auto &node : lst) {
@ -583,6 +593,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
MS_EXCEPTION_IF_NULL(func_graph);
auto node_list = TopoSort(func_graph->get_return());
auto graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(graph);
front_backend_graph_map_[func_graph] = graph;
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
@ -722,8 +733,8 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
}
auto anf_outputs = kernel_graph->outputs();
for (auto &item : anf_outputs) {
MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
MS_EXCEPTION_IF_NULL(item);
MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
continue;
@ -759,6 +770,7 @@ void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
auto node = cnode->input(kSummaryGetItem);
MS_EXCEPTION_IF_NULL(node);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
MS_EXCEPTION_IF_NULL(item_with_index.first);
if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
}
@ -810,6 +822,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> output_args;
for (const auto &output : outputs) {
MS_EXCEPTION_IF_NULL(output);
MS_LOG(INFO) << "output:" << output->DebugString();
}
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
@ -881,7 +894,9 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
}
auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
inputs.push_back(parameter);
graph->MutableInputs()->push_back(parameter);
auto mutable_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(mutable_inputs);
mutable_inputs->push_back(parameter);
}
// set execution order
auto cnode = graph->NewCNode(inputs);

View File

@ -48,11 +48,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
class SessionBasic {
public:
SessionBasic() : device_id_(0) {
graphs_ = {};
run_op_graphs_ = {};
summary_callback_ = nullptr;
}
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {}
virtual void Init(uint32_t device_id) { device_id_ = device_id; }