forked from mindspore-Ecosystem/mindspore
add partial internal parameters
This commit is contained in:
parent
c0690309aa
commit
abf295560c
|
@ -160,7 +160,7 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
|
|||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
|
||||
auto &node = node_output_pair.first;
|
||||
auto &output_index = node_output_pair.second;
|
||||
int output_index = SizeToInt(node_output_pair.second);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
|
||||
|
@ -172,25 +172,24 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
|
|||
if (type_id == kTypeUnknown) {
|
||||
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
|
||||
}
|
||||
tensor::TensorPtr tensor = nullptr;
|
||||
std::vector<int64_t> temp_shape;
|
||||
if (graph->IsUniqueTargetInternalOutput(node, output_index)) {
|
||||
temp_shape.emplace_back(1);
|
||||
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
} else {
|
||||
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
|
||||
tensor::TensorPtr tensor;
|
||||
bool is_internal_output = graph->IsInternalOutput(node, output_index);
|
||||
if (is_internal_output) {
|
||||
tensor = graph->GetInternalOutputTensor(node, output_index);
|
||||
if (tensor == nullptr) {
|
||||
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
|
||||
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
bool is_internal_output = graph->IsInternalOutput(node, output_index);
|
||||
if (is_internal_output) {
|
||||
graph->AddInternalOutputTensor(node, output_index, tensor);
|
||||
}
|
||||
graph->AddInternalOutputTensor(node, output_index, tensor);
|
||||
}
|
||||
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
|
||||
} else {
|
||||
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
}
|
||||
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
|
||||
if (is_internal_output) {
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
} else {
|
||||
// if in pynative mode,data only copied to host when user want to print data
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
@ -682,16 +681,20 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
|
|||
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
|
||||
if (param_value != nullptr) {
|
||||
new_parameter = param_value->parameter();
|
||||
if (new_parameter == nullptr) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
|
||||
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
|
||||
param_value->set_parameter(new_parameter);
|
||||
}
|
||||
} else {
|
||||
}
|
||||
if (new_parameter == nullptr) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
|
||||
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
|
||||
}
|
||||
|
||||
auto input_node_iter = partial_parameters_map_.find(anf);
|
||||
if (input_node_iter != partial_parameters_map_.end()) {
|
||||
InitInternalOutputParameter(input_node_iter->second, new_parameter);
|
||||
}
|
||||
|
||||
if (param_value != nullptr) {
|
||||
param_value->set_parameter(new_parameter);
|
||||
}
|
||||
}
|
||||
new_parameter->IncreaseUsedGraphCount();
|
||||
graph_inputs->push_back(new_parameter);
|
||||
valid_inputs->push_back(true);
|
||||
|
@ -1771,10 +1774,11 @@ bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) {
|
|||
|
||||
std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager,
|
||||
const AnfNodePtr &front_node) {
|
||||
MS_EXCEPTION_IF_NULL(front_func_graph_manager);
|
||||
auto &users = front_func_graph_manager->node_users()[front_node];
|
||||
std::vector<AnfNodePtr> result;
|
||||
for (auto &user : users) {
|
||||
if (IsPrimitiveCNode(user.first, prim::kPrimDepend)) {
|
||||
if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimDepend)) {
|
||||
auto depend_cnode = user.first->cast<CNodePtr>();
|
||||
if (depend_cnode == nullptr) {
|
||||
continue;
|
||||
|
@ -1784,9 +1788,12 @@ std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr
|
|||
}
|
||||
auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
|
||||
result.insert(result.end(), res.begin(), res.end());
|
||||
continue;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimMakeTuple)) {
|
||||
auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
|
||||
(void)result.insert(result.end(), res.begin(), res.end());
|
||||
} else {
|
||||
(void)result.emplace_back(user.first);
|
||||
}
|
||||
(void)result.emplace_back(user.first);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -1812,10 +1819,54 @@ AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) {
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
|
||||
const FuncGraphManagerPtr &front_func_graph_manager,
|
||||
const std::shared_ptr<KernelGraph> &backend_graph) {
|
||||
constexpr auto kMixTarget = "MixTarget";
|
||||
constexpr auto kNoTarget = "NoTarget";
|
||||
std::string SessionBasic::AddPartialParametersMap(const FuncGraphManagerPtr &front_func_graph_manager,
|
||||
const AnfNodePtr &partial_node) {
|
||||
MS_EXCEPTION_IF_NULL(partial_node);
|
||||
auto iter = partial_target_map_.find(partial_node);
|
||||
if (iter != partial_target_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
auto partial_cnode = partial_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_cnode);
|
||||
auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
|
||||
MS_EXCEPTION_IF_NULL(partial_graph);
|
||||
auto parameters = partial_graph->parameters();
|
||||
auto partial_inputs = partial_cnode->inputs();
|
||||
if (parameters.size() + 2 != partial_inputs.size()) {
|
||||
return kMixTarget;
|
||||
}
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
partial_parameters_map_[parameters[i]] = partial_inputs[2 + i];
|
||||
}
|
||||
auto graph_nodes = TopoSort(partial_graph->get_return());
|
||||
std::string graph_target = kNoTarget;
|
||||
for (auto &node : graph_nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
if (!AnfAlgo::IsRealKernel(node)) {
|
||||
continue;
|
||||
}
|
||||
std::string cur_target = GetCNodeTarget(node);
|
||||
if (graph_target == kNoTarget) {
|
||||
graph_target = cur_target;
|
||||
}
|
||||
if (graph_target != cur_target) {
|
||||
graph_target = kMixTarget;
|
||||
break;
|
||||
}
|
||||
}
|
||||
(void)partial_target_map_.insert({partial_node, graph_target});
|
||||
return graph_target;
|
||||
}
|
||||
|
||||
void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
|
||||
const FuncGraphManagerPtr &front_func_graph_manager,
|
||||
const std::shared_ptr<KernelGraph> &backend_graph) {
|
||||
auto front_node = GetSupportedInternalNode(input_front_node);
|
||||
if (front_node == nullptr) {
|
||||
return;
|
||||
|
@ -1839,7 +1890,14 @@ void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &
|
|||
}
|
||||
if (internal_output) {
|
||||
auto users = ExtendNodeUsers(front_func_graph_manager, front_node);
|
||||
for (auto user : users) {
|
||||
for (auto &user : users) {
|
||||
if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial)) {
|
||||
auto partial_target = AddPartialParametersMap(front_func_graph_manager, user);
|
||||
if (partial_target != kNoTarget && partial_target != kernel_target) {
|
||||
unique_target = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (!CNodeFirstInputIsPrimitive(user)) {
|
||||
internal_output = false;
|
||||
break;
|
||||
|
@ -1859,7 +1917,6 @@ void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &
|
|||
backend_graph->AddInternalOutput(front_node, backend_real_kernel, backend_real_kernel_pair.second, unique_target);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -1868,7 +1925,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
|
|||
MS_EXCEPTION_IF_NULL(output);
|
||||
MS_LOG(INFO) << "Output:" << output->DebugString();
|
||||
}
|
||||
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
|
||||
auto FindEqu = [graph, outputs, this](const AnfNodePtr &out) -> AnfNodePtr {
|
||||
auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
|
||||
if (backend_anf != nullptr) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
|
|
|
@ -156,6 +156,11 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
|
||||
std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph);
|
||||
void ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, const std::vector<AnfNodePtr> &real_inputs);
|
||||
void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
|
||||
const FuncGraphManagerPtr &front_func_graph_manager,
|
||||
const std::shared_ptr<KernelGraph> &backend_graph);
|
||||
std::string AddPartialParametersMap(const FuncGraphManagerPtr &front_func_graph_manager,
|
||||
const AnfNodePtr &partial_node);
|
||||
|
||||
protected:
|
||||
friend class Executor;
|
||||
|
@ -255,6 +260,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
|
||||
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
|
||||
std::unordered_map<FuncGraph *, KernelGraphPtr> front_backend_graph_map_;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> partial_parameters_map_;
|
||||
std::unordered_map<AnfNodePtr, std::string> partial_target_map_;
|
||||
std::shared_ptr<Context> context_;
|
||||
CallBackFunc summary_callback_;
|
||||
static GraphId graph_sum_;
|
||||
|
|
|
@ -189,16 +189,19 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
|
|||
MS_EXCEPTION_IF_NULL(address);
|
||||
TypeId infer_type_id = AnfAlgo::GetOutputInferDataType(node, index);
|
||||
TypeId device_type_id = AnfAlgo::GetOutputDeviceDataType(node, index);
|
||||
tensor::TensorPtr tensor = kernel_graph->GetInternalOutputTensor(node, index);
|
||||
if (tensor == nullptr) {
|
||||
auto shape = AnfAlgo::GetOutputInferShape(node, index);
|
||||
ShapeVector temp_shape;
|
||||
(void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end());
|
||||
tensor = std::make_shared<tensor::Tensor>(infer_type_id, temp_shape);
|
||||
bool is_internal_output = kernel_graph->IsInternalOutput(node, index);
|
||||
if (is_internal_output) {
|
||||
kernel_graph->AddInternalOutputTensor(node, index, tensor);
|
||||
auto shape = AnfAlgo::GetOutputInferShape(node, index);
|
||||
ShapeVector temp_shape;
|
||||
tensor::TensorPtr tensor;
|
||||
bool is_internal_output = kernel_graph->IsInternalOutput(node, index);
|
||||
(void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end());
|
||||
if (is_internal_output) {
|
||||
tensor = kernel_graph->GetInternalOutputTensor(node, index);
|
||||
if (tensor == nullptr) {
|
||||
tensor = std::make_shared<tensor::Tensor>(infer_type_id, temp_shape);
|
||||
}
|
||||
kernel_graph->AddInternalOutputTensor(node, index, tensor);
|
||||
} else {
|
||||
tensor = std::make_shared<tensor::Tensor>(infer_type_id, temp_shape);
|
||||
}
|
||||
tensor->set_device_address(address);
|
||||
if (bound_addresses_.find(address) == bound_addresses_.end()) {
|
||||
|
|
Loading…
Reference in New Issue