!27790 Fix if and while bugs in closure.

Merge pull request !27790 from gaoyong10/runtime_second12
This commit is contained in:
i-robot 2021-12-17 03:08:15 +00:00 committed by Gitee
commit fa26ec66c8
7 changed files with 202 additions and 20 deletions

View File

@ -202,6 +202,14 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
if (AnfUtils::IsRealCNodeKernel(node)) {
outputs_num = AnfAlgo::GetOutputTensorNum(node);
}
// If the node is a call, the outputs num should get from the abstract.
if (AnfAlgo::IsCallNode(node)) {
auto abstract = node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
outputs_num = AnfAlgo::GetOutputNumByAbstract(abstract);
}
// The output may be the tuple of node, so need visit all the outputs of node.
for (size_t i = 0; i < outputs_num; ++i) {
// Maybe this scene: tupleGetItem + depend + makeTuple, can be done correctly in VisitKernelWithReturnType.
@ -371,11 +379,21 @@ std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node
size_t AnfRuntimeAlgorithm::GetOutputNumByAbstract(const AbstractBasePtr &node_abstract) {
MS_EXCEPTION_IF_NULL(node_abstract);
size_t result = 0;
if (node_abstract->isa<abstract::AbstractCSRTensor>()) {
auto csr_tensor_abstract = node_abstract->cast<abstract::AbstractCSRTensorPtr>();
MS_EXCEPTION_IF_NULL(csr_tensor_abstract);
result += GetOutputNumByAbstract(csr_tensor_abstract->indptr());
result += GetOutputNumByAbstract(csr_tensor_abstract->indices());
result += GetOutputNumByAbstract(csr_tensor_abstract->values());
result += GetOutputNumByAbstract(csr_tensor_abstract->dense_shape());
return result;
}
if (!node_abstract->isa<abstract::AbstractTuple>()) {
return 1;
}
size_t result = 0;
auto tuple_abstract = node_abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();

View File

@ -134,14 +134,14 @@ void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
std::vector<DeviceTensor *> new_device_tensors;
for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
auto input_device_tensor = input_device_tensors_[i];
MS_EXCEPTION_IF_NULL(input_device_tensor);
const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex();
MS_EXCEPTION_IF_NULL(node_with_index.first);
if (!is_need_copy_device_tensors_[i]) {
new_device_tensors.emplace_back(input_device_tensor);
continue;
}
MS_EXCEPTION_IF_NULL(input_device_tensor);
const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex();
MS_EXCEPTION_IF_NULL(node_with_index.first);
MS_EXCEPTION_IF_NULL(device_contexts_[i]);
// Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs.
auto new_device_tensor = device_contexts_[i]->CreateDeviceAddress(

View File

@ -142,6 +142,9 @@ void PrepareDataForValue(const ValuePtr &value, const KernelWithIndex &node_with
} else if (value->isa<Int64Imm>()) {
type = kNumberTypeInt64;
(reinterpret_cast<int64_t *>(host_addr.get()))[0] = GetValue<int64_t>(value);
} else if (value->isa<Int32Imm>()) {
type = kNumberTypeInt32;
(reinterpret_cast<int32_t *>(host_addr.get()))[0] = GetValue<int32_t>(value);
} else {
MS_LOG(EXCEPTION) << "Invalid value:" << value->ToString();
}

View File

@ -16,6 +16,7 @@
#include "runtime/framework/control_node_parser.h"
#include "runtime/framework/actor/actor_common.h"
#include "utils/convert_utils.h"
#include "abstract/utils.h"
#include "ir/tensor.h"
@ -31,6 +32,12 @@ bool IsPartial(const AnfNodePtr &node) {
AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial);
}
// Check if node is a value node need to create a device tensor.
bool IsFrontValueNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
return node->isa<ValueNode>() && (!IsValueNode<FuncGraph>(node)) && (!IsValueNode<Primitive>(node));
}
// Get funcgraph in partial structure.
// Depth represents the number of layers of the call. When the first input of the call node is a call node,
// the funcgraph in the return value of the inner call needs to be returned.
@ -195,6 +202,64 @@ KernelWithIndex FetchRealInputNode(const KernelWithIndex &node_with_index) {
return {};
}
// Fetch all the output index in the sub-abstract of abstract.
std::set<size_t> FetchRealIndexByAbstract(const AbstractBasePtr &abstract, size_t index) {
MS_EXCEPTION_IF_NULL(abstract);
AbstractBasePtr dst_abstract = abstract;
size_t pre_abstract_num = 0;
// Fetch the dest abstract by index, and the abstracts num before the dest abstract.
if (abstract->isa<abstract::AbstractCSRTensor>()) {
auto csr_abs = abstract->cast<abstract::AbstractCSRTensorPtr>();
MS_EXCEPTION_IF_NULL(csr_abs);
switch (index) {
case kCsrTensorIndPtrIndex:
dst_abstract = csr_abs->indptr();
pre_abstract_num = kCsrTensorIndPtrIndex;
break;
case kCsrTensorIndicesIndex:
dst_abstract = csr_abs->indices();
pre_abstract_num = kCsrTensorIndicesIndex;
break;
case kCsrTensorValuesIndex:
dst_abstract = csr_abs->values();
pre_abstract_num = kCsrTensorValuesIndex;
break;
case kCsrTensorDenseShapeIndex:
dst_abstract = csr_abs->dense_shape();
pre_abstract_num = kCsrTensorDenseShapeIndex;
break;
default:
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString();
break;
}
} else if (abstract->isa<abstract::AbstractTuple>()) {
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
if (sub_abstracts.size() <= index) {
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString();
}
for (size_t i = 0; i < index; ++i) {
pre_abstract_num += AnfAlgo::GetOutputNumByAbstract(sub_abstracts[i]);
}
dst_abstract = sub_abstracts[index];
} else {
if (index != 0) {
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
}
}
MS_EXCEPTION_IF_NULL(dst_abstract);
// Fetch real output index.
size_t ouput_num = AnfAlgo::GetOutputNumByAbstract(dst_abstract);
std::set<size_t> real_indexs;
for (size_t i = pre_abstract_num; i < ouput_num + pre_abstract_num; ++i) {
real_indexs.emplace(i);
}
return real_indexs;
}
// Get all the real parameters corresponding to node.
void FetchRealParameterByNode(const KernelWithIndex &node, std::set<KernelWithIndex> *real_parameters,
std::set<KernelWithIndex> *invalid_call_nodes,
@ -384,7 +449,7 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index
MS_EXCEPTION_IF_NULL(front_node);
const auto &node_value = front_node->cast<ValueNodePtr>()->value();
if ((!node_value->isa<tensor::Tensor>()) && (!node_value->isa<ValueTuple>()) && (!node_value->isa<BoolImm>())) {
if (node_value->isa<FuncGraph>() || node_value->isa<Primitive>()) {
return;
}
@ -468,6 +533,7 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
const auto &node_with_index =
AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
auto real_node = node_with_index.first;
size_t real_index = node_with_index.second;
MS_EXCEPTION_IF_NULL(real_node);
std::vector<KernelWithIndex> results;
// 2. MakeTuple.
@ -481,13 +547,51 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
return results;
}
// 3. One output node.
// 3. kPrimMakeCSRTensor.
if (IsCsrNode(real_node)) {
const auto &cnode = real_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
if (inputs.size() <= kMakeCSRTensorInputStartPos) {
MS_LOG(EXCEPTION) << "Invalid make csr tensor node:" << cnode->DebugString();
}
// Fetch output put index.
const auto &prim_node = inputs[0]->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_node);
const auto &prim_value = prim_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim_value);
const auto &src_node = inputs[kMakeCSRTensorInputStartPos];
MS_EXCEPTION_IF_NULL(src_node);
const auto iter = sparse_attr_map.find(prim_value->name());
// Csr node from the make csr tensor node.
if (AnfAlgo::CheckPrimitiveType(src_node, prim::kPrimMakeCSRTensor)) {
const auto &make_csr_tensor_cnode = src_node->cast<CNodePtr>();
const auto &csr_tensor_inputs = make_csr_tensor_cnode->inputs();
if (csr_tensor_inputs.size() <= kMakeCSRTensorInputNum) {
MS_LOG(EXCEPTION) << "Invalid make csr tensor node:" << cnode->DebugString();
}
const auto &sub_results = FetchInputNodeByNode(csr_tensor_inputs[iter->second + kMakeCSRTensorInputStartPos]);
results.insert(results.end(), sub_results.begin(), sub_results.end());
} else {
// Csr node from parameter or call node.
auto abstract = src_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
auto real_indexs = FetchRealIndexByAbstract(abstract, iter->second);
(void)std::transform(real_indexs.begin(), real_indexs.end(), std::back_inserter(results),
[&src_node](const auto &index) { return KernelWithIndex(src_node, index); });
}
return results;
}
// 4. One output node.
const auto &abstract = real_node->abstract();
if (abstract == nullptr || (!abstract->isa<abstract::AbstractTuple>())) {
if (abstract == nullptr ||
((!abstract->isa<abstract::AbstractTuple>()) && (!abstract->isa<abstract::AbstractCSRTensor>()))) {
if (abstract == nullptr) {
MS_LOG(WARNING) << "Empty abstract for node:" << real_node->DebugString();
}
return {AnfAlgo::VisitKernelWithReturnType(real_node, 0)};
return {AnfAlgo::VisitKernelWithReturnType(real_node, real_index)};
}
// 4. Abstract is Tuple.
@ -528,6 +632,17 @@ bool HasAbstractRef(const AnfNodePtr &node) {
return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
}
bool IsCsrNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
return AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetIndptr) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetIndices) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetValues) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetDenseShape);
}
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, KernelGraph *const graph) {
MS_EXCEPTION_IF_NULL(graph);
const auto &front_node = graph->GetFrontAnfByBackendAnf(backend_node);
@ -589,6 +704,22 @@ std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) {
abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index) {
MS_EXCEPTION_IF_NULL(abstract);
if (abstract->isa<abstract::AbstractCSRTensor>()) {
auto csr_abs = abstract->cast<abstract::AbstractCSRTensorPtr>();
MS_EXCEPTION_IF_NULL(csr_abs);
if (index == kCsrTensorIndPtrIndex) {
return csr_abs->indptr();
} else if (index == kCsrTensorIndicesIndex) {
return csr_abs->indices();
} else if (index == kCsrTensorValuesIndex) {
return csr_abs->values();
} else if (index >= kCsrTensorDenseShapeIndex) {
return FetchAbstractByIndex(csr_abs->dense_shape(), index - kCsrTensorDenseShapeIndex);
} else {
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString();
}
}
if (!abstract->isa<abstract::AbstractTuple>()) {
if (index != 0) {
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
@ -906,13 +1037,20 @@ void ControlNodeParser::ParseDeviceContextForCallNode(const std::vector<AnfNodeP
const auto &cnode = control_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
if (inputs.size() - kCallInputStartPos > iter->second.size()) {
MS_LOG(EXCEPTION) << "Invalid input size:" << inputs.size() << " context size:" << iter->second.size()
size_t call_input_num = 0;
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
const auto &abstract = inputs[i]->abstract();
MS_EXCEPTION_IF_NULL(abstract);
call_input_num += AnfAlgo::GetOutputNumByAbstract(abstract);
}
if (call_input_num > iter->second.size()) {
MS_LOG(EXCEPTION) << "Invalid input size:" << call_input_num << " context size:" << iter->second.size()
<< "for funcgraph" << func_graph->ToString() << " for call node:" << cnode->DebugString();
}
// Fetch the device contexts for the real parameters on the call node.
for (size_t i = iter->second.size() - inputs.size() + kCallInputStartPos; i < iter->second.size(); ++i) {
for (size_t i = iter->second.size() - call_input_num; i < iter->second.size(); ++i) {
if (i >= iter->second.size()) {
MS_LOG(EXCEPTION) << "Invalid device context index:" << i << " for funcgraph:" << func_graph->ToString()
<< " device context size:" << iter->second.size()
@ -1076,7 +1214,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
for (const auto &formal_to_real_parameter : formal_to_real_parameters_) {
for (const auto &real_parameter_with_index : formal_to_real_parameter.second) {
const auto &real_parameter = real_parameter_with_index.first;
if (!real_parameter->isa<ValueNode>() || IsValueNode<FuncGraph>(real_parameter)) {
if (!IsFrontValueNode(real_parameter)) {
continue;
}
@ -1097,7 +1235,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
const auto &front_node = front_to_backend_parameters.first.first;
MS_EXCEPTION_IF_NULL(front_node);
if (front_node->isa<ValueNode>() && (!front_to_backend_parameters.second.empty())) {
if (IsFrontValueNode(front_node) && (!front_to_backend_parameters.second.empty())) {
const auto &backend_parameter = front_to_backend_parameters.second.begin()->first;
const auto &device_context = front_to_backend_parameters.second.begin()->second;
CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context);
@ -1107,18 +1245,21 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
// Create device tensors for those value nodes which direct return by a return node.
for (const auto &control_node : control_nodes) {
if (!AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
if ((!AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) && (!AnfAlgo::IsCallNode(control_node))) {
continue;
}
auto input_with_indexs = FetchInputNodeByCNode(control_node);
auto iter = control_node_to_device_contexts_.find(control_node);
if (iter == control_node_to_device_contexts_.end() || iter->second.size() != input_with_indexs.size()) {
MS_LOG(EXCEPTION) << "Invalid device context for control node:" << control_node->DebugString();
if (iter == control_node_to_device_contexts_.end() || iter->second.size() < input_with_indexs.size()) {
MS_LOG(EXCEPTION) << "Invalid device context for control node:" << control_node->DebugString()
<< " need:" << input_with_indexs.size() << " current:"
<< (iter == control_node_to_device_contexts_.end() ? "null"
: std::to_string(iter->second.size()));
}
for (size_t i = 0; i < input_with_indexs.size(); ++i) {
const auto &input_with_index = input_with_indexs[i];
if (input_with_index.first->isa<ValueNode>() && (!IsValueNode<FuncGraph>(input_with_index.first)) &&
if (IsFrontValueNode(input_with_index.first) &&
front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) {
CreateDeviceTensorForFrontNode(input_with_index, iter->second[i]);
front_value_nodes_.emplace(input_with_index, iter->second[i]);

View File

@ -56,6 +56,13 @@ constexpr size_t kCNodeInputStartPos = 1;
constexpr size_t kReturnInputPos = 1;
constexpr size_t kSingleControlNode = 1;
constexpr size_t kCsrTensorIndPtrIndex = 0;
constexpr size_t kCsrTensorIndicesIndex = 1;
constexpr size_t kCsrTensorValuesIndex = 2;
constexpr size_t kCsrTensorDenseShapeIndex = 3;
constexpr size_t kMakeCSRTensorInputStartPos = 1;
constexpr size_t kMakeCSRTensorInputNum = 4;
const char kEntranceActorNameSuffix[] = "_EntranceActor";
const char kExitActorNameSuffix[] = "_ExitActor";
const char kStackActorNameSuffix[] = "_StackActor";
@ -87,6 +94,8 @@ using KernelGraphGroupInfoPtr = std::shared_ptr<KernelGraphGroupInfo>;
// Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph,
// it is determined whether it is a weight.
bool HasAbstractRef(const AnfNodePtr &node);
// Check whether the node is a csr node.
bool IsCsrNode(const AnfNodePtr &node);
// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
// corresponding cnode.
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, KernelGraph *const graph);

View File

@ -257,7 +257,7 @@ std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompil
std::vector<const DeviceContext *> device_contexts;
for (const auto &node_with_context : kernel_graph_group_info->front_output_nodes_) {
if (HasAbstractMonad(node_with_context.first.first)) {
if (HasAbstractMonad(node_with_context.first.first) || IsCsrNode(node_with_context.first.first)) {
continue;
}
// Collect inputs of exit actor.
@ -666,7 +666,7 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont
branch_id);
}
}
if (abstract->isa<abstract::AbstractFunction>()) {
if (real_abstract->isa<abstract::AbstractFunction>()) {
to_actor->input_partials_num_++;
} else {
to_actor->input_datas_num_++;

View File

@ -954,13 +954,24 @@ BaseRef MindRTBackend::ConstructOutputByAbstract(const abstract::AbstractBasePtr
MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position << " need:" << outputs_num
<< " total:" << output_tensors.size();
}
VectorRef outputs;
if (abstract->isa<abstract::AbstractCSRTensor>()) {
auto csr_tensor_abstract = abstract->cast<abstract::AbstractCSRTensorPtr>();
MS_EXCEPTION_IF_NULL(csr_tensor_abstract);
outputs.emplace_back(ConstructOutputByAbstract(csr_tensor_abstract->indptr(), output_tensors, output_position));
outputs.emplace_back(ConstructOutputByAbstract(csr_tensor_abstract->indices(), output_tensors, output_position));
outputs.emplace_back(ConstructOutputByAbstract(csr_tensor_abstract->values(), output_tensors, output_position));
outputs.emplace_back(
ConstructOutputByAbstract(csr_tensor_abstract->dense_shape(), output_tensors, output_position));
return outputs;
}
if (!abstract->isa<abstract::AbstractTuple>()) {
(*output_position)++;
return output_tensors[(*output_position) - 1];
}
VectorRef outputs;
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();