!32616 set value node fracz group

Merge pull request !32616 from jjfeing/master
This commit is contained in:
i-robot 2022-04-07 01:07:59 +00:00 committed by Gitee
commit a6d8b185e7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 3 additions and 17 deletions

View File

@ -71,14 +71,6 @@ bool HasFraczGroupAttrAndSet(const AnfNodePtr &node, size_t index, int64_t group
param->set_fracz_group(groups);
return false;
}
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
if (value_node->fracz_group() != 1) {
return true;
}
value_node->set_fracz_group(groups);
return false;
}
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto node_name = common::AnfAlgo::GetCNodeName(cnode);
@ -168,7 +160,7 @@ std::vector<KernelWithIndex> GetNeighborFraczNodes(const FuncGraphManagerPtr &ma
size_t index, int64_t groups) {
std::vector<KernelWithIndex> ret;
auto node_user = manager->node_users();
if (node->isa<Parameter>() || node->isa<ValueNode>()) {
if (node->isa<Parameter>()) {
std::transform(node_user[node].begin(), node_user[node].end(), std::back_inserter(ret),
[](const KernelWithIndex &node_index) {
return KernelWithIndex{node_index.first, node_index.second - 1};
@ -218,8 +210,7 @@ bool SetAttrFraczGroup(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
return true;
}
template <typename T>
bool SetAttrFraczGroup(const FuncGraphPtr &func_graph, const T &param) {
bool SetAttrFraczGroup(const FuncGraphPtr &func_graph, const ParameterPtr &param) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param);
auto groups = param->fracz_group();
@ -265,12 +256,6 @@ bool SetFraczGroupAttr::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(param);
changed = SetAttrFraczGroup(func_graph, param) || changed;
}
if (node->isa<ValueNode>()) {
// transmit fracz_group attr through multi graph by value node
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
changed = SetAttrFraczGroup(func_graph, value_node) || changed;
}
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {

View File

@ -116,6 +116,7 @@ const AnfNodePtr Conv2dBackpropFilterMul::Process(const FuncGraphPtr &func_graph
}
// CreateAssitValueNode
auto value_node = CreateAssistNode(func_graph, node, shape, matrix_size);
value_node->set_fracz_group(groups);
MS_LOG(INFO) << "Create assist value node success.";
// CreateMulNode
std::vector<AnfNodePtr> mul_inputs{NewValueNode(std::make_shared<Primitive>(kMulOpName)), node, value_node};