forked from mindspore-Ecosystem/mindspore
!28193 Valuetuple support monad.
Merge pull request !28193 from gaoyong10/runtime_second12
This commit is contained in:
commit
4a7d08a7d6
|
@ -146,7 +146,8 @@ void PrepareDataForValue(const ValuePtr &value, const KernelWithIndex &node_with
|
|||
type = kNumberTypeInt32;
|
||||
(reinterpret_cast<int32_t *>(host_addr.get()))[0] = GetValue<int32_t>(value);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid value:" << value->ToString();
|
||||
std::string error_info = "Invalid value:" + value->ToString();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
|
||||
auto type_size = GetTypeByte(TypeIdToType(type));
|
||||
|
|
|
@ -34,9 +34,28 @@ bool IsPartial(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
// Check if node is a value node need to create a device tensor.
|
||||
bool IsFrontValueNode(const AnfNodePtr &node) {
|
||||
bool IsFrontValueNode(const KernelWithIndex &node_with_index) {
|
||||
const auto &node = node_with_index.first;
|
||||
size_t index = node_with_index.second;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
return node->isa<ValueNode>() && (!IsValueNode<FuncGraph>(node)) && (!IsValueNode<Primitive>(node));
|
||||
if (!node->isa<ValueNode>() || IsValueNode<FuncGraph>(node) || IsValueNode<Primitive>(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!IsValueNode<ValueTuple>(node)) {
|
||||
return !HasAbstractMonad(node);
|
||||
}
|
||||
|
||||
const auto &abstract = node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
const auto &sub_abstracts = tuple_abstract->elements();
|
||||
if (sub_abstracts.size() <= index) {
|
||||
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for tuple value node:" << node->DebugString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(sub_abstracts[index]);
|
||||
return !sub_abstracts[index]->isa<abstract::AbstractMonad>();
|
||||
}
|
||||
|
||||
// Get funcgraph in partial structure.
|
||||
|
@ -599,7 +618,6 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
// 5 Other.
|
||||
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimTupleGetItem)) {
|
||||
const auto &get_item_cnode = real_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(get_item_cnode);
|
||||
|
@ -625,9 +643,23 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
|||
real_indexs.begin(), real_indexs.end(), std::back_inserter(results),
|
||||
[&get_item_src_node](const auto &index) { return KernelWithIndex(get_item_src_node, index); });
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
results.emplace_back(real_node, i);
|
||||
return results;
|
||||
}
|
||||
|
||||
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||
if (output_num == 1) {
|
||||
results.emplace_back(real_node, 0);
|
||||
return results;
|
||||
}
|
||||
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
const auto &sub_abstracts = tuple_abstract->elements();
|
||||
size_t index = 0;
|
||||
for (const auto &sub_abstract : sub_abstracts) {
|
||||
MS_EXCEPTION_IF_NULL(sub_abstract);
|
||||
if (!sub_abstract->isa<abstract::AbstractMonad>()) {
|
||||
results.emplace_back(real_node, index++);
|
||||
}
|
||||
}
|
||||
return results;
|
||||
|
@ -1011,8 +1043,13 @@ void ControlNodeParser::ParseDeviceContextForFuncGraph(const std::vector<AnfNode
|
|||
FuncGraphSet sub_graphs = root_func_graph_->manager()->func_graphs();
|
||||
for (auto sub_graph : sub_graphs) {
|
||||
if (func_graph_to_device_contexts_.find(sub_graph) == func_graph_to_device_contexts_.end()) {
|
||||
func_graph_to_device_contexts_[sub_graph] =
|
||||
std::vector<const DeviceContext *>(sub_graph->parameters().size(), default_context);
|
||||
size_t output_num = 0;
|
||||
for (const auto ¶meter : sub_graph->parameters()) {
|
||||
const auto &abstract = parameter->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
output_num += AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||
}
|
||||
func_graph_to_device_contexts_[sub_graph] = std::vector<const DeviceContext *>(output_num, default_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1277,8 +1314,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
|||
|
||||
for (const auto &formal_to_real_parameter : formal_to_real_parameters_) {
|
||||
for (const auto &real_parameter_with_index : formal_to_real_parameter.second) {
|
||||
const auto &real_parameter = real_parameter_with_index.first;
|
||||
if (!IsFrontValueNode(real_parameter)) {
|
||||
if (!IsFrontValueNode(real_parameter_with_index)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1299,7 +1335,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
|||
for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
|
||||
const auto &front_node = front_to_backend_parameters.first.first;
|
||||
MS_EXCEPTION_IF_NULL(front_node);
|
||||
if (IsFrontValueNode(front_node) && (!front_to_backend_parameters.second.empty())) {
|
||||
if (IsFrontValueNode(front_to_backend_parameters.first) && (!front_to_backend_parameters.second.empty())) {
|
||||
const auto &backend_parameter = front_to_backend_parameters.second.begin()->first;
|
||||
const auto &device_context = front_to_backend_parameters.second.begin()->second;
|
||||
CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context);
|
||||
|
@ -1323,7 +1359,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
|||
}
|
||||
for (size_t i = 0; i < input_with_indexs.size(); ++i) {
|
||||
const auto &input_with_index = input_with_indexs[i];
|
||||
if (IsFrontValueNode(input_with_index.first) &&
|
||||
if (IsFrontValueNode(input_with_index) &&
|
||||
front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) {
|
||||
CreateDeviceTensorForFrontNode(input_with_index, iter->second[i]);
|
||||
front_value_nodes_.emplace(input_with_index, iter->second[i]);
|
||||
|
|
Loading…
Reference in New Issue