forked from mindspore-Ecosystem/mindspore
Fix getitem in getitem.
This commit is contained in:
parent
ab50cf5b21
commit
e446b2e3ed
|
@ -307,7 +307,8 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz
|
|||
|
||||
KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
|
||||
bool skip_nop_node,
|
||||
const std::vector<PrimitivePtr> &return_types) {
|
||||
const std::vector<PrimitivePtr> &return_types,
|
||||
abstract::AbstractBasePtr *abstract) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool {
|
||||
return CheckPrimitiveType(anf_node, prim_type);
|
||||
|
@ -320,8 +321,9 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
|
|||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
|
||||
auto item_with_index_tmp = VisitKernelWithReturnType(GetTupleGetItemRealInput(cnode),
|
||||
GetTupleGetItemOutIndex(cnode), skip_nop_node, return_types);
|
||||
abstract::AbstractBasePtr abs = nullptr;
|
||||
auto item_with_index_tmp = VisitKernelWithReturnType(
|
||||
GetTupleGetItemRealInput(cnode), GetTupleGetItemOutIndex(cnode), skip_nop_node, return_types, &abs);
|
||||
if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) {
|
||||
MS_EXCEPTION_IF_NULL(item_with_index_tmp.first);
|
||||
auto make_tuple = item_with_index_tmp.first->cast<CNodePtr>();
|
||||
|
@ -334,6 +336,32 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
|
|||
}
|
||||
return VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], 0, skip_nop_node, return_types);
|
||||
}
|
||||
if (IsCallNode(item_with_index_tmp.first)) {
|
||||
size_t real_index = item_with_index_tmp.second;
|
||||
if (abs == nullptr) {
|
||||
abs = item_with_index_tmp.first->abstract();
|
||||
real_index = 0;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
if (abs->isa<abstract::AbstractTuple>()) {
|
||||
auto tuple_abstract = abs->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
auto sub_abstracts = tuple_abstract->elements();
|
||||
if (sub_abstracts.size() <= GetTupleGetItemOutIndex(cnode)) {
|
||||
MS_LOG(EXCEPTION) << "Invalid index:" << GetTupleGetItemOutIndex(cnode)
|
||||
<< " for abstract:" << abs->ToString();
|
||||
}
|
||||
for (size_t i = 0; i < GetTupleGetItemOutIndex(cnode); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(sub_abstracts[i]);
|
||||
real_index += AnfAlgo::GetOutputNumByAbstract(sub_abstracts[i]);
|
||||
}
|
||||
if (abstract != nullptr) {
|
||||
(*abstract) = sub_abstracts[GetTupleGetItemOutIndex(cnode)];
|
||||
MS_EXCEPTION_IF_NULL((*abstract));
|
||||
}
|
||||
return {item_with_index_tmp.first, real_index};
|
||||
}
|
||||
}
|
||||
return item_with_index_tmp;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) {
|
||||
|
|
|
@ -77,10 +77,10 @@ class AnfRuntimeAlgorithm {
|
|||
static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
|
||||
// get input_anf_node's real kernel by recurse
|
||||
static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index);
|
||||
static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index,
|
||||
bool skip_nop_node = false,
|
||||
const std::vector<PrimitivePtr> &return_types = {
|
||||
prim::kPrimMakeTuple});
|
||||
static KernelWithIndex VisitKernelWithReturnType(
|
||||
const AnfNodePtr &input_anf_node, size_t output_index, bool skip_nop_node = false,
|
||||
const std::vector<PrimitivePtr> &return_types = {prim::kPrimMakeTuple},
|
||||
abstract::AbstractBasePtr *abstract = nullptr);
|
||||
static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
|
||||
const std::vector<PrimitivePtr> &return_types = {});
|
||||
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node);
|
||||
|
|
|
@ -737,14 +737,13 @@ void DataPrepareActor::PrepareDataForControlNode(const ControlNodeParserPtr &con
|
|||
continue;
|
||||
}
|
||||
|
||||
const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters();
|
||||
const auto &iter = front_to_backend_parameters.find({front_node, 0});
|
||||
if (iter == front_to_backend_parameters.end() || iter->second.empty()) {
|
||||
const auto &backend_parameter_with_context =
|
||||
control_node_parser->FetchBackendParameterWithContextByFrontParameter({front_node, 0});
|
||||
if (backend_parameter_with_context.first == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:" << AnfAlgo::GetNodeDebugString(front_node);
|
||||
}
|
||||
const auto &node_with_context = iter->second.begin();
|
||||
const auto &backend_node = node_with_context->first;
|
||||
const auto &device_context = node_with_context->second;
|
||||
const auto &backend_node = backend_parameter_with_context.first;
|
||||
const auto &device_context = backend_parameter_with_context.second;
|
||||
MS_EXCEPTION_IF_NULL(backend_node);
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
|
||||
|
|
|
@ -369,7 +369,8 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index
|
|||
device::DeviceAddressPtr address =
|
||||
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address;
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address
|
||||
<< " size:" << tensor_size;
|
||||
AnfAlgo::SetOutputAddr(address, front_node_with_index.second, front_node.get());
|
||||
UpdateRefCount(address.get(), true);
|
||||
}
|
||||
|
@ -1231,6 +1232,25 @@ FuncGraphPtr ControlNodeParser::FetchFuncGraphByKernelGraph(const KernelGraph *c
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
NodeWithContext ControlNodeParser::FetchBackendParameterWithContextByFrontParameter(
|
||||
const KernelWithIndex &front_parameter_with_index) {
|
||||
const auto &iter = front_to_backend_parameters_.find(front_parameter_with_index);
|
||||
if (iter == front_to_backend_parameters_.end()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
for (const auto &node_with_context : iter->second) {
|
||||
MS_EXCEPTION_IF_NULL(node_with_context.first);
|
||||
if (AnfAlgo::GetOutputTensorMemSize(node_with_context.first, 0) != 0) {
|
||||
return node_with_context;
|
||||
}
|
||||
MS_LOG(WARNING) << "Backend node:" << node_with_context.first->DebugString()
|
||||
<< " for front node:" << front_parameter_with_index.first->DebugString()
|
||||
<< " index:" << front_parameter_with_index.second << " output size is 0.";
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes,
|
||||
const DeviceContext *const default_context) {
|
||||
MS_EXCEPTION_IF_NULL(default_context);
|
||||
|
@ -1241,11 +1261,12 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
|||
continue;
|
||||
}
|
||||
|
||||
const auto &iter = front_to_backend_parameters_.find(real_parameter_with_index);
|
||||
if (iter != front_to_backend_parameters_.end() && (!iter->second.empty())) {
|
||||
(void)front_value_nodes_.emplace(real_parameter_with_index, iter->second.begin()->second);
|
||||
CreateDeviceTensorForValueNode(real_parameter_with_index, iter->second.begin()->first,
|
||||
iter->second.begin()->second);
|
||||
const auto &backend_node_with_context =
|
||||
FetchBackendParameterWithContextByFrontParameter(real_parameter_with_index);
|
||||
if (backend_node_with_context.first != nullptr) {
|
||||
(void)front_value_nodes_.emplace(real_parameter_with_index, backend_node_with_context.second);
|
||||
CreateDeviceTensorForValueNode(real_parameter_with_index, backend_node_with_context.first,
|
||||
backend_node_with_context.second);
|
||||
} else {
|
||||
(void)front_value_nodes_.emplace(real_parameter_with_index, default_context);
|
||||
CreateDeviceTensorForFrontNode(real_parameter_with_index, default_context);
|
||||
|
@ -1253,19 +1274,6 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
|||
}
|
||||
}
|
||||
|
||||
// If the output of funcgraph is a value node, it will eventually be sent to the kernel as a real parameter.
|
||||
// These the value nodes also need to create a device address.
|
||||
for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
|
||||
const auto &front_node = front_to_backend_parameters.first.first;
|
||||
MS_EXCEPTION_IF_NULL(front_node);
|
||||
if (IsFrontValueNode(front_to_backend_parameters.first) && (!front_to_backend_parameters.second.empty())) {
|
||||
const auto &backend_parameter = front_to_backend_parameters.second.begin()->first;
|
||||
const auto &device_context = front_to_backend_parameters.second.begin()->second;
|
||||
CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context);
|
||||
(void)front_value_nodes_.emplace(front_to_backend_parameters.first, device_context);
|
||||
}
|
||||
}
|
||||
|
||||
// Create device tensors for those value nodes which direct return by a return node.
|
||||
for (const auto &control_node : control_nodes) {
|
||||
if ((!AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) && (!AnfAlgo::IsCallNode(control_node))) {
|
||||
|
|
|
@ -66,8 +66,8 @@ constexpr size_t kMakeCSRTensorInputNum = 4;
|
|||
const char kEntranceActorNameSuffix[] = "_EntranceActor";
|
||||
const char kExitActorNameSuffix[] = "_ExitActor";
|
||||
const char kStackActorNameSuffix[] = "_StackActor";
|
||||
|
||||
using FrontToBackendNodeWithContext = std::map<KernelWithIndex, std::set<std::pair<AnfNodePtr, DeviceContext *>>>;
|
||||
using NodeWithContext = std::pair<AnfNodePtr, DeviceContext *>;
|
||||
using FrontToBackendNodeWithContext = std::map<KernelWithIndex, std::set<NodeWithContext>>;
|
||||
using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>;
|
||||
using FuncGraphToKernelGraphGroup = mindspore::HashMap<FuncGraphPtr, std::vector<std::vector<KernelGraphPtr>>>;
|
||||
using HostParameterToWeight = std::map<AnfNodePtr, std::set<AnfNodePtr>>;
|
||||
|
@ -145,6 +145,7 @@ class ControlNodeParser {
|
|||
KernelWithIndex FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index);
|
||||
FuncGraphPtr FetchFuncGraphByKernelGraph(const KernelGraph *const graph);
|
||||
std::string FetchGroupNameByKernelGraph(const KernelGraphPtr &graph);
|
||||
NodeWithContext FetchBackendParameterWithContextByFrontParameter(const KernelWithIndex &front_parameter_with_index);
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
|
|
@ -1418,6 +1418,10 @@ void ControlNodeScheduler::LinkBranchIDArrow(ControlActor *const from_actor, Con
|
|||
|
||||
bool ControlNodeScheduler::CheckActorValid(const ActorSet *actor_set) const {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
if (actor_set->control_actors_ == nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (const auto &kernel_actor : actor_set->kernel_actors_) {
|
||||
std::string exit_actor_name = "";
|
||||
for (const auto arrow : kernel_actor->output_data_arrows_) {
|
||||
|
|
|
@ -646,22 +646,16 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
|
|||
}
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
|
||||
const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
|
||||
const auto parser = graph_compiler_info.control_node_parser_;
|
||||
MS_EXCEPTION_IF_NULL(parser);
|
||||
|
||||
// Initialize the parameter in the control node, first get all the front parameters in the control node, then find
|
||||
// the corresponding backend parameter from the map, and insert it into the host data source actor
|
||||
const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
|
||||
// the corresponding backend parameter from the map, and insert it into the host data source actor.
|
||||
const auto &control_node_parameters = parser->control_node_parameters();
|
||||
for (const auto ¶meter : control_node_parameters) {
|
||||
if (IsPersistentDeviceTensor(parameter)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto backend_iter = front_to_backend_parameter.find({parameter, 0});
|
||||
if (backend_iter == front_to_backend_parameter.end() || backend_iter->second.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter);
|
||||
}
|
||||
|
||||
if (host_queue_ds_actor == nullptr) {
|
||||
auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
|
||||
MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
|
||||
|
@ -675,15 +669,18 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
|
|||
if (node_map.find(parameter) != node_map.end()) {
|
||||
continue;
|
||||
}
|
||||
const auto &backend_node = backend_iter->second.begin()->first;
|
||||
const auto &backend_parameter_with_context =
|
||||
parser->FetchBackendParameterWithContextByFrontParameter({parameter, 0});
|
||||
const auto &backend_node = backend_parameter_with_context.first;
|
||||
MS_EXCEPTION_IF_NULL(backend_node);
|
||||
auto iter = find(host_queue_ds_actor->data_nodes_.begin(), host_queue_ds_actor->data_nodes_.end(), backend_node);
|
||||
if (iter != host_queue_ds_actor->data_nodes_.end()) {
|
||||
(void)node_map.emplace(parameter, iter - host_queue_ds_actor->data_nodes_.begin());
|
||||
} else {
|
||||
(void)node_map.emplace(parameter, host_queue_ds_actor->data_nodes_.size());
|
||||
(void)node_map.emplace(backend_iter->second.begin()->first, host_queue_ds_actor->data_nodes_.size());
|
||||
(void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.begin()->first);
|
||||
(void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.begin()->second);
|
||||
(void)node_map.emplace(backend_node, host_queue_ds_actor->data_nodes_.size());
|
||||
(void)host_queue_ds_actor->data_nodes_.emplace_back(backend_node);
|
||||
(void)host_queue_ds_actor->device_contexts_.emplace_back(backend_parameter_with_context.second);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1940,14 +1937,13 @@ void GraphScheduler::PersistDeviceTensorForControlNode(const GraphCompilerInfo &
|
|||
if ((!IsPersistentDeviceTensor(input_node)) || (!parser->IsRootGraphParameter(input_node))) {
|
||||
continue;
|
||||
}
|
||||
const auto &front_to_backend_parameters = parser->front_to_backend_parameters();
|
||||
const auto &iter = front_to_backend_parameters.find({input_node, 0});
|
||||
if (iter == front_to_backend_parameters.end() || iter->second.empty()) {
|
||||
const auto &backend_parameter_with_context =
|
||||
parser->FetchBackendParameterWithContextByFrontParameter({input_node, 0});
|
||||
if (backend_parameter_with_context.first == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:" << input_node->DebugString();
|
||||
}
|
||||
const auto &node_with_context = iter->second.begin();
|
||||
const auto &backend_node = node_with_context->first;
|
||||
const auto &device_context = node_with_context->second;
|
||||
const auto &backend_node = backend_parameter_with_context.first;
|
||||
const auto &device_context = backend_parameter_with_context.second;
|
||||
MS_EXCEPTION_IF_NULL(backend_node);
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
if (!DeviceTensorStore::GetInstance().Fetch(input_node.get()).empty()) {
|
||||
|
|
Loading…
Reference in New Issue