add partial internal parameters

This commit is contained in:
kswang 2021-04-26 20:13:32 +08:00
parent c0690309aa
commit abf295560c
3 changed files with 107 additions and 40 deletions

View File

@ -160,7 +160,7 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
const std::vector<tensor::TensorPtr> &input_tensors,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
auto &node = node_output_pair.first;
auto &output_index = node_output_pair.second;
int output_index = SizeToInt(node_output_pair.second);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
@ -172,25 +172,24 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
if (type_id == kTypeUnknown) {
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
}
tensor::TensorPtr tensor = nullptr;
std::vector<int64_t> temp_shape;
if (graph->IsUniqueTargetInternalOutput(node, output_index)) {
temp_shape.emplace_back(1);
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
tensor->set_sync_status(kNoNeedSync);
} else {
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
tensor::TensorPtr tensor;
bool is_internal_output = graph->IsInternalOutput(node, output_index);
if (is_internal_output) {
tensor = graph->GetInternalOutputTensor(node, output_index);
if (tensor == nullptr) {
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
bool is_internal_output = graph->IsInternalOutput(node, output_index);
if (is_internal_output) {
graph->AddInternalOutputTensor(node, output_index, tensor);
}
graph->AddInternalOutputTensor(node, output_index, tensor);
}
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
} else {
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
}
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
if (is_internal_output) {
tensor->set_sync_status(kNoNeedSync);
} else {
// if in pynative mode,data only copied to host when user want to print data
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
@ -682,16 +681,20 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
if (param_value != nullptr) {
new_parameter = param_value->parameter();
if (new_parameter == nullptr) {
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
param_value->set_parameter(new_parameter);
}
} else {
}
if (new_parameter == nullptr) {
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
}
auto input_node_iter = partial_parameters_map_.find(anf);
if (input_node_iter != partial_parameters_map_.end()) {
InitInternalOutputParameter(input_node_iter->second, new_parameter);
}
if (param_value != nullptr) {
param_value->set_parameter(new_parameter);
}
}
new_parameter->IncreaseUsedGraphCount();
graph_inputs->push_back(new_parameter);
valid_inputs->push_back(true);
@ -1771,10 +1774,11 @@ bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) {
std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager,
const AnfNodePtr &front_node) {
MS_EXCEPTION_IF_NULL(front_func_graph_manager);
auto &users = front_func_graph_manager->node_users()[front_node];
std::vector<AnfNodePtr> result;
for (auto &user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimDepend)) {
if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimDepend)) {
auto depend_cnode = user.first->cast<CNodePtr>();
if (depend_cnode == nullptr) {
continue;
@ -1784,9 +1788,12 @@ std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr
}
auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
result.insert(result.end(), res.begin(), res.end());
continue;
} else if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimMakeTuple)) {
auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
(void)result.insert(result.end(), res.begin(), res.end());
} else {
(void)result.emplace_back(user.first);
}
(void)result.emplace_back(user.first);
}
return result;
}
@ -1812,10 +1819,54 @@ AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) {
}
return nullptr;
}
} // namespace
void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
const FuncGraphManagerPtr &front_func_graph_manager,
const std::shared_ptr<KernelGraph> &backend_graph) {
constexpr auto kMixTarget = "MixTarget";
constexpr auto kNoTarget = "NoTarget";
std::string SessionBasic::AddPartialParametersMap(const FuncGraphManagerPtr &front_func_graph_manager,
const AnfNodePtr &partial_node) {
MS_EXCEPTION_IF_NULL(partial_node);
auto iter = partial_target_map_.find(partial_node);
if (iter != partial_target_map_.end()) {
return iter->second;
}
auto partial_cnode = partial_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode);
auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
MS_EXCEPTION_IF_NULL(partial_graph);
auto parameters = partial_graph->parameters();
auto partial_inputs = partial_cnode->inputs();
if (parameters.size() + 2 != partial_inputs.size()) {
return kMixTarget;
}
for (size_t i = 0; i < parameters.size(); ++i) {
partial_parameters_map_[parameters[i]] = partial_inputs[2 + i];
}
auto graph_nodes = TopoSort(partial_graph->get_return());
std::string graph_target = kNoTarget;
for (auto &node : graph_nodes) {
if (!node->isa<CNode>()) {
continue;
}
if (!AnfAlgo::IsRealKernel(node)) {
continue;
}
std::string cur_target = GetCNodeTarget(node);
if (graph_target == kNoTarget) {
graph_target = cur_target;
}
if (graph_target != cur_target) {
graph_target = kMixTarget;
break;
}
}
(void)partial_target_map_.insert({partial_node, graph_target});
return graph_target;
}
void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
const FuncGraphManagerPtr &front_func_graph_manager,
const std::shared_ptr<KernelGraph> &backend_graph) {
auto front_node = GetSupportedInternalNode(input_front_node);
if (front_node == nullptr) {
return;
@ -1839,7 +1890,14 @@ void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &
}
if (internal_output) {
auto users = ExtendNodeUsers(front_func_graph_manager, front_node);
for (auto user : users) {
for (auto &user : users) {
if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial)) {
auto partial_target = AddPartialParametersMap(front_func_graph_manager, user);
if (partial_target != kNoTarget && partial_target != kernel_target) {
unique_target = false;
}
continue;
}
if (!CNodeFirstInputIsPrimitive(user)) {
internal_output = false;
break;
@ -1859,7 +1917,6 @@ void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &
backend_graph->AddInternalOutput(front_node, backend_real_kernel, backend_real_kernel_pair.second, unique_target);
}
}
} // namespace
CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
MS_EXCEPTION_IF_NULL(graph);
@ -1868,7 +1925,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
MS_EXCEPTION_IF_NULL(output);
MS_LOG(INFO) << "Output:" << output->DebugString();
}
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
auto FindEqu = [graph, outputs, this](const AnfNodePtr &out) -> AnfNodePtr {
auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
if (backend_anf != nullptr) {
auto context_ptr = MsContext::GetInstance();

View File

@ -156,6 +156,11 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph);
void ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, const std::vector<AnfNodePtr> &real_inputs);
void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
const FuncGraphManagerPtr &front_func_graph_manager,
const std::shared_ptr<KernelGraph> &backend_graph);
std::string AddPartialParametersMap(const FuncGraphManagerPtr &front_func_graph_manager,
const AnfNodePtr &partial_node);
protected:
friend class Executor;
@ -255,6 +260,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
std::unordered_map<FuncGraph *, KernelGraphPtr> front_backend_graph_map_;
std::unordered_map<AnfNodePtr, AnfNodePtr> partial_parameters_map_;
std::unordered_map<AnfNodePtr, std::string> partial_target_map_;
std::shared_ptr<Context> context_;
CallBackFunc summary_callback_;
static GraphId graph_sum_;

View File

@ -189,16 +189,19 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
MS_EXCEPTION_IF_NULL(address);
TypeId infer_type_id = AnfAlgo::GetOutputInferDataType(node, index);
TypeId device_type_id = AnfAlgo::GetOutputDeviceDataType(node, index);
tensor::TensorPtr tensor = kernel_graph->GetInternalOutputTensor(node, index);
if (tensor == nullptr) {
auto shape = AnfAlgo::GetOutputInferShape(node, index);
ShapeVector temp_shape;
(void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end());
tensor = std::make_shared<tensor::Tensor>(infer_type_id, temp_shape);
bool is_internal_output = kernel_graph->IsInternalOutput(node, index);
if (is_internal_output) {
kernel_graph->AddInternalOutputTensor(node, index, tensor);
auto shape = AnfAlgo::GetOutputInferShape(node, index);
ShapeVector temp_shape;
tensor::TensorPtr tensor;
bool is_internal_output = kernel_graph->IsInternalOutput(node, index);
(void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end());
if (is_internal_output) {
tensor = kernel_graph->GetInternalOutputTensor(node, index);
if (tensor == nullptr) {
tensor = std::make_shared<tensor::Tensor>(infer_type_id, temp_shape);
}
kernel_graph->AddInternalOutputTensor(node, index, tensor);
} else {
tensor = std::make_shared<tensor::Tensor>(infer_type_id, temp_shape);
}
tensor->set_device_address(address);
if (bound_addresses_.find(address) == bound_addresses_.end()) {