!20899 Code review fix for GPU optimizer

Merge pull request !20899 from chengang/codefix
This commit is contained in:
i-robot 2021-07-27 13:20:33 +00:00 committed by Gitee
commit cae7f291c4
24 changed files with 159 additions and 88 deletions

View File

@ -34,6 +34,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
std::vector<TypeId> outputs_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
MS_EXCEPTION_IF_NULL(node);
size_t input_num = AnfAlgo::GetInputTensorNum(node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
@ -56,7 +57,9 @@ AnfNodePtr RelpaceOutputEdge(const AnfNodePtr &node, CNodePtr adam, AnfNodePtr u
// the execution order of FusedAdam and the following operators.
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
const size_t assign_index = 2;
const auto &n = node->cast<CNodePtr>()->input(assign_index);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &n = cnode->input(assign_index);
MS_EXCEPTION_IF_NULL(n);
const auto &fg = n->func_graph();
MS_EXCEPTION_IF_NULL(fg);
@ -73,8 +76,10 @@ AnfNodePtr RelpaceOutputEdge(const AnfNodePtr &node, CNodePtr adam, AnfNodePtr u
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
const size_t monad_index = 1;
const size_t adam_index = 2;
(user.first)->cast<CNodePtr>()->set_input(monad_index, u_input);
(user.first)->cast<CNodePtr>()->set_input(adam_index, adam);
auto cnode_ptr = (user.first)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode_ptr);
cnode_ptr->set_input(monad_index, u_input);
cnode_ptr->set_input(adam_index, adam);
break;
}
}

View File

@ -34,6 +34,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
std::vector<TypeId> outputs_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
MS_EXCEPTION_IF_NULL(node);
size_t input_num = AnfAlgo::GetInputTensorNum(node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
@ -56,7 +57,8 @@ AnfNodePtr ReplaceOutputEdge(const AnfNodePtr &node, CNodePtr adam_weight_decay,
// the execution order of FusedAdamWeightDecay and the following operators.
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
const size_t assign_index = 2;
const auto &n = node->cast<CNodePtr>()->input(assign_index);
auto cnode = node->cast<CNodePtr>();
const auto &n = cnode->input(assign_index);
MS_EXCEPTION_IF_NULL(n);
const auto &fg = n->func_graph();
MS_EXCEPTION_IF_NULL(fg);
@ -73,8 +75,10 @@ AnfNodePtr ReplaceOutputEdge(const AnfNodePtr &node, CNodePtr adam_weight_decay,
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
const size_t monad_index = 1;
const size_t adam_weight_decay_index = 2;
(user.first)->cast<CNodePtr>()->set_input(monad_index, u_input);
(user.first)->cast<CNodePtr>()->set_input(adam_weight_decay_index, adam_weight_decay);
auto cnode_ptr = (user.first)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode_ptr);
cnode_ptr->set_input(monad_index, u_input);
cnode_ptr->set_input(adam_weight_decay_index, adam_weight_decay);
break;
}
}

View File

