optimize converted list to tuple
This commit is contained in:
parent
f115941b24
commit
9108c6f60b
|
@ -172,7 +172,7 @@ void AddDynamicShapeAttrPass(const std::shared_ptr<session::KernelGraph> &kernel
|
|||
|
||||
PassManagerPtr GetEliminateIllegalDataTypePassManager() {
|
||||
auto pm = std::make_shared<PassManager>("common_eliminate_illegal_data_type_pm");
|
||||
pm->AddPass(std::make_shared<ConvertListToTuple>());
|
||||
pm->AddPass(std::make_shared<ConvertListToTuple>("convert_list_to_tuple"));
|
||||
pm->AddPass(std::make_shared<EliminateFuncDataType>());
|
||||
return pm;
|
||||
}
|
||||
|
|
|
@ -22,89 +22,30 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
static const std::map<std::string, std::string> kOpListToTupleNames = {{prim::kMakeListNew, prim::kMakeTuple},
|
||||
{prim::kListGetItem, prim::kTupleGetItem},
|
||||
{prim::kListSetItem, prim::kTupleSetItem}};
|
||||
static const size_t kMaxRecursiveDepth = 6;
|
||||
|
||||
const AnfNodePtr ConvertListToTuple::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// Value list --> Value tuple.
|
||||
if (node->isa<ValueNode>()) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
bool need_convert = false;
|
||||
auto convert_value = ConvertValueSequenceToValueTuple(value_node->value(), &need_convert);
|
||||
if (need_convert) {
|
||||
return std::make_shared<ValueNode>(convert_value);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
// List name --> tuple name.
|
||||
auto old_full_name = cnode->fullname_with_scope();
|
||||
auto old_name = common::AnfAlgo::GetCNodeName(cnode);
|
||||
auto iter = kOpListToTupleNames.find(old_name);
|
||||
if (iter != kOpListToTupleNames.end()) {
|
||||
auto primitive = GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
primitive->set_name(iter->second);
|
||||
// Reset full scope name.
|
||||
cnode->set_fullname_with_scope("");
|
||||
common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), cnode);
|
||||
MS_LOG(DEBUG) << "Rename op from " << old_name << " to " << iter->second << " for op " << old_full_name << " to "
|
||||
<< cnode->fullname_with_scope();
|
||||
}
|
||||
|
||||
// List abstract --> tuple abstract.
|
||||
auto new_abs = ConvertSequenceAbsToTupleAbs(node->abstract());
|
||||
if (new_abs != nullptr) {
|
||||
node->set_abstract(new_abs);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrAbstractAdaptationProcessed, MakeValue(true), cnode);
|
||||
MS_LOG(DEBUG) << "Convert sequence abstract to tuple abstract for op " << old_full_name << ", new op name "
|
||||
<< cnode->fullname_with_scope();
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// ValueSequence --> ValueTuple.
|
||||
ValuePtr ConvertListToTuple::ConvertValueSequenceToValueTuple(const ValuePtr &value, bool *need_convert,
|
||||
size_t depth) const {
|
||||
MS_EXCEPTION_IF_NULL(need_convert);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (depth > kMaxRecursiveDepth) {
|
||||
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxRecursiveDepth << " levels.";
|
||||
}
|
||||
|
||||
if (value->isa<ValueSequence>()) {
|
||||
std::vector<ValuePtr> elements;
|
||||
auto value_seq = value->cast<ValueSequencePtr>();
|
||||
(void)std::transform(value_seq->value().begin(), value_seq->value().end(), std::back_inserter(elements),
|
||||
[&](const ValuePtr &value) -> ValuePtr {
|
||||
bool is_convert = false;
|
||||
auto convert_value = ConvertValueSequenceToValueTuple(value, &is_convert, depth + 1);
|
||||
*need_convert |= is_convert;
|
||||
return convert_value;
|
||||
});
|
||||
*need_convert |= value->isa<ValueList>();
|
||||
if (*need_convert) {
|
||||
return std::make_shared<ValueTuple>(elements);
|
||||
bool ConvertListToTuple::Run(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
|
||||
for (auto node : node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// List abstract --> tuple abstract.
|
||||
auto new_abs = ConvertSequenceAbsToTupleAbs(node->abstract());
|
||||
if (new_abs != nullptr) {
|
||||
node->set_abstract(new_abs);
|
||||
if (node->isa<CNode>()) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrAbstractAdaptationProcessed, MakeValue(true), node);
|
||||
}
|
||||
MS_LOG(INFO) << "Convert sequence abstract to tuple abstract for op:" << node->fullname_with_scope()
|
||||
<< ",debug name:" << node->DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
return value;
|
||||
return true;
|
||||
}
|
||||
|
||||
// AbstractSequence --> AbstractTuple.
|
||||
AbstractBasePtr ConvertListToTuple::ConvertSequenceAbsToTupleAbs(const AbstractBasePtr &abs, size_t depth) const {
|
||||
if (abs == nullptr) {
|
||||
if (abs == nullptr || !abs->isa<abstract::AbstractSequence>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -113,48 +54,45 @@ AbstractBasePtr ConvertListToTuple::ConvertSequenceAbsToTupleAbs(const AbstractB
|
|||
}
|
||||
|
||||
auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
|
||||
if (abs_seq != nullptr) {
|
||||
// Dynamic length sequence convert by the dynamic abs.
|
||||
if (abs_seq->dynamic_len() && abs_seq->isa<abstract::AbstractList>()) {
|
||||
auto converted_dynamic_abs_tuple =
|
||||
std::make_shared<abstract::AbstractTuple>(abs_seq->elements(), abs_seq->sequence_nodes());
|
||||
converted_dynamic_abs_tuple->set_dynamic_len(true);
|
||||
converted_dynamic_abs_tuple->set_dynamic_len_element_abs(abs_seq->dynamic_len_element_abs());
|
||||
return converted_dynamic_abs_tuple;
|
||||
}
|
||||
const auto &seq_elements = abs_seq->elements();
|
||||
// First we check if elements should be converted,
|
||||
// changed_elements maps old element to new element.
|
||||
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
|
||||
for (const auto &element : seq_elements) {
|
||||
auto new_element = ConvertSequenceAbsToTupleAbs(element, depth + 1);
|
||||
if (new_element != nullptr) {
|
||||
(void)changed_elements.emplace(element, new_element);
|
||||
}
|
||||
}
|
||||
if (changed_elements.empty()) {
|
||||
if (abs->isa<abstract::AbstractTuple>()) {
|
||||
// If no elements changed and it is an AbstractTuple, do not convert.
|
||||
return nullptr;
|
||||
}
|
||||
// If no elements changed but it is not an AbstractTuple, convert it by copy elements.
|
||||
return std::make_shared<abstract::AbstractTuple>(seq_elements);
|
||||
}
|
||||
// Always make new AbstractTuple when elements changed.
|
||||
std::vector<AbstractBasePtr> elements;
|
||||
elements.reserve(seq_elements.size());
|
||||
for (const auto &element : seq_elements) {
|
||||
auto iter = changed_elements.find(element);
|
||||
if (iter != changed_elements.end()) {
|
||||
(void)elements.emplace_back(iter->second);
|
||||
} else {
|
||||
(void)elements.emplace_back(element);
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTuple>(std::move(elements));
|
||||
MS_EXCEPTION_IF_NULL(abs_seq);
|
||||
// Dynamic length sequence convert by the dynamic abs.
|
||||
if (abs_seq->dynamic_len() && abs_seq->isa<abstract::AbstractList>()) {
|
||||
auto converted_dynamic_abs_tuple =
|
||||
std::make_shared<abstract::AbstractTuple>(abs_seq->elements(), abs_seq->sequence_nodes());
|
||||
converted_dynamic_abs_tuple->set_dynamic_len(true);
|
||||
converted_dynamic_abs_tuple->set_dynamic_len_element_abs(abs_seq->dynamic_len_element_abs());
|
||||
return converted_dynamic_abs_tuple;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
const auto &seq_elements = abs_seq->elements();
|
||||
// First we check if elements should be converted,
|
||||
// changed_elements maps old element to new element.
|
||||
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
|
||||
for (const auto &element : seq_elements) {
|
||||
auto new_element = ConvertSequenceAbsToTupleAbs(element, depth + 1);
|
||||
if (new_element != nullptr) {
|
||||
(void)changed_elements.emplace(element, new_element);
|
||||
}
|
||||
}
|
||||
if (changed_elements.empty()) {
|
||||
if (abs->isa<abstract::AbstractTuple>()) {
|
||||
// If no elements changed and it is an AbstractTuple, do not convert.
|
||||
return nullptr;
|
||||
}
|
||||
// If no elements changed but it is not an AbstractTuple, convert it by copy elements.
|
||||
return std::make_shared<abstract::AbstractTuple>(seq_elements);
|
||||
}
|
||||
// Always make new AbstractTuple when elements changed.
|
||||
std::vector<AbstractBasePtr> elements;
|
||||
elements.reserve(seq_elements.size());
|
||||
for (const auto &element : seq_elements) {
|
||||
auto iter = changed_elements.find(element);
|
||||
if (iter != changed_elements.end()) {
|
||||
(void)elements.emplace_back(iter->second);
|
||||
} else {
|
||||
(void)elements.emplace_back(element);
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTuple>(std::move(elements));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,15 +23,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class BACKEND_EXPORT ConvertListToTuple : public PatternProcessPass {
|
||||
class BACKEND_EXPORT ConvertListToTuple : public Pass {
|
||||
public:
|
||||
explicit ConvertListToTuple(bool multigraph = true) : PatternProcessPass("convert_list_to_tuple", multigraph) {}
|
||||
explicit ConvertListToTuple(const std::string &name) : Pass("convert_list_to_tuple") {}
|
||||
~ConvertListToTuple() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
// ValueSequence --> ValueTuple.
|
||||
ValuePtr ConvertValueSequenceToValueTuple(const ValuePtr &value, bool *need_convert, size_t depth = 0) const;
|
||||
// AbstractSequence --> AbstractTuple.
|
||||
AbstractBasePtr ConvertSequenceAbsToTupleAbs(const AbstractBasePtr &abs, size_t depth = 0) const;
|
||||
};
|
||||
|
|
|
@ -994,7 +994,7 @@ void KernelGraphMgr::SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
|
|||
auto input = common::AnfAlgo::GetInputNode(make_tuple, i);
|
||||
auto node_abs = input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(node_abs);
|
||||
if (node_abs->isa<abstract::AbstractTuple>()) {
|
||||
if (node_abs->isa<abstract::AbstractSequence>()) {
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(
|
||||
i == 0, "Input index: " + std::to_string(i) + " is a make tuple, input node: " + input->DebugString());
|
||||
MS_LOG(DEBUG) << "Flatten the make tuple, input node: " << input->DebugString()
|
||||
|
|
|
@ -398,31 +398,6 @@ void MindRTBackendBase::ProcessNotSupportCnode(const FuncGraphPtr &func_graph,
|
|||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
void ExchangeRealTupleGetItem(const FuncGraphPtr &root_graph) {
|
||||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
MS_EXCEPTION_IF_NULL(root_graph->manager());
|
||||
FuncGraphSet graphs = root_graph->manager()->func_graphs();
|
||||
for (const auto &graph : graphs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto nodes = TopoSort(graph->get_return());
|
||||
for (const auto &node : nodes) {
|
||||
if (node == nullptr || (!node->isa<CNode>())) {
|
||||
continue;
|
||||
}
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (common::AnfAlgo::GetCNodeName(cnode) == prim::kTupleGetItem &&
|
||||
cnode->inputs().size() == kTupleGetItemInputSize &&
|
||||
(!cnode->input(kInputNodeOutputIndexInTupleGetItem)->isa<ValueNode>())) {
|
||||
cnode->set_input(0, mindspore::NewValueNode(std::make_shared<Primitive>(prim::kRealTupleGetItem)));
|
||||
MS_LOG(INFO) << "Exchange tuple get item to real tuple get item for node:" << cnode->DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const ActorInfo &MindRTBackendBase::CompileGraphs(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -431,10 +406,9 @@ const ActorInfo &MindRTBackendBase::CompileGraphs(const FuncGraphPtr &func_graph
|
|||
|
||||
auto root_graph = WrapPrimitives(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
UnifyMindIR(root_graph);
|
||||
root_graph_ = root_graph;
|
||||
|
||||
ExchangeRealTupleGetItem(root_graph);
|
||||
|
||||
// Register a summary callback function, which is called in the final stages of summary.
|
||||
graph_compiler_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
|
||||
|
||||
|
@ -487,6 +461,49 @@ const ActorInfo &MindRTBackendBase::CompileGraphs(const FuncGraphPtr &func_graph
|
|||
return actor_info;
|
||||
}
|
||||
|
||||
void MindRTBackendBase::UnifyMindIR(const FuncGraphPtr &root_graph) {
|
||||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
MS_EXCEPTION_IF_NULL(root_graph->manager());
|
||||
const std::map<std::string, std::string> kOpListToTupleNames = {{prim::kMakeListNew, prim::kMakeTuple},
|
||||
{prim::kListGetItem, prim::kTupleGetItem},
|
||||
{prim::kListSetItem, prim::kTupleSetItem}};
|
||||
|
||||
FuncGraphSet graphs = root_graph->manager()->func_graphs();
|
||||
for (const auto &graph : graphs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto nodes = TopoSort(graph->get_return());
|
||||
for (const auto &node : nodes) {
|
||||
if (node == nullptr || (!node->isa<CNode>())) {
|
||||
continue;
|
||||
}
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// List name --> tuple name.
|
||||
auto iter = kOpListToTupleNames.find(common::AnfAlgo::GetCNodeName(cnode));
|
||||
if (iter != kOpListToTupleNames.end()) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), cnode);
|
||||
cnode->set_input(0, mindspore::NewValueNode(std::make_shared<Primitive>(iter->second)));
|
||||
// Reset full scope name.
|
||||
cnode->set_fullname_with_scope("");
|
||||
MS_LOG(INFO) << "Rename op from " << iter->first << " to " << iter->second << " for op "
|
||||
<< cnode->fullname_with_scope() << ", debug name:" << cnode->DebugString();
|
||||
}
|
||||
|
||||
// TupleGetItem --> RealTupleGetItem.
|
||||
if (common::AnfAlgo::GetCNodeName(cnode) == prim::kTupleGetItem &&
|
||||
cnode->inputs().size() == kTupleGetItemInputSize &&
|
||||
(!cnode->input(kInputNodeOutputIndexInTupleGetItem)->isa<ValueNode>())) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), cnode);
|
||||
cnode->set_input(0, mindspore::NewValueNode(std::make_shared<Primitive>(prim::kRealTupleGetItem)));
|
||||
// Reset full scope name.
|
||||
cnode->set_fullname_with_scope("");
|
||||
MS_LOG(INFO) << "Rename op from TupleGetItem to RealTupleGetItem for op " << cnode->fullname_with_scope()
|
||||
<< ", debug name:" << cnode->DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MindRTBackendBase::CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
|
||||
auto root_graph = WrapPrimitives(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
|
|
|
@ -98,6 +98,9 @@ class BACKEND_EXPORT MindRTBackendBase : public Backend {
|
|||
const VectorRef &args, VectorRef *outputs) {}
|
||||
|
||||
protected:
|
||||
// Convert the nodes which are not supported in the backend.
|
||||
void UnifyMindIR(const FuncGraphPtr &func_graph);
|
||||
|
||||
// The parameter func_graph is a graph, it can be either a root graph or a sub graph,
|
||||
// The result of graph compiler is stored in graph_id_to_device_context_ and control_nodes_.
|
||||
void CompileGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode);
|
||||
|
|
|
@ -238,6 +238,9 @@ void DeviceAddressUtils::CreateValueNodeDeviceAddress(const DeviceContext *devic
|
|||
<< " addr:" << address;
|
||||
address->set_from_persistent_mem(true);
|
||||
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
|
||||
} else {
|
||||
MS_LOG(INFO) << "No device address for value node:" << value_node->fullname_with_scope()
|
||||
<< ", debug name:" << common::AnfAlgo::GetNodeDebugString(value_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue