!32616 set value node fracz group
Merge pull request !32616 from jjfeing/master
This commit is contained in:
commit
a6d8b185e7
|
@ -71,14 +71,6 @@ bool HasFraczGroupAttrAndSet(const AnfNodePtr &node, size_t index, int64_t group
|
|||
param->set_fracz_group(groups);
|
||||
return false;
|
||||
}
|
||||
if (node->isa<ValueNode>()) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
if (value_node->fracz_group() != 1) {
|
||||
return true;
|
||||
}
|
||||
value_node->set_fracz_group(groups);
|
||||
return false;
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto node_name = common::AnfAlgo::GetCNodeName(cnode);
|
||||
|
@ -168,7 +160,7 @@ std::vector<KernelWithIndex> GetNeighborFraczNodes(const FuncGraphManagerPtr &ma
|
|||
size_t index, int64_t groups) {
|
||||
std::vector<KernelWithIndex> ret;
|
||||
auto node_user = manager->node_users();
|
||||
if (node->isa<Parameter>() || node->isa<ValueNode>()) {
|
||||
if (node->isa<Parameter>()) {
|
||||
std::transform(node_user[node].begin(), node_user[node].end(), std::back_inserter(ret),
|
||||
[](const KernelWithIndex &node_index) {
|
||||
return KernelWithIndex{node_index.first, node_index.second - 1};
|
||||
|
@ -218,8 +210,7 @@ bool SetAttrFraczGroup(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
|||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SetAttrFraczGroup(const FuncGraphPtr &func_graph, const T ¶m) {
|
||||
bool SetAttrFraczGroup(const FuncGraphPtr &func_graph, const ParameterPtr ¶m) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
auto groups = param->fracz_group();
|
||||
|
@ -265,12 +256,6 @@ bool SetFraczGroupAttr::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_EXCEPTION_IF_NULL(param);
|
||||
changed = SetAttrFraczGroup(func_graph, param) || changed;
|
||||
}
|
||||
if (node->isa<ValueNode>()) {
|
||||
// transmit fracz_group attr through multi graph by value node
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
changed = SetAttrFraczGroup(func_graph, value_node) || changed;
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
|
|
|
@ -116,6 +116,7 @@ const AnfNodePtr Conv2dBackpropFilterMul::Process(const FuncGraphPtr &func_graph
|
|||
}
|
||||
// CreateAssitValueNode
|
||||
auto value_node = CreateAssistNode(func_graph, node, shape, matrix_size);
|
||||
value_node->set_fracz_group(groups);
|
||||
MS_LOG(INFO) << "Create assist value node success.";
|
||||
// CreateMulNode
|
||||
std::vector<AnfNodePtr> mul_inputs{NewValueNode(std::make_shared<Primitive>(kMulOpName)), node, value_node};
|
||||
|
|
Loading…
Reference in New Issue