Fix getitem in getitem.

This commit is contained in:
gaoyong10 2022-01-18 12:06:50 +08:00
parent ab50cf5b21
commit e446b2e3ed
7 changed files with 90 additions and 54 deletions

View File

@ -307,7 +307,8 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz
KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
bool skip_nop_node,
const std::vector<PrimitivePtr> &return_types) {
const std::vector<PrimitivePtr> &return_types,
abstract::AbstractBasePtr *abstract) {
MS_EXCEPTION_IF_NULL(anf_node);
if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool {
return CheckPrimitiveType(anf_node, prim_type);
@ -320,8 +321,9 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
auto item_with_index_tmp = VisitKernelWithReturnType(GetTupleGetItemRealInput(cnode),
GetTupleGetItemOutIndex(cnode), skip_nop_node, return_types);
abstract::AbstractBasePtr abs = nullptr;
auto item_with_index_tmp = VisitKernelWithReturnType(
GetTupleGetItemRealInput(cnode), GetTupleGetItemOutIndex(cnode), skip_nop_node, return_types, &abs);
if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) {
MS_EXCEPTION_IF_NULL(item_with_index_tmp.first);
auto make_tuple = item_with_index_tmp.first->cast<CNodePtr>();
@ -334,6 +336,32 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
}
return VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], 0, skip_nop_node, return_types);
}
if (IsCallNode(item_with_index_tmp.first)) {
size_t real_index = item_with_index_tmp.second;
if (abs == nullptr) {
abs = item_with_index_tmp.first->abstract();
real_index = 0;
}
MS_EXCEPTION_IF_NULL(abs);
if (abs->isa<abstract::AbstractTuple>()) {
auto tuple_abstract = abs->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
auto sub_abstracts = tuple_abstract->elements();
if (sub_abstracts.size() <= GetTupleGetItemOutIndex(cnode)) {
MS_LOG(EXCEPTION) << "Invalid index:" << GetTupleGetItemOutIndex(cnode)
<< " for abstract:" << abs->ToString();
}
for (size_t i = 0; i < GetTupleGetItemOutIndex(cnode); ++i) {
MS_EXCEPTION_IF_NULL(sub_abstracts[i]);
real_index += AnfAlgo::GetOutputNumByAbstract(sub_abstracts[i]);
}
if (abstract != nullptr) {
(*abstract) = sub_abstracts[GetTupleGetItemOutIndex(cnode)];
MS_EXCEPTION_IF_NULL((*abstract));
}
return {item_with_index_tmp.first, real_index};
}
}
return item_with_index_tmp;
}
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) {

View File

@ -77,10 +77,10 @@ class AnfRuntimeAlgorithm {
static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
// get input_anf_node's real kernel by recurse
static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index);
static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index,
bool skip_nop_node = false,
const std::vector<PrimitivePtr> &return_types = {
prim::kPrimMakeTuple});
static KernelWithIndex VisitKernelWithReturnType(
const AnfNodePtr &input_anf_node, size_t output_index, bool skip_nop_node = false,
const std::vector<PrimitivePtr> &return_types = {prim::kPrimMakeTuple},
abstract::AbstractBasePtr *abstract = nullptr);
static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
const std::vector<PrimitivePtr> &return_types = {});
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node);

View File

