forked from mindspore-Ecosystem/mindspore
!28193 Valuetuple support monad.
Merge pull request !28193 from gaoyong10/runtime_second12
This commit is contained in:
commit
4a7d08a7d6
|
@ -146,7 +146,8 @@ void PrepareDataForValue(const ValuePtr &value, const KernelWithIndex &node_with
|
||||||
type = kNumberTypeInt32;
|
type = kNumberTypeInt32;
|
||||||
(reinterpret_cast<int32_t *>(host_addr.get()))[0] = GetValue<int32_t>(value);
|
(reinterpret_cast<int32_t *>(host_addr.get()))[0] = GetValue<int32_t>(value);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Invalid value:" << value->ToString();
|
std::string error_info = "Invalid value:" + value->ToString();
|
||||||
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto type_size = GetTypeByte(TypeIdToType(type));
|
auto type_size = GetTypeByte(TypeIdToType(type));
|
||||||
|
|
|
@ -34,9 +34,28 @@ bool IsPartial(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if node is a value node need to create a device tensor.
|
// Check if node is a value node need to create a device tensor.
|
||||||
bool IsFrontValueNode(const AnfNodePtr &node) {
|
bool IsFrontValueNode(const KernelWithIndex &node_with_index) {
|
||||||
|
const auto &node = node_with_index.first;
|
||||||
|
size_t index = node_with_index.second;
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
return node->isa<ValueNode>() && (!IsValueNode<FuncGraph>(node)) && (!IsValueNode<Primitive>(node));
|
if (!node->isa<ValueNode>() || IsValueNode<FuncGraph>(node) || IsValueNode<Primitive>(node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!IsValueNode<ValueTuple>(node)) {
|
||||||
|
return !HasAbstractMonad(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &abstract = node->abstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
|
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||||
|
const auto &sub_abstracts = tuple_abstract->elements();
|
||||||
|
if (sub_abstracts.size() <= index) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for tuple value node:" << node->DebugString();
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(sub_abstracts[index]);
|
||||||
|
return !sub_abstracts[index]->isa<abstract::AbstractMonad>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get funcgraph in partial structure.
|
// Get funcgraph in partial structure.
|
||||||
|
@ -599,7 +618,6 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5 Other.
|
// 5 Other.
|
||||||
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);
|
|
||||||
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimTupleGetItem)) {
|
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimTupleGetItem)) {
|
||||||
const auto &get_item_cnode = real_node->cast<CNodePtr>();
|
const auto &get_item_cnode = real_node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(get_item_cnode);
|
MS_EXCEPTION_IF_NULL(get_item_cnode);
|
||||||
|
@ -625,9 +643,23 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
||||||
real_indexs.begin(), real_indexs.end(), std::back_inserter(results),
|
real_indexs.begin(), real_indexs.end(), std::back_inserter(results),
|
||||||
[&get_item_src_node](const auto &index) { return KernelWithIndex(get_item_src_node, index); });
|
[&get_item_src_node](const auto &index) { return KernelWithIndex(get_item_src_node, index); });
|
||||||
}
|
}
|
||||||
} else {
|
return results;
|
||||||
for (size_t i = 0; i < output_num; ++i) {
|
}
|
||||||
results.emplace_back(real_node, i);
|
|
||||||
|
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||||
|
if (output_num == 1) {
|
||||||
|
results.emplace_back(real_node, 0);
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||||
|
const auto &sub_abstracts = tuple_abstract->elements();
|
||||||
|
size_t index = 0;
|
||||||
|
for (const auto &sub_abstract : sub_abstracts) {
|
||||||
|
MS_EXCEPTION_IF_NULL(sub_abstract);
|
||||||
|
if (!sub_abstract->isa<abstract::AbstractMonad>()) {
|
||||||
|
results.emplace_back(real_node, index++);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return results;
|
return results;
|
||||||
|
@ -1011,8 +1043,13 @@ void ControlNodeParser::ParseDeviceContextForFuncGraph(const std::vector<AnfNode
|
||||||
FuncGraphSet sub_graphs = root_func_graph_->manager()->func_graphs();
|
FuncGraphSet sub_graphs = root_func_graph_->manager()->func_graphs();
|
||||||
for (auto sub_graph : sub_graphs) {
|
for (auto sub_graph : sub_graphs) {
|
||||||
if (func_graph_to_device_contexts_.find(sub_graph) == func_graph_to_device_contexts_.end()) {
|
if (func_graph_to_device_contexts_.find(sub_graph) == func_graph_to_device_contexts_.end()) {
|
||||||
func_graph_to_device_contexts_[sub_graph] =
|
size_t output_num = 0;
|
||||||
std::vector<const DeviceContext *>(sub_graph->parameters().size(), default_context);
|
for (const auto ¶meter : sub_graph->parameters()) {
|
||||||
|
const auto &abstract = parameter->abstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
|
output_num += AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||||
|
}
|
||||||
|
func_graph_to_device_contexts_[sub_graph] = std::vector<const DeviceContext *>(output_num, default_context);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1277,8 +1314,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
||||||
|
|
||||||
for (const auto &formal_to_real_parameter : formal_to_real_parameters_) {
|
for (const auto &formal_to_real_parameter : formal_to_real_parameters_) {
|
||||||
for (const auto &real_parameter_with_index : formal_to_real_parameter.second) {
|
for (const auto &real_parameter_with_index : formal_to_real_parameter.second) {
|
||||||
const auto &real_parameter = real_parameter_with_index.first;
|
if (!IsFrontValueNode(real_parameter_with_index)) {
|
||||||
if (!IsFrontValueNode(real_parameter)) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1299,7 +1335,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
||||||
for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
|
for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
|
||||||
const auto &front_node = front_to_backend_parameters.first.first;
|
const auto &front_node = front_to_backend_parameters.first.first;
|
||||||
MS_EXCEPTION_IF_NULL(front_node);
|
MS_EXCEPTION_IF_NULL(front_node);
|
||||||
if (IsFrontValueNode(front_node) && (!front_to_backend_parameters.second.empty())) {
|
if (IsFrontValueNode(front_to_backend_parameters.first) && (!front_to_backend_parameters.second.empty())) {
|
||||||
const auto &backend_parameter = front_to_backend_parameters.second.begin()->first;
|
const auto &backend_parameter = front_to_backend_parameters.second.begin()->first;
|
||||||
const auto &device_context = front_to_backend_parameters.second.begin()->second;
|
const auto &device_context = front_to_backend_parameters.second.begin()->second;
|
||||||
CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context);
|
CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context);
|
||||||
|
@ -1323,7 +1359,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < input_with_indexs.size(); ++i) {
|
for (size_t i = 0; i < input_with_indexs.size(); ++i) {
|
||||||
const auto &input_with_index = input_with_indexs[i];
|
const auto &input_with_index = input_with_indexs[i];
|
||||||
if (IsFrontValueNode(input_with_index.first) &&
|
if (IsFrontValueNode(input_with_index) &&
|
||||||
front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) {
|
front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) {
|
||||||
CreateDeviceTensorForFrontNode(input_with_index, iter->second[i]);
|
CreateDeviceTensorForFrontNode(input_with_index, iter->second[i]);
|
||||||
front_value_nodes_.emplace(input_with_index, iter->second[i]);
|
front_value_nodes_.emplace(input_with_index, iter->second[i]);
|
||||||
|
|
Loading…
Reference in New Issue