forked from mindspore-Ecosystem/mindspore
!27790 Fix if and while bugs in closure.
Merge pull request !27790 from gaoyong10/runtime_second12
This commit is contained in:
commit
fa26ec66c8
|
@ -202,6 +202,14 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
|
|||
if (AnfUtils::IsRealCNodeKernel(node)) {
|
||||
outputs_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
}
|
||||
|
||||
// If the node is a call, the outputs num should get from the abstract.
|
||||
if (AnfAlgo::IsCallNode(node)) {
|
||||
auto abstract = node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
outputs_num = AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||
}
|
||||
|
||||
// The output may be the tuple of node, so need visit all the outputs of node.
|
||||
for (size_t i = 0; i < outputs_num; ++i) {
|
||||
// Maybe this scene: tupleGetItem + depend + makeTuple, can be done correctly in VisitKernelWithReturnType.
|
||||
|
@ -371,11 +379,21 @@ std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node
|
|||
|
||||
size_t AnfRuntimeAlgorithm::GetOutputNumByAbstract(const AbstractBasePtr &node_abstract) {
|
||||
MS_EXCEPTION_IF_NULL(node_abstract);
|
||||
size_t result = 0;
|
||||
if (node_abstract->isa<abstract::AbstractCSRTensor>()) {
|
||||
auto csr_tensor_abstract = node_abstract->cast<abstract::AbstractCSRTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor_abstract);
|
||||
result += GetOutputNumByAbstract(csr_tensor_abstract->indptr());
|
||||
result += GetOutputNumByAbstract(csr_tensor_abstract->indices());
|
||||
result += GetOutputNumByAbstract(csr_tensor_abstract->values());
|
||||
result += GetOutputNumByAbstract(csr_tensor_abstract->dense_shape());
|
||||
return result;
|
||||
}
|
||||
|
||||
if (!node_abstract->isa<abstract::AbstractTuple>()) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
size_t result = 0;
|
||||
auto tuple_abstract = node_abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
const auto &sub_abstracts = tuple_abstract->elements();
|
||||
|
|
|
@ -134,14 +134,14 @@ void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
|
|||
std::vector<DeviceTensor *> new_device_tensors;
|
||||
for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
|
||||
auto input_device_tensor = input_device_tensors_[i];
|
||||
MS_EXCEPTION_IF_NULL(input_device_tensor);
|
||||
const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex();
|
||||
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
||||
if (!is_need_copy_device_tensors_[i]) {
|
||||
new_device_tensors.emplace_back(input_device_tensor);
|
||||
continue;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(input_device_tensor);
|
||||
const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex();
|
||||
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[i]);
|
||||
// Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs.
|
||||
auto new_device_tensor = device_contexts_[i]->CreateDeviceAddress(
|
||||
|
|
|
@ -142,6 +142,9 @@ void PrepareDataForValue(const ValuePtr &value, const KernelWithIndex &node_with
|
|||
} else if (value->isa<Int64Imm>()) {
|
||||
type = kNumberTypeInt64;
|
||||
(reinterpret_cast<int64_t *>(host_addr.get()))[0] = GetValue<int64_t>(value);
|
||||
} else if (value->isa<Int32Imm>()) {
|
||||
type = kNumberTypeInt32;
|
||||
(reinterpret_cast<int32_t *>(host_addr.get()))[0] = GetValue<int32_t>(value);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid value:" << value->ToString();
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "runtime/framework/control_node_parser.h"
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
|
@ -31,6 +32,12 @@ bool IsPartial(const AnfNodePtr &node) {
|
|||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial);
|
||||
}
|
||||
|
||||
// Check if node is a value node need to create a device tensor.
|
||||
bool IsFrontValueNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
return node->isa<ValueNode>() && (!IsValueNode<FuncGraph>(node)) && (!IsValueNode<Primitive>(node));
|
||||
}
|
||||
|
||||
// Get funcgraph in partial structure.
|
||||
// Depth represents the number of layers of the call. When the first input of the call node is a call node,
|
||||
// the funcgraph in the return value of the inner call needs to be returned.
|
||||
|
@ -195,6 +202,64 @@ KernelWithIndex FetchRealInputNode(const KernelWithIndex &node_with_index) {
|
|||
return {};
|
||||
}
|
||||
|
||||
// Fetch all the output index in the sub-abstract of abstract.
|
||||
std::set<size_t> FetchRealIndexByAbstract(const AbstractBasePtr &abstract, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
AbstractBasePtr dst_abstract = abstract;
|
||||
size_t pre_abstract_num = 0;
|
||||
|
||||
// Fetch the dest abstract by index, and the abstracts num before the dest abstract.
|
||||
if (abstract->isa<abstract::AbstractCSRTensor>()) {
|
||||
auto csr_abs = abstract->cast<abstract::AbstractCSRTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(csr_abs);
|
||||
switch (index) {
|
||||
case kCsrTensorIndPtrIndex:
|
||||
dst_abstract = csr_abs->indptr();
|
||||
pre_abstract_num = kCsrTensorIndPtrIndex;
|
||||
break;
|
||||
case kCsrTensorIndicesIndex:
|
||||
dst_abstract = csr_abs->indices();
|
||||
pre_abstract_num = kCsrTensorIndicesIndex;
|
||||
break;
|
||||
case kCsrTensorValuesIndex:
|
||||
dst_abstract = csr_abs->values();
|
||||
pre_abstract_num = kCsrTensorValuesIndex;
|
||||
break;
|
||||
case kCsrTensorDenseShapeIndex:
|
||||
dst_abstract = csr_abs->dense_shape();
|
||||
pre_abstract_num = kCsrTensorDenseShapeIndex;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString();
|
||||
break;
|
||||
}
|
||||
} else if (abstract->isa<abstract::AbstractTuple>()) {
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
const auto &sub_abstracts = tuple_abstract->elements();
|
||||
if (sub_abstracts.size() <= index) {
|
||||
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString();
|
||||
}
|
||||
for (size_t i = 0; i < index; ++i) {
|
||||
pre_abstract_num += AnfAlgo::GetOutputNumByAbstract(sub_abstracts[i]);
|
||||
}
|
||||
dst_abstract = sub_abstracts[index];
|
||||
} else {
|
||||
if (index != 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
|
||||
}
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(dst_abstract);
|
||||
|
||||
// Fetch real output index.
|
||||
size_t ouput_num = AnfAlgo::GetOutputNumByAbstract(dst_abstract);
|
||||
std::set<size_t> real_indexs;
|
||||
for (size_t i = pre_abstract_num; i < ouput_num + pre_abstract_num; ++i) {
|
||||
real_indexs.emplace(i);
|
||||
}
|
||||
return real_indexs;
|
||||
}
|
||||
|
||||
// Get all the real parameters corresponding to node.
|
||||
void FetchRealParameterByNode(const KernelWithIndex &node, std::set<KernelWithIndex> *real_parameters,
|
||||
std::set<KernelWithIndex> *invalid_call_nodes,
|
||||
|
@ -384,7 +449,7 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index
|
|||
MS_EXCEPTION_IF_NULL(front_node);
|
||||
|
||||
const auto &node_value = front_node->cast<ValueNodePtr>()->value();
|
||||
if ((!node_value->isa<tensor::Tensor>()) && (!node_value->isa<ValueTuple>()) && (!node_value->isa<BoolImm>())) {
|
||||
if (node_value->isa<FuncGraph>() || node_value->isa<Primitive>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -468,6 +533,7 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
|||
const auto &node_with_index =
|
||||
AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
|
||||
auto real_node = node_with_index.first;
|
||||
size_t real_index = node_with_index.second;
|
||||
MS_EXCEPTION_IF_NULL(real_node);
|
||||
std::vector<KernelWithIndex> results;
|
||||
// 2. MakeTuple.
|
||||
|
@ -481,13 +547,51 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
|||
return results;
|
||||
}
|
||||
|
||||
// 3. One output node.
|
||||
// 3. kPrimMakeCSRTensor.
|
||||
if (IsCsrNode(real_node)) {
|
||||
const auto &cnode = real_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs.size() <= kMakeCSRTensorInputStartPos) {
|
||||
MS_LOG(EXCEPTION) << "Invalid make csr tensor node:" << cnode->DebugString();
|
||||
}
|
||||
|
||||
// Fetch output put index.
|
||||
const auto &prim_node = inputs[0]->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_node);
|
||||
const auto &prim_value = prim_node->value()->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_value);
|
||||
const auto &src_node = inputs[kMakeCSRTensorInputStartPos];
|
||||
MS_EXCEPTION_IF_NULL(src_node);
|
||||
|
||||
const auto iter = sparse_attr_map.find(prim_value->name());
|
||||
// Csr node from the make csr tensor node.
|
||||
if (AnfAlgo::CheckPrimitiveType(src_node, prim::kPrimMakeCSRTensor)) {
|
||||
const auto &make_csr_tensor_cnode = src_node->cast<CNodePtr>();
|
||||
const auto &csr_tensor_inputs = make_csr_tensor_cnode->inputs();
|
||||
if (csr_tensor_inputs.size() <= kMakeCSRTensorInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Invalid make csr tensor node:" << cnode->DebugString();
|
||||
}
|
||||
const auto &sub_results = FetchInputNodeByNode(csr_tensor_inputs[iter->second + kMakeCSRTensorInputStartPos]);
|
||||
results.insert(results.end(), sub_results.begin(), sub_results.end());
|
||||
} else {
|
||||
// Csr node from parameter or call node.
|
||||
auto abstract = src_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto real_indexs = FetchRealIndexByAbstract(abstract, iter->second);
|
||||
(void)std::transform(real_indexs.begin(), real_indexs.end(), std::back_inserter(results),
|
||||
[&src_node](const auto &index) { return KernelWithIndex(src_node, index); });
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
// 4. One output node.
|
||||
const auto &abstract = real_node->abstract();
|
||||
if (abstract == nullptr || (!abstract->isa<abstract::AbstractTuple>())) {
|
||||
if (abstract == nullptr ||
|
||||
((!abstract->isa<abstract::AbstractTuple>()) && (!abstract->isa<abstract::AbstractCSRTensor>()))) {
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(WARNING) << "Empty abstract for node:" << real_node->DebugString();
|
||||
}
|
||||
return {AnfAlgo::VisitKernelWithReturnType(real_node, 0)};
|
||||
return {AnfAlgo::VisitKernelWithReturnType(real_node, real_index)};
|
||||
}
|
||||
|
||||
// 4. Abstract is Tuple.
|
||||
|
@ -528,6 +632,17 @@ bool HasAbstractRef(const AnfNodePtr &node) {
|
|||
return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
|
||||
}
|
||||
|
||||
bool IsCsrNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
return AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetIndptr) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetIndices) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetValues) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetDenseShape);
|
||||
}
|
||||
|
||||
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, KernelGraph *const graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
const auto &front_node = graph->GetFrontAnfByBackendAnf(backend_node);
|
||||
|
@ -589,6 +704,22 @@ std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) {
|
|||
|
||||
abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (abstract->isa<abstract::AbstractCSRTensor>()) {
|
||||
auto csr_abs = abstract->cast<abstract::AbstractCSRTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(csr_abs);
|
||||
if (index == kCsrTensorIndPtrIndex) {
|
||||
return csr_abs->indptr();
|
||||
} else if (index == kCsrTensorIndicesIndex) {
|
||||
return csr_abs->indices();
|
||||
} else if (index == kCsrTensorValuesIndex) {
|
||||
return csr_abs->values();
|
||||
} else if (index >= kCsrTensorDenseShapeIndex) {
|
||||
return FetchAbstractByIndex(csr_abs->dense_shape(), index - kCsrTensorDenseShapeIndex);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
if (!abstract->isa<abstract::AbstractTuple>()) {
|
||||
if (index != 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
|
||||
|
@ -906,13 +1037,20 @@ void ControlNodeParser::ParseDeviceContextForCallNode(const std::vector<AnfNodeP
|
|||
const auto &cnode = control_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs.size() - kCallInputStartPos > iter->second.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input size:" << inputs.size() << " context size:" << iter->second.size()
|
||||
size_t call_input_num = 0;
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
const auto &abstract = inputs[i]->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
call_input_num += AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||
}
|
||||
|
||||
if (call_input_num > iter->second.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input size:" << call_input_num << " context size:" << iter->second.size()
|
||||
<< "for funcgraph" << func_graph->ToString() << " for call node:" << cnode->DebugString();
|
||||
}
|
||||
|
||||
// Fetch the device contexts for the real parameters on the call node.
|
||||
for (size_t i = iter->second.size() - inputs.size() + kCallInputStartPos; i < iter->second.size(); ++i) {
|
||||
for (size_t i = iter->second.size() - call_input_num; i < iter->second.size(); ++i) {
|
||||
if (i >= iter->second.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid device context index:" << i << " for funcgraph:" << func_graph->ToString()
|
||||
<< " device context size:" << iter->second.size()
|
||||
|
@ -1076,7 +1214,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
|||
for (const auto &formal_to_real_parameter : formal_to_real_parameters_) {
|
||||
for (const auto &real_parameter_with_index : formal_to_real_parameter.second) {
|
||||
const auto &real_parameter = real_parameter_with_index.first;
|
||||
if (!real_parameter->isa<ValueNode>() || IsValueNode<FuncGraph>(real_parameter)) {
|
||||
if (!IsFrontValueNode(real_parameter)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1097,7 +1235,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
|||
for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
|
||||
const auto &front_node = front_to_backend_parameters.first.first;
|
||||
MS_EXCEPTION_IF_NULL(front_node);
|
||||
if (front_node->isa<ValueNode>() && (!front_to_backend_parameters.second.empty())) {
|
||||
if (IsFrontValueNode(front_node) && (!front_to_backend_parameters.second.empty())) {
|
||||
const auto &backend_parameter = front_to_backend_parameters.second.begin()->first;
|
||||
const auto &device_context = front_to_backend_parameters.second.begin()->second;
|
||||
CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context);
|
||||
|
@ -1107,18 +1245,21 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
|||
|
||||
// Create device tensors for those value nodes which direct return by a return node.
|
||||
for (const auto &control_node : control_nodes) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
|
||||
if ((!AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) && (!AnfAlgo::IsCallNode(control_node))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto input_with_indexs = FetchInputNodeByCNode(control_node);
|
||||
auto iter = control_node_to_device_contexts_.find(control_node);
|
||||
if (iter == control_node_to_device_contexts_.end() || iter->second.size() != input_with_indexs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid device context for control node:" << control_node->DebugString();
|
||||
if (iter == control_node_to_device_contexts_.end() || iter->second.size() < input_with_indexs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid device context for control node:" << control_node->DebugString()
|
||||
<< " need:" << input_with_indexs.size() << " current:"
|
||||
<< (iter == control_node_to_device_contexts_.end() ? "null"
|
||||
: std::to_string(iter->second.size()));
|
||||
}
|
||||
for (size_t i = 0; i < input_with_indexs.size(); ++i) {
|
||||
const auto &input_with_index = input_with_indexs[i];
|
||||
if (input_with_index.first->isa<ValueNode>() && (!IsValueNode<FuncGraph>(input_with_index.first)) &&
|
||||
if (IsFrontValueNode(input_with_index.first) &&
|
||||
front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) {
|
||||
CreateDeviceTensorForFrontNode(input_with_index, iter->second[i]);
|
||||
front_value_nodes_.emplace(input_with_index, iter->second[i]);
|
||||
|
|
|
@ -56,6 +56,13 @@ constexpr size_t kCNodeInputStartPos = 1;
|
|||
constexpr size_t kReturnInputPos = 1;
|
||||
constexpr size_t kSingleControlNode = 1;
|
||||
|
||||
constexpr size_t kCsrTensorIndPtrIndex = 0;
|
||||
constexpr size_t kCsrTensorIndicesIndex = 1;
|
||||
constexpr size_t kCsrTensorValuesIndex = 2;
|
||||
constexpr size_t kCsrTensorDenseShapeIndex = 3;
|
||||
constexpr size_t kMakeCSRTensorInputStartPos = 1;
|
||||
constexpr size_t kMakeCSRTensorInputNum = 4;
|
||||
|
||||
const char kEntranceActorNameSuffix[] = "_EntranceActor";
|
||||
const char kExitActorNameSuffix[] = "_ExitActor";
|
||||
const char kStackActorNameSuffix[] = "_StackActor";
|
||||
|
@ -87,6 +94,8 @@ using KernelGraphGroupInfoPtr = std::shared_ptr<KernelGraphGroupInfo>;
|
|||
// Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph,
|
||||
// it is determined whether it is a weight.
|
||||
bool HasAbstractRef(const AnfNodePtr &node);
|
||||
// Check whether the node is a csr node.
|
||||
bool IsCsrNode(const AnfNodePtr &node);
|
||||
// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
|
||||
// corresponding cnode.
|
||||
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, KernelGraph *const graph);
|
||||
|
|
|
@ -257,7 +257,7 @@ std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompil
|
|||
std::vector<const DeviceContext *> device_contexts;
|
||||
|
||||
for (const auto &node_with_context : kernel_graph_group_info->front_output_nodes_) {
|
||||
if (HasAbstractMonad(node_with_context.first.first)) {
|
||||
if (HasAbstractMonad(node_with_context.first.first) || IsCsrNode(node_with_context.first.first)) {
|
||||
continue;
|
||||
}
|
||||
// Collect inputs of exit actor.
|
||||
|
@ -666,7 +666,7 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont
|
|||
branch_id);
|
||||
}
|
||||
}
|
||||
if (abstract->isa<abstract::AbstractFunction>()) {
|
||||
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
||||
to_actor->input_partials_num_++;
|
||||
} else {
|
||||
to_actor->input_datas_num_++;
|
||||
|
|
|
@ -954,13 +954,24 @@ BaseRef MindRTBackend::ConstructOutputByAbstract(const abstract::AbstractBasePtr
|
|||
MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position << " need:" << outputs_num
|
||||
<< " total:" << output_tensors.size();
|
||||
}
|
||||
VectorRef outputs;
|
||||
|
||||
if (abstract->isa<abstract::AbstractCSRTensor>()) {
|
||||
auto csr_tensor_abstract = abstract->cast<abstract::AbstractCSRTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor_abstract);
|
||||
outputs.emplace_back(ConstructOutputByAbstract(csr_tensor_abstract->indptr(), output_tensors, output_position));
|
||||
outputs.emplace_back(ConstructOutputByAbstract(csr_tensor_abstract->indices(), output_tensors, output_position));
|
||||
outputs.emplace_back(ConstructOutputByAbstract(csr_tensor_abstract->values(), output_tensors, output_position));
|
||||
outputs.emplace_back(
|
||||
ConstructOutputByAbstract(csr_tensor_abstract->dense_shape(), output_tensors, output_position));
|
||||
return outputs;
|
||||
}
|
||||
|
||||
if (!abstract->isa<abstract::AbstractTuple>()) {
|
||||
(*output_position)++;
|
||||
return output_tensors[(*output_position) - 1];
|
||||
}
|
||||
|
||||
VectorRef outputs;
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
const auto &sub_abstracts = tuple_abstract->elements();
|
||||
|
|
Loading…
Reference in New Issue