!28193 Valuetuple support monad.

Merge pull request !28193 from gaoyong10/runtime_second12
This commit is contained in:
i-robot 2021-12-25 11:12:13 +00:00 committed by Gitee
commit 4a7d08a7d6
2 changed files with 50 additions and 13 deletions

View File

@ -146,7 +146,8 @@ void PrepareDataForValue(const ValuePtr &value, const KernelWithIndex &node_with
type = kNumberTypeInt32;
(reinterpret_cast<int32_t *>(host_addr.get()))[0] = GetValue<int32_t>(value);
} else {
MS_LOG(EXCEPTION) << "Invalid value:" << value->ToString();
std::string error_info = "Invalid value:" + value->ToString();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
auto type_size = GetTypeByte(TypeIdToType(type));

View File

@ -34,9 +34,28 @@ bool IsPartial(const AnfNodePtr &node) {
}
// Check if node is a value node need to create a device tensor.
bool IsFrontValueNode(const AnfNodePtr &node) {
bool IsFrontValueNode(const KernelWithIndex &node_with_index) {
const auto &node = node_with_index.first;
size_t index = node_with_index.second;
MS_EXCEPTION_IF_NULL(node);
return node->isa<ValueNode>() && (!IsValueNode<FuncGraph>(node)) && (!IsValueNode<Primitive>(node));
if (!node->isa<ValueNode>() || IsValueNode<FuncGraph>(node) || IsValueNode<Primitive>(node)) {
return false;
}
if (!IsValueNode<ValueTuple>(node)) {
return !HasAbstractMonad(node);
}
const auto &abstract = node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
if (sub_abstracts.size() <= index) {
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for tuple value node:" << node->DebugString();
}
MS_EXCEPTION_IF_NULL(sub_abstracts[index]);
return !sub_abstracts[index]->isa<abstract::AbstractMonad>();
}
// Get funcgraph in partial structure.
@ -599,7 +618,6 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
}
// 5 Other.
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimTupleGetItem)) {
const auto &get_item_cnode = real_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(get_item_cnode);
@ -625,9 +643,23 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
real_indexs.begin(), real_indexs.end(), std::back_inserter(results),
[&get_item_src_node](const auto &index) { return KernelWithIndex(get_item_src_node, index); });
}
} else {
for (size_t i = 0; i < output_num; ++i) {
results.emplace_back(real_node, i);
return results;
}
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);
if (output_num == 1) {
results.emplace_back(real_node, 0);
return results;
}
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
size_t index = 0;
for (const auto &sub_abstract : sub_abstracts) {
MS_EXCEPTION_IF_NULL(sub_abstract);
if (!sub_abstract->isa<abstract::AbstractMonad>()) {
results.emplace_back(real_node, index++);
}
}
return results;
@ -1011,8 +1043,13 @@ void ControlNodeParser::ParseDeviceContextForFuncGraph(const std::vector<AnfNode
FuncGraphSet sub_graphs = root_func_graph_->manager()->func_graphs();
for (auto sub_graph : sub_graphs) {
if (func_graph_to_device_contexts_.find(sub_graph) == func_graph_to_device_contexts_.end()) {
func_graph_to_device_contexts_[sub_graph] =
std::vector<const DeviceContext *>(sub_graph->parameters().size(), default_context);
size_t output_num = 0;
for (const auto &parameter : sub_graph->parameters()) {
const auto &abstract = parameter->abstract();
MS_EXCEPTION_IF_NULL(abstract);
output_num += AnfAlgo::GetOutputNumByAbstract(abstract);
}
func_graph_to_device_contexts_[sub_graph] = std::vector<const DeviceContext *>(output_num, default_context);
}
}
}
@ -1277,8 +1314,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
for (const auto &formal_to_real_parameter : formal_to_real_parameters_) {
for (const auto &real_parameter_with_index : formal_to_real_parameter.second) {
const auto &real_parameter = real_parameter_with_index.first;
if (!IsFrontValueNode(real_parameter)) {
if (!IsFrontValueNode(real_parameter_with_index)) {
continue;
}
@ -1299,7 +1335,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
const auto &front_node = front_to_backend_parameters.first.first;
MS_EXCEPTION_IF_NULL(front_node);
if (IsFrontValueNode(front_node) && (!front_to_backend_parameters.second.empty())) {
if (IsFrontValueNode(front_to_backend_parameters.first) && (!front_to_backend_parameters.second.empty())) {
const auto &backend_parameter = front_to_backend_parameters.second.begin()->first;
const auto &device_context = front_to_backend_parameters.second.begin()->second;
CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context);
@ -1323,7 +1359,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
}
for (size_t i = 0; i < input_with_indexs.size(); ++i) {
const auto &input_with_index = input_with_indexs[i];
if (IsFrontValueNode(input_with_index.first) &&
if (IsFrontValueNode(input_with_index) &&
front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) {
CreateDeviceTensorForFrontNode(input_with_index, iter->second[i]);
front_value_nodes_.emplace(input_with_index, iter->second[i]);