@ -65,6 +65,9 @@ const AnfNodePtr AddReluGradV2Fusion::Process(const FuncGraphPtr &graph, const A
auto x1 = utils::cast<AnfNodePtr>((*equiv)[x1_]);
auto x2 = utils::cast<AnfNodePtr>((*equiv)[x2_]);
auto mask = utils::cast<AnfNodePtr>((*equiv)[mask_]);
MS_EXCEPTION_IF_NULL(x1);
MS_EXCEPTION_IF_NULL(x2);
MS_EXCEPTION_IF_NULL(mask);
auto tensor_add = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(tensor_add);

View File

@ -64,6 +64,8 @@ const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNo
MS_EXCEPTION_IF_NULL(equiv);
auto x1 = utils::cast<AnfNodePtr>((*equiv)[x1_]);
auto x2 = utils::cast<AnfNodePtr>((*equiv)[x2_]);
MS_EXCEPTION_IF_NULL(x1);
MS_EXCEPTION_IF_NULL(x2);
auto tensor_add = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(tensor_add);

View File

@ -15,6 +15,10 @@
*/
#include "backend/optimizer/gpu/apply_momentum_scale_fusion.h"
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
@ -26,7 +30,9 @@ bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
MS_EXCEPTION_IF_NULL(in);
auto shape = in->Shape()->cast<abstract::ShapePtr>();
auto shape_ptr = in->Shape();
MS_EXCEPTION_IF_NULL(shape_ptr);
auto shape = shape_ptr->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
if (shape->shape().size() != 0) {
return false;
@ -35,7 +41,11 @@ bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) {
if (dtype->type_id() != kObjectTypeTensorType) {
return false;
}
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
auto type_ptr = dyn_cast<TensorType>(dtype);
MS_EXCEPTION_IF_NULL(type_ptr);
auto element = type_ptr->element();
MS_EXCEPTION_IF_NULL(element);
auto element_type = element->type_id();
if (element_type != kNumberTypeFloat32) {
return false;
}

View File

@ -15,6 +15,8 @@
*/
#include "backend/optimizer/gpu/apply_momentum_weight_fusion.h"
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
@ -22,28 +24,6 @@
namespace mindspore {
namespace opt {
bool ApplyMomentumWeightDecayFusion::IsScalar(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
MS_EXCEPTION_IF_NULL(in);
auto shape = in->Shape()->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
if (shape->shape().size() != 0) {
return false;
}
auto dtype = in->Type();
if (dtype->type_id() != kObjectTypeTensorType) {
return false;
}
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
if (element_type != kNumberTypeFloat32) {
return false;
}
return true;
}
return false;
}
const BaseRef ApplyMomentumWeightDecayFusion::DefinePattern() const {
VectorRef load_para = VectorRef({prim::kPrimLoad, variable_, monad_});
VectorRef weight_decay =

View File

@ -39,8 +39,6 @@ class ApplyMomentumWeightDecayFusion : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
static bool IsScalar(const BaseRef &n);
VarPtr monad_;
VarPtr weight_decay_;
VarPtr variable_;

View File

@ -32,16 +32,23 @@ bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
MS_EXCEPTION_IF_NULL(in);
auto shape = in->Shape()->cast<abstract::ShapePtr>();
auto shape_ptr = in->Shape();
MS_EXCEPTION_IF_NULL(shape_ptr);
auto shape = shape_ptr->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
if (shape->shape().size() != 0) {
return false;
}
auto dtype = in->Type();
MS_EXCEPTION_IF_NULL(dtype);
if (dtype->type_id() != kObjectTypeTensorType) {
return false;
}
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
auto type_ptr = dyn_cast<TensorType>(dtype);
MS_EXCEPTION_IF_NULL(type_ptr);
auto element = type_ptr->element();
MS_EXCEPTION_IF_NULL(element);
auto element_type = element->type_id();
if (element_type != kNumberTypeFloat32) {
return false;
}

View File

@ -36,7 +36,7 @@ const BaseRef BatchNormAddReluFusion::DefinePattern() const {
}
const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);

View File

@ -35,6 +35,7 @@ constexpr size_t kBNAddReluGradOutputNum = 4;
bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(bn);
MS_EXCEPTION_IF_NULL(bn_outputs);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
@ -121,7 +122,7 @@ bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
return false;
}
auto shape = AnfAlgo::GetInputDeviceShape(node, 0);
if (shape.back() % kBNChannelMultipleFactor != 0) {
if ((shape.back() % kBNChannelMultipleFactor) != 0) {
return false;
}
@ -188,7 +189,6 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
if (!GetValue<bool>(is_train)) {
return nullptr;
}
auto prim = std::make_shared<Primitive>(kBatchNormGradWithAddAndActivation);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y};

View File

@ -35,7 +35,7 @@ const BaseRef BatchNormReluFusion::DefinePattern() const {
}
const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);

View File

@ -36,7 +36,7 @@ const BaseRef BatchNormReluGradFusion::DefinePattern() const {
}
const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto is_train = AnfAlgo::GetCNodePrimitive(node)->GetAttr("is_training");

View File

@ -23,12 +23,14 @@ namespace mindspore {
namespace opt {
namespace {
bool GetDealList(const std::vector<AnfNodePtr> &node_list, std::vector<std::vector<AnfNodePtr>> *deal_list) {
MS_EXCEPTION_IF_NULL(deal_list);
std::vector<AnfNodePtr> cast_32to16_list;
std::vector<AnfNodePtr> cast_16to32_list;
AnfNodePtr cast_32to16_load_monad = nullptr;
AnfNodePtr cast_16to32_load_monad = nullptr;
constexpr size_t second_input_index = 2;
for (auto &cast_node : node_list) {
MS_EXCEPTION_IF_NULL(cast_node);
// currently, we only deal with the construct : [Param->Cast->] to avoid being a cycle.
// { prim::kPrimCast, { prim::kPrimLoad, Parameter, U }}
if (!IsPrimitiveCNode(cast_node, prim::kPrimCast)) {
@ -88,43 +90,51 @@ bool CastAllFusion::Run(const FuncGraphPtr &graph) {
auto prim = std::make_shared<Primitive>("CastAll");
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
// set inputs for CastAll
for (size_t idx = 0; idx < cast_list.size(); ++idx) {
inputs.push_back(AnfAlgo::GetInputNode(utils::cast<CNodePtr>(cast_list[idx]), 0));
}
TraceGuard guard(std::make_shared<TraceOpt>(cast_list[0]->debug_info()));
auto cast_all = graph->NewCNode(inputs);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
cast_all->set_kernel_info(kernel_info);
AbstractBasePtrList abstract_list;
for (size_t idx = 0; idx < cast_list.size(); ++idx) {
auto cnode = utils::cast<CNodePtr>(cast_list[idx]);
MS_EXCEPTION_IF_NULL(cnode);
abstract_list.push_back(cnode->abstract());
inputs.push_back(AnfAlgo::GetInputNode(cnode, 0));
}
auto kernel_build_info = GenerateKernelBuildInfo(cast_list);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, cast_all.get());
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
MS_EXCEPTION_IF_NULL(abstract_tuple);
cast_all->set_abstract(abstract_tuple);
AnfAlgo::SetNodeAttr("n", MakeValue(cast_list.size()), cast_all);
// 3 replace all the cast by CastAllv tuplegetitem[castall, idx]
for (size_t idx = 0; idx < cast_list.size(); ++idx) {
std::vector<AnfNodePtr> tuple_getitem_input;
tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem));
tuple_getitem_input.push_back(cast_all);
auto index = NewValueNode(SizeToLong(idx));
auto imm = std::make_shared<Int64Imm>(idx);
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
MS_EXCEPTION_IF_NULL(abstract_scalar);
index->set_abstract(abstract_scalar);
tuple_getitem_input.push_back(index);
AnfNodePtr tuple_getitem = graph->NewCNode(tuple_getitem_input);
MS_EXCEPTION_IF_NULL(tuple_getitem);
tuple_getitem->set_abstract(cast_list[idx]->abstract());
if (!manager->Replace(cast_list[idx], tuple_getitem)) {
MS_LOG(EXCEPTION) << "manager replace node failed";
if (cast_list.size() > 0) {
TraceGuard guard(std::make_shared<TraceOpt>(cast_list[0]->debug_info()));
auto cast_all = graph->NewCNode(inputs);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(cast_all);
MS_EXCEPTION_IF_NULL(kernel_info);
cast_all->set_kernel_info(kernel_info);
AbstractBasePtrList abstract_list;
for (size_t idx = 0; idx < cast_list.size(); ++idx) {
auto cnode = utils::cast<CNodePtr>(cast_list[idx]);
MS_EXCEPTION_IF_NULL(cnode);
abstract_list.push_back(cnode->abstract());
}
auto kernel_build_info = GenerateKernelBuildInfo(cast_list);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, cast_all.get());
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
MS_EXCEPTION_IF_NULL(abstract_tuple);
cast_all->set_abstract(abstract_tuple);
AnfAlgo::SetNodeAttr("n", MakeValue(cast_list.size()), cast_all);
// 3 replace all the cast by CastAllv tuplegetitem[castall, idx]
for (size_t idx = 0; idx < cast_list.size(); ++idx) {
std::vector<AnfNodePtr> tuple_getitem_input;
tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem));
tuple_getitem_input.push_back(cast_all);
auto index = NewValueNode(SizeToLong(idx));
auto imm = std::make_shared<Int64Imm>(idx);
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
MS_EXCEPTION_IF_NULL(index);
MS_EXCEPTION_IF_NULL(abstract_scalar);
index->set_abstract(abstract_scalar);
tuple_getitem_input.push_back(index);
AnfNodePtr tuple_getitem = graph->NewCNode(tuple_getitem_input);
MS_EXCEPTION_IF_NULL(tuple_getitem);
tuple_getitem->set_abstract(cast_list[idx]->abstract());
if (!manager->Replace(cast_list[idx], tuple_getitem)) {
MS_LOG(EXCEPTION) << "manager replace node failed";
}
}
} else {
MS_LOG(EXCEPTION) << "The size of cast_list is zero.";
}
}
return true;

View File

@ -22,6 +22,7 @@
namespace mindspore {
namespace opt {
bool GetDealList(const std::vector<AnfNodePtr> &node_list, std::vector<std::vector<AnfNodePtr>> *deal_list) {
MS_EXCEPTION_IF_NULL(deal_list);
std::vector<AnfNodePtr> momentum;
std::vector<AnfNodePtr> momentum_decay;
for (auto &momentum_node : node_list) {
@ -55,6 +56,9 @@ bool CombineMomentumFusion::Run(const FuncGraphPtr &graph) {
return false;
}
for (auto momentums : deal_list) {
if (momentums.size() == 0) {
MS_LOG(EXCEPTION) << "The size of momentums is zero.";
}
// 2 create node momentum
std::vector<AnfNodePtr> inputs = {};
if (AnfAlgo::GetCNodeName(momentums[0]) == kFusedScaleApplyMomentum) {
@ -70,12 +74,15 @@ bool CombineMomentumFusion::Run(const FuncGraphPtr &graph) {
size_t input_num = AnfAlgo::GetInputTensorNum(momentums[0]);
for (auto mom : momentums) {
for (size_t i = 0; i < input_num; i++) {
inputs.push_back(AnfAlgo::GetInputNode(utils::cast<CNodePtr>(mom), i));
auto cnode = utils::cast<CNodePtr>(mom);
MS_EXCEPTION_IF_NULL(cnode);
inputs.push_back(AnfAlgo::GetInputNode(cnode, i));
}
}
TraceGuard guard(std::make_shared<TraceOpt>(momentums[0]->debug_info()));
auto combine_mom = graph->NewCNode(inputs);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(combine_mom);
MS_EXCEPTION_IF_NULL(kernel_info);
combine_mom->set_kernel_info(kernel_info);
AbstractBasePtrList abstract_list;

View File

@ -123,13 +123,19 @@ std::pair<size_t, bool> GetCoverIndex(const std::vector<AnfNodeIndex> &inplace_n
}
auto first_node_prim = AnfAlgo::GetCNodePrimitive(first_node);
MS_EXCEPTION_IF_NULL(first_node_prim);
auto first_node_channel = first_node_prim.get()->GetAttr("out_channel");
MS_EXCEPTION_IF_NULL(first_node_channel);
size_t first_channel = first_node_channel->cast<Int64ImmPtr>()->value();
auto first_imm_ptr = first_node_channel->cast<Int64ImmPtr>();
MS_EXCEPTION_IF_NULL(first_imm_ptr);
size_t first_channel = first_imm_ptr->value();
auto second_node_prim = AnfAlgo::GetCNodePrimitive(second_node);
MS_EXCEPTION_IF_NULL(second_node_prim);
auto second_node_channel = second_node_prim.get()->GetAttr("out_channel");
MS_EXCEPTION_IF_NULL(second_node_channel);
size_t second_channel = second_node_channel->cast<Int64ImmPtr>()->value();
auto second_imm_ptr = second_node_channel->cast<Int64ImmPtr>();
MS_EXCEPTION_IF_NULL(second_imm_ptr);
size_t second_channel = second_imm_ptr->value();
size_t cover_index = (first_channel >= second_channel) ? 0 : 1;
bool ret = ExistDependencyFromAcc2Cover(inplace_node, cover_index);
if (ret) {
@ -165,6 +171,8 @@ void CheckInplaceNodeInputs(std::vector<AnfNodeIndex> *inplace_node, size_t cove
// | | | | |
// Cover Acc | Acc |
// Cover---------------+
MS_EXCEPTION_IF_NULL(inplace_node);
MS_EXCEPTION_IF_NULL(graph);
size_t acc_index = cover_index == 1 ? 0 : 1;
const CNodePtr &cover_node = inplace_node->at(cover_index).node->cast<CNodePtr>();
const CNodePtr &acc_node = inplace_node->at(acc_index).node->cast<CNodePtr>();
@ -177,9 +185,11 @@ void CheckInplaceNodeInputs(std::vector<AnfNodeIndex> *inplace_node, size_t cove
bool ret = ExistRoute(acc_input, cover_node);
if (ret) {
auto new_input = graph->NewCNode(acc_input->inputs());
MS_EXCEPTION_IF_NULL(new_input);
new_input->set_abstract(acc_input->abstract());
CopyKernelInfo(acc_input, new_input);
auto new_inplace_node = graph->NewCNode({acc_node->input(0), new_input, acc_node->input(2)});
MS_EXCEPTION_IF_NULL(new_inplace_node);
new_inplace_node->set_abstract(acc_node->abstract());
CopyKernelInfo(acc_node, new_inplace_node);
auto manager = graph->manager();
@ -191,6 +201,10 @@ void CheckInplaceNodeInputs(std::vector<AnfNodeIndex> *inplace_node, size_t cove
void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<AnfNodeIndex> *inplace_node,
const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(skip_node);
MS_EXCEPTION_IF_NULL(inplace_node);
MS_EXCEPTION_IF_NULL(graph);
SetPrimAttr(aggregate_node.node, "aggregate", true);
SetPrimAttr(aggregate_node.node, "aggregate_input_index", aggregate_node.index);
SetPrimAttr(skip_node, "skip", true);
@ -202,6 +216,7 @@ void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<
for (size_t i = 0; i < inplace_node->size(); i++) {
auto algo = (i == cover_index) ? "cover" : "accumulation";
auto node = (*inplace_node)[i].node;
MS_EXCEPTION_IF_NULL(node);
SetPrimAttr(node, "inplace_algo", algo);
SetPrimAttr(node, "inplace_group", group);
SetPrimAttr(node, "inplace_output_index", (*inplace_node)[i].index);
@ -209,10 +224,13 @@ void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<
if (order_required && i != cover_index) {
auto acc_node = node;
auto cover_node = (*inplace_node)[cover_index].node;
auto acc_node_input = acc_node->cast<CNodePtr>()->input(1);
auto cnode = acc_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto acc_node_input = cnode->input(1);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
acc_node_input, cover_node};
auto depend_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(depend_node);
depend_node->set_abstract(acc_node_input->abstract());
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
@ -224,6 +242,9 @@ void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<
bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node,
std::vector<AnfNodeIndex> *inplace) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(inplace);
MS_EXCEPTION_IF_NULL(skip_node);
MS_EXCEPTION_IF_NULL(aggregate);
if (!node->isa<CNode>()) {

View File

@ -77,9 +77,12 @@ void SetTransposeOpBuildInfo(const std::string &input_format, const std::string
// Insert transpose op between node and used_node whose position is used_node_index.
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node,
int used_node_index, const std::vector<int64_t> &transpose_perm) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(used_node);
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
<< ", index: " << used_node_index;
MS_EXCEPTION_IF_NULL(graph);
// 0.Judge whether it is a fake transpose
auto transed_shape = AnfAlgo::GetInputDeviceShape(used_node, used_node_index);
bool is_fake = IsFakeTranspose(transed_shape, transpose_perm);
@ -94,6 +97,7 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
// 2.Set the input of transpose.
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node};
auto transpose_op = graph->NewCNode(transpose_input);
MS_EXCEPTION_IF_NULL(transpose_op);
// 3.Set the output info of transpose.
auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)};
@ -144,6 +148,7 @@ const AnfNodePtr InsertFormatTransformOp::Process(const FuncGraphPtr &graph, con
if ((outputs_format[i] != kOpFormat_DEFAULT) && (outputs_format[i] != origin_data_format)) {
// Find all nodes connected with node output, and change their inputs to transpose.
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, i);
MS_EXCEPTION_IF_NULL(used_node_list);
for (size_t j = 0; j < used_node_list->size(); j++) {
auto used_node = used_node_list->at(j).first;
auto used_node_index = used_node_list->at(j).second - 1;
@ -166,6 +171,7 @@ void InsertFormatTransformOp::ProcessForTupleItem(const FuncGraphPtr &graph, con
const std::vector<int64_t> &transpose_perm,
const std::string &transpose_format) const {
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, node_index);
MS_EXCEPTION_IF_NULL(used_node_list);
for (size_t i = 0; i < used_node_list->size(); i++) {
auto used_node = used_node_list->at(i).first;
auto used_node_index = used_node_list->at(i).second - 1;

View File

@ -36,7 +36,7 @@ const BaseRef PostBatchNormAddReluFusion::DefinePattern() const {
}
const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);

View File

@ -55,8 +55,13 @@ bool GetOptList(const std::vector<AnfNodePtr> &node_list, std::vector<AnfNodePtr
std::vector<std::vector<int64_t>> *string_pos_vec,
std::vector<std::vector<std::string>> *string_value_vec,
std::vector<std::vector<std::pair<int64_t, int64_t>>> *not_tensor_pos_vec) {
MS_EXCEPTION_IF_NULL(opt_list);
MS_EXCEPTION_IF_NULL(string_pos_vec);
MS_EXCEPTION_IF_NULL(string_value_vec);
for (auto &node : node_list) {
// {prim::kPrimPrint} reduction only applies on print with string, tensor(scalar or tuple)
MS_EXCEPTION_IF_NULL(node);
std::vector<int64_t> string_pos;
std::vector<std::string> string_value;
std::vector<std::pair<int64_t, int64_t>> value_type;
@ -69,7 +74,10 @@ bool GetOptList(const std::vector<AnfNodePtr> &node_list, std::vector<AnfNodePtr
continue;
}
auto value_node = current_node->cast<ValueNodePtr>();
auto shape_node = dyn_cast<abstract::Shape>(value_node->abstract()->GetShapeTrack());
MS_EXCEPTION_IF_NULL(value_node);
auto shape = value_node->abstract();
MS_EXCEPTION_IF_NULL(shape);
auto shape_node = dyn_cast<abstract::Shape>(shape->GetShapeTrack());
if (shape_node != nullptr) {
// a scalar or tuple
auto shape_size = shape_node->shape().size();
@ -84,7 +92,9 @@ bool GetOptList(const std::vector<AnfNodePtr> &node_list, std::vector<AnfNodePtr
// not a string
continue;
}
if (node_value->type()->generic_type_id() == kObjectTypeString) {
auto type = node_value->type();
MS_EXCEPTION_IF_NULL(type);
if (type->generic_type_id() == kObjectTypeString) {
auto current_string_value = GetValue<std::string>(node_value);
string_pos.push_back(i);
string_value.push_back(std::string(current_string_value));
@ -122,6 +132,7 @@ bool PrintReduceFusion::Run(const FuncGraphPtr &graph) {
for (size_t idx = 0; idx < opt_list.size(); idx++) {
auto node = opt_list[idx];
CNodePtr cnode = utils::cast<CNodePtr>(node);
MS_EXCEPTION_IF_NULL(cnode);
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
auto prim = std::make_shared<Primitive>("Print");
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
@ -157,6 +168,7 @@ bool PrintReduceFusion::Run(const FuncGraphPtr &graph) {
[](const std::pair<int64_t, int64_t> &value) { return value.second; });
// create new cnode
auto print_fused = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(print_fused);
// hand over the attrs to new print
AnfAlgo::SetNodeAttr("string_pos", MakeValue<std::vector<int64_t>>(string_pos), print_fused);
AnfAlgo::SetNodeAttr("string_value", MakeValue<std::vector<std::string>>(string_value), print_fused);

View File

@ -29,10 +29,13 @@ namespace opt {
namespace {
void ReducePrecision(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t i, const TypeId &src_type,
const TypeId &cast_type) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto prim = std::make_shared<Primitive>(prim::kPrimCast->name());
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i)};
auto cast = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cast);
auto cast_shape = {AnfAlgo::GetInputDeviceShape(node, i)};
AnfAlgo::SetOutputInferTypeAndShape({cast_type}, cast_shape, cast.get());
FuncGraphManagerPtr manager = graph->manager();
@ -49,7 +52,10 @@ void ReducePrecision(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t i
}
void ProcessTupleGetItem(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t node_index, const TypeId &src_type,
const TypeId &cast_type) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, node_index);
MS_EXCEPTION_IF_NULL(used_node_list);
for (size_t i = 0; i < used_node_list->size(); i++) {
auto used_node = used_node_list->at(i).first;
auto used_node_index = used_node_list->at(i).second - 1;
@ -64,6 +70,7 @@ bool ReducePrecisionFusion::Run(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
for (auto node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
size_t input_num = AnfAlgo::GetInputTensorNum(node);
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
@ -83,6 +90,7 @@ bool ReducePrecisionFusion::Run(const FuncGraphPtr &graph) {
continue;
}
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, i);
MS_EXCEPTION_IF_NULL(used_node_list);
for (size_t j = 0; j < used_node_list->size(); j++) {
auto used_node = used_node_list->at(j).first;
auto used_node_index = used_node_list->at(j).second - 1;

View File

@ -40,6 +40,7 @@ CNodePtr GetRelu(const CNodePtr &relu_grad) {
}
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
MS_EXCEPTION_IF_NULL(node);
std::vector<std::string> inputs_format;
std::vector<std::string> outputs_format;
std::vector<TypeId> inputs_type;

View File

@ -34,10 +34,9 @@ const BaseRef RemoveFormatTransformPair::DefinePattern() const {
}
const AnfNodePtr RemoveFormatTransformPair::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
MS_LOG(DEBUG) << "Process node:" << node->fullname_with_scope();
auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(input_node);

View File

@ -32,15 +32,15 @@ const BaseRef RemoveRedundantFormatTransform::DefinePattern() const {
}
const AnfNodePtr RemoveRedundantFormatTransform::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
MS_LOG(DEBUG) << "Process node:" << node->fullname_with_scope();
auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(input_node);
AnfNodePtr first_transpose = nullptr;
auto used_node_list = GetRealNodeUsedList(graph, input_node);
MS_EXCEPTION_IF_NULL(used_node_list);
for (size_t j = 0; j < used_node_list->size(); j++) {
auto used_node = used_node_list->at(j).first;
if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTranspose->name()) {

View File

@ -26,11 +26,9 @@ const BaseRef ReplaceAddNFusion::DefinePattern() const {
return addn;
}
const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto A = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto B = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 1);
MS_EXCEPTION_IF_NULL(A);
@ -41,6 +39,7 @@ const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const Anf
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), A, B};
auto add_new = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add_new);
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(A, 0));

View File

@ -28,10 +28,9 @@ const BaseRef ReplaceMomentumCastFusion::DefinePattern() const {
}
const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), kGradIndex);
MS_EXCEPTION_IF_NULL(grad_cast);