@ -737,14 +737,13 @@ void DataPrepareActor::PrepareDataForControlNode(const ControlNodeParserPtr &con
continue;
}
const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters();
const auto &iter = front_to_backend_parameters.find({front_node, 0});
if (iter == front_to_backend_parameters.end() || iter->second.empty()) {
const auto &backend_parameter_with_context =
control_node_parser->FetchBackendParameterWithContextByFrontParameter({front_node, 0});
if (backend_parameter_with_context.first == nullptr) {
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:" << AnfAlgo::GetNodeDebugString(front_node);
}
const auto &node_with_context = iter->second.begin();
const auto &backend_node = node_with_context->first;
const auto &device_context = node_with_context->second;
const auto &backend_node = backend_parameter_with_context.first;
const auto &device_context = backend_parameter_with_context.second;
MS_EXCEPTION_IF_NULL(backend_node);
MS_EXCEPTION_IF_NULL(device_context);

View File

@ -369,7 +369,8 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index
device::DeviceAddressPtr address =
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
MS_EXCEPTION_IF_NULL(address);
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address;
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address
<< " size:" << tensor_size;
AnfAlgo::SetOutputAddr(address, front_node_with_index.second, front_node.get());
UpdateRefCount(address.get(), true);
}
@ -1231,6 +1232,25 @@ FuncGraphPtr ControlNodeParser::FetchFuncGraphByKernelGraph(const KernelGraph *c
return nullptr;
}
NodeWithContext ControlNodeParser::FetchBackendParameterWithContextByFrontParameter(
const KernelWithIndex &front_parameter_with_index) {
const auto &iter = front_to_backend_parameters_.find(front_parameter_with_index);
if (iter == front_to_backend_parameters_.end()) {
return {};
}
for (const auto &node_with_context : iter->second) {
MS_EXCEPTION_IF_NULL(node_with_context.first);
if (AnfAlgo::GetOutputTensorMemSize(node_with_context.first, 0) != 0) {
return node_with_context;
}
MS_LOG(WARNING) << "Backend node:" << node_with_context.first->DebugString()
<< " for front node:" << front_parameter_with_index.first->DebugString()
<< " index:" << front_parameter_with_index.second << " output size is 0.";
}
return {};
}
void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes,
const DeviceContext *const default_context) {
MS_EXCEPTION_IF_NULL(default_context);
@ -1241,11 +1261,12 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
continue;
}
const auto &iter = front_to_backend_parameters_.find(real_parameter_with_index);
if (iter != front_to_backend_parameters_.end() && (!iter->second.empty())) {
(void)front_value_nodes_.emplace(real_parameter_with_index, iter->second.begin()->second);
CreateDeviceTensorForValueNode(real_parameter_with_index, iter->second.begin()->first,
iter->second.begin()->second);
const auto &backend_node_with_context =
FetchBackendParameterWithContextByFrontParameter(real_parameter_with_index);
if (backend_node_with_context.first != nullptr) {
(void)front_value_nodes_.emplace(real_parameter_with_index, backend_node_with_context.second);
CreateDeviceTensorForValueNode(real_parameter_with_index, backend_node_with_context.first,
backend_node_with_context.second);
} else {
(void)front_value_nodes_.emplace(real_parameter_with_index, default_context);
CreateDeviceTensorForFrontNode(real_parameter_with_index, default_context);
@ -1253,19 +1274,6 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
}
}
// If the output of funcgraph is a value node, it will eventually be sent to the kernel as a real parameter.
// These the value nodes also need to create a device address.
for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
const auto &front_node = front_to_backend_parameters.first.first;
MS_EXCEPTION_IF_NULL(front_node);
if (IsFrontValueNode(front_to_backend_parameters.first) && (!front_to_backend_parameters.second.empty())) {
const auto &backend_parameter = front_to_backend_parameters.second.begin()->first;
const auto &device_context = front_to_backend_parameters.second.begin()->second;
CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context);
(void)front_value_nodes_.emplace(front_to_backend_parameters.first, device_context);
}
}
// Create device tensors for those value nodes which direct return by a return node.
for (const auto &control_node : control_nodes) {
if ((!AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) && (!AnfAlgo::IsCallNode(control_node))) {

View File

@ -66,8 +66,8 @@ constexpr size_t kMakeCSRTensorInputNum = 4;
const char kEntranceActorNameSuffix[] = "_EntranceActor";
const char kExitActorNameSuffix[] = "_ExitActor";
const char kStackActorNameSuffix[] = "_StackActor";
using FrontToBackendNodeWithContext = std::map<KernelWithIndex, std::set<std::pair<AnfNodePtr, DeviceContext *>>>;
using NodeWithContext = std::pair<AnfNodePtr, DeviceContext *>;
using FrontToBackendNodeWithContext = std::map<KernelWithIndex, std::set<NodeWithContext>>;
using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>;
using FuncGraphToKernelGraphGroup = mindspore::HashMap<FuncGraphPtr, std::vector<std::vector<KernelGraphPtr>>>;
using HostParameterToWeight = std::map<AnfNodePtr, std::set<AnfNodePtr>>;
@ -145,6 +145,7 @@ class ControlNodeParser {
KernelWithIndex FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index);
FuncGraphPtr FetchFuncGraphByKernelGraph(const KernelGraph *const graph);
std::string FetchGroupNameByKernelGraph(const KernelGraphPtr &graph);
NodeWithContext FetchBackendParameterWithContextByFrontParameter(const KernelWithIndex &front_parameter_with_index);
private:
friend class GraphScheduler;

View File

@ -1418,6 +1418,10 @@ void ControlNodeScheduler::LinkBranchIDArrow(ControlActor *const from_actor, Con
bool ControlNodeScheduler::CheckActorValid(const ActorSet *actor_set) const {
MS_EXCEPTION_IF_NULL(actor_set);
if (actor_set->control_actors_ == nullptr) {
return true;
}
for (const auto &kernel_actor : actor_set->kernel_actors_) {
std::string exit_actor_name = "";
for (const auto arrow : kernel_actor->output_data_arrows_) {

View File

@ -646,22 +646,16 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
}
}
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
const auto parser = graph_compiler_info.control_node_parser_;
MS_EXCEPTION_IF_NULL(parser);
// Initialize the parameter in the control node, first get all the front parameters in the control node, then find
// the corresponding backend parameter from the map, and insert it into the host data source actor
const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
// the corresponding backend parameter from the map, and insert it into the host data source actor.
const auto &control_node_parameters = parser->control_node_parameters();
for (const auto &parameter : control_node_parameters) {
if (IsPersistentDeviceTensor(parameter)) {
continue;
}
auto backend_iter = front_to_backend_parameter.find({parameter, 0});
if (backend_iter == front_to_backend_parameter.end() || backend_iter->second.empty()) {
MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter);
}
if (host_queue_ds_actor == nullptr) {
auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
@ -675,15 +669,18 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
if (node_map.find(parameter) != node_map.end()) {
continue;
}
const auto &backend_node = backend_iter->second.begin()->first;
const auto &backend_parameter_with_context =
parser->FetchBackendParameterWithContextByFrontParameter({parameter, 0});
const auto &backend_node = backend_parameter_with_context.first;
MS_EXCEPTION_IF_NULL(backend_node);
auto iter = find(host_queue_ds_actor->data_nodes_.begin(), host_queue_ds_actor->data_nodes_.end(), backend_node);
if (iter != host_queue_ds_actor->data_nodes_.end()) {
(void)node_map.emplace(parameter, iter - host_queue_ds_actor->data_nodes_.begin());
} else {
(void)node_map.emplace(parameter, host_queue_ds_actor->data_nodes_.size());
(void)node_map.emplace(backend_iter->second.begin()->first, host_queue_ds_actor->data_nodes_.size());
(void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.begin()->first);
(void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.begin()->second);
(void)node_map.emplace(backend_node, host_queue_ds_actor->data_nodes_.size());
(void)host_queue_ds_actor->data_nodes_.emplace_back(backend_node);
(void)host_queue_ds_actor->device_contexts_.emplace_back(backend_parameter_with_context.second);
}
}
@ -1940,14 +1937,13 @@ void GraphScheduler::PersistDeviceTensorForControlNode(const GraphCompilerInfo &
if ((!IsPersistentDeviceTensor(input_node)) || (!parser->IsRootGraphParameter(input_node))) {
continue;
}
const auto &front_to_backend_parameters = parser->front_to_backend_parameters();
const auto &iter = front_to_backend_parameters.find({input_node, 0});
if (iter == front_to_backend_parameters.end() || iter->second.empty()) {
const auto &backend_parameter_with_context =
parser->FetchBackendParameterWithContextByFrontParameter({input_node, 0});
if (backend_parameter_with_context.first == nullptr) {
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:" << input_node->DebugString();
}
const auto &node_with_context = iter->second.begin();
const auto &backend_node = node_with_context->first;
const auto &device_context = node_with_context->second;
const auto &backend_node = backend_parameter_with_context.first;
const auto &device_context = backend_parameter_with_context.second;
MS_EXCEPTION_IF_NULL(backend_node);
MS_EXCEPTION_IF_NULL(device_context);
if (!DeviceTensorStore::GetInstance().Fetch(input_node.get()).empty()) {