forked from mindspore-Ecosystem/mindspore
!20899 Code review fix for GPU optimizer
Merge pull request !20899 from chengang/codefix
This commit is contained in:
commit
cae7f291c4
|
@ -34,6 +34,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
|||
std::vector<TypeId> outputs_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
|
@ -56,7 +57,9 @@ AnfNodePtr RelpaceOutputEdge(const AnfNodePtr &node, CNodePtr adam, AnfNodePtr u
|
|||
// the execution order of FusedAdam and the following operators.
|
||||
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
|
||||
const size_t assign_index = 2;
|
||||
const auto &n = node->cast<CNodePtr>()->input(assign_index);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &n = cnode->input(assign_index);
|
||||
MS_EXCEPTION_IF_NULL(n);
|
||||
const auto &fg = n->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
|
@ -73,8 +76,10 @@ AnfNodePtr RelpaceOutputEdge(const AnfNodePtr &node, CNodePtr adam, AnfNodePtr u
|
|||
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
|
||||
const size_t monad_index = 1;
|
||||
const size_t adam_index = 2;
|
||||
(user.first)->cast<CNodePtr>()->set_input(monad_index, u_input);
|
||||
(user.first)->cast<CNodePtr>()->set_input(adam_index, adam);
|
||||
auto cnode_ptr = (user.first)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
cnode_ptr->set_input(monad_index, u_input);
|
||||
cnode_ptr->set_input(adam_index, adam);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
|||
std::vector<TypeId> outputs_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
|
@ -56,7 +57,8 @@ AnfNodePtr ReplaceOutputEdge(const AnfNodePtr &node, CNodePtr adam_weight_decay,
|
|||
// the execution order of FusedAdamWeightDecay and the following operators.
|
||||
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
|
||||
const size_t assign_index = 2;
|
||||
const auto &n = node->cast<CNodePtr>()->input(assign_index);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
const auto &n = cnode->input(assign_index);
|
||||
MS_EXCEPTION_IF_NULL(n);
|
||||
const auto &fg = n->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
|
@ -73,8 +75,10 @@ AnfNodePtr ReplaceOutputEdge(const AnfNodePtr &node, CNodePtr adam_weight_decay,
|
|||
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
|
||||
const size_t monad_index = 1;
|
||||
const size_t adam_weight_decay_index = 2;
|
||||
(user.first)->cast<CNodePtr>()->set_input(monad_index, u_input);
|
||||
(user.first)->cast<CNodePtr>()->set_input(adam_weight_decay_index, adam_weight_decay);
|
||||
auto cnode_ptr = (user.first)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
cnode_ptr->set_input(monad_index, u_input);
|
||||
cnode_ptr->set_input(adam_weight_decay_index, adam_weight_decay);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -65,6 +65,9 @@ const AnfNodePtr AddReluGradV2Fusion::Process(const FuncGraphPtr &graph, const A
|
|||
auto x1 = utils::cast<AnfNodePtr>((*equiv)[x1_]);
|
||||
auto x2 = utils::cast<AnfNodePtr>((*equiv)[x2_]);
|
||||
auto mask = utils::cast<AnfNodePtr>((*equiv)[mask_]);
|
||||
MS_EXCEPTION_IF_NULL(x1);
|
||||
MS_EXCEPTION_IF_NULL(x2);
|
||||
MS_EXCEPTION_IF_NULL(mask);
|
||||
|
||||
auto tensor_add = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
MS_EXCEPTION_IF_NULL(tensor_add);
|
||||
|
|
|
@ -64,6 +64,8 @@ const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNo
|
|||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto x1 = utils::cast<AnfNodePtr>((*equiv)[x1_]);
|
||||
auto x2 = utils::cast<AnfNodePtr>((*equiv)[x2_]);
|
||||
MS_EXCEPTION_IF_NULL(x1);
|
||||
MS_EXCEPTION_IF_NULL(x2);
|
||||
|
||||
auto tensor_add = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
MS_EXCEPTION_IF_NULL(tensor_add);
|
||||
|
|
|
@ -15,6 +15,10 @@
|
|||
*/
|
||||
#include "backend/optimizer/gpu/apply_momentum_scale_fusion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
|
@ -26,7 +30,9 @@ bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) {
|
|||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
|
||||
MS_EXCEPTION_IF_NULL(in);
|
||||
auto shape = in->Shape()->cast<abstract::ShapePtr>();
|
||||
auto shape_ptr = in->Shape();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
auto shape = shape_ptr->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
if (shape->shape().size() != 0) {
|
||||
return false;
|
||||
|
@ -35,7 +41,11 @@ bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) {
|
|||
if (dtype->type_id() != kObjectTypeTensorType) {
|
||||
return false;
|
||||
}
|
||||
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
|
||||
auto type_ptr = dyn_cast<TensorType>(dtype);
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
auto element = type_ptr->element();
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
auto element_type = element->type_id();
|
||||
if (element_type != kNumberTypeFloat32) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
#include "backend/optimizer/gpu/apply_momentum_weight_fusion.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
|
@ -22,28 +24,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool ApplyMomentumWeightDecayFusion::IsScalar(const BaseRef &n) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
|
||||
MS_EXCEPTION_IF_NULL(in);
|
||||
auto shape = in->Shape()->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
if (shape->shape().size() != 0) {
|
||||
return false;
|
||||
}
|
||||
auto dtype = in->Type();
|
||||
if (dtype->type_id() != kObjectTypeTensorType) {
|
||||
return false;
|
||||
}
|
||||
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
|
||||
if (element_type != kNumberTypeFloat32) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const BaseRef ApplyMomentumWeightDecayFusion::DefinePattern() const {
|
||||
VectorRef load_para = VectorRef({prim::kPrimLoad, variable_, monad_});
|
||||
VectorRef weight_decay =
|
||||
|
|
|
@ -39,8 +39,6 @@ class ApplyMomentumWeightDecayFusion : public PatternProcessPass {
|
|||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
static bool IsScalar(const BaseRef &n);
|
||||
|
||||
VarPtr monad_;
|
||||
VarPtr weight_decay_;
|
||||
VarPtr variable_;
|
||||
|
|
|
@ -32,16 +32,23 @@ bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) {
|
|||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
|
||||
MS_EXCEPTION_IF_NULL(in);
|
||||
auto shape = in->Shape()->cast<abstract::ShapePtr>();
|
||||
auto shape_ptr = in->Shape();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
auto shape = shape_ptr->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
if (shape->shape().size() != 0) {
|
||||
return false;
|
||||
}
|
||||
auto dtype = in->Type();
|
||||
MS_EXCEPTION_IF_NULL(dtype);
|
||||
if (dtype->type_id() != kObjectTypeTensorType) {
|
||||
return false;
|
||||
}
|
||||
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
|
||||
auto type_ptr = dyn_cast<TensorType>(dtype);
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
auto element = type_ptr->element();
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
auto element_type = element->type_id();
|
||||
if (element_type != kNumberTypeFloat32) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ const BaseRef BatchNormAddReluFusion::DefinePattern() const {
|
|||
}
|
||||
|
||||
const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ constexpr size_t kBNAddReluGradOutputNum = 4;
|
|||
|
||||
bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn);
|
||||
MS_EXCEPTION_IF_NULL(bn_outputs);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -121,7 +122,7 @@ bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(node, 0);
|
||||
if (shape.back() % kBNChannelMultipleFactor != 0) {
|
||||
if ((shape.back() % kBNChannelMultipleFactor) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -188,7 +189,6 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
|
|||
if (!GetValue<bool>(is_train)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kBatchNormGradWithAddAndActivation);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y};
|
||||
|
|
|
@ -35,7 +35,7 @@ const BaseRef BatchNormReluFusion::DefinePattern() const {
|
|||
}
|
||||
|
||||
const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ const BaseRef BatchNormReluGradFusion::DefinePattern() const {
|
|||
}
|
||||
|
||||
const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto is_train = AnfAlgo::GetCNodePrimitive(node)->GetAttr("is_training");
|
||||
|
|
|
@ -23,12 +23,14 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace {
|
||||
bool GetDealList(const std::vector<AnfNodePtr> &node_list, std::vector<std::vector<AnfNodePtr>> *deal_list) {
|
||||
MS_EXCEPTION_IF_NULL(deal_list);
|
||||
std::vector<AnfNodePtr> cast_32to16_list;
|
||||
std::vector<AnfNodePtr> cast_16to32_list;
|
||||
AnfNodePtr cast_32to16_load_monad = nullptr;
|
||||
AnfNodePtr cast_16to32_load_monad = nullptr;
|
||||
constexpr size_t second_input_index = 2;
|
||||
for (auto &cast_node : node_list) {
|
||||
MS_EXCEPTION_IF_NULL(cast_node);
|
||||
// currently, we only deal with the construct : [Param->Cast->] to avoid being a cycle.
|
||||
// { prim::kPrimCast, { prim::kPrimLoad, Parameter, U }}
|
||||
if (!IsPrimitiveCNode(cast_node, prim::kPrimCast)) {
|
||||
|
@ -88,43 +90,51 @@ bool CastAllFusion::Run(const FuncGraphPtr &graph) {
|
|||
auto prim = std::make_shared<Primitive>("CastAll");
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||
// set inputs for CastAll
|
||||
for (size_t idx = 0; idx < cast_list.size(); ++idx) {
|
||||
inputs.push_back(AnfAlgo::GetInputNode(utils::cast<CNodePtr>(cast_list[idx]), 0));
|
||||
}
|
||||
TraceGuard guard(std::make_shared<TraceOpt>(cast_list[0]->debug_info()));
|
||||
auto cast_all = graph->NewCNode(inputs);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
cast_all->set_kernel_info(kernel_info);
|
||||
AbstractBasePtrList abstract_list;
|
||||
for (size_t idx = 0; idx < cast_list.size(); ++idx) {
|
||||
auto cnode = utils::cast<CNodePtr>(cast_list[idx]);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
abstract_list.push_back(cnode->abstract());
|
||||
inputs.push_back(AnfAlgo::GetInputNode(cnode, 0));
|
||||
}
|
||||
auto kernel_build_info = GenerateKernelBuildInfo(cast_list);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, cast_all.get());
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
cast_all->set_abstract(abstract_tuple);
|
||||
AnfAlgo::SetNodeAttr("n", MakeValue(cast_list.size()), cast_all);
|
||||
// 3 replace all the cast by CastAllv tuplegetitem[castall, idx]
|
||||
for (size_t idx = 0; idx < cast_list.size(); ++idx) {
|
||||
std::vector<AnfNodePtr> tuple_getitem_input;
|
||||
tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem));
|
||||
tuple_getitem_input.push_back(cast_all);
|
||||
auto index = NewValueNode(SizeToLong(idx));
|
||||
auto imm = std::make_shared<Int64Imm>(idx);
|
||||
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
|
||||
MS_EXCEPTION_IF_NULL(abstract_scalar);
|
||||
index->set_abstract(abstract_scalar);
|
||||
tuple_getitem_input.push_back(index);
|
||||
AnfNodePtr tuple_getitem = graph->NewCNode(tuple_getitem_input);
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
tuple_getitem->set_abstract(cast_list[idx]->abstract());
|
||||
if (!manager->Replace(cast_list[idx], tuple_getitem)) {
|
||||
MS_LOG(EXCEPTION) << "manager replace node failed";
|
||||
if (cast_list.size() > 0) {
|
||||
TraceGuard guard(std::make_shared<TraceOpt>(cast_list[0]->debug_info()));
|
||||
auto cast_all = graph->NewCNode(inputs);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(cast_all);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
cast_all->set_kernel_info(kernel_info);
|
||||
AbstractBasePtrList abstract_list;
|
||||
for (size_t idx = 0; idx < cast_list.size(); ++idx) {
|
||||
auto cnode = utils::cast<CNodePtr>(cast_list[idx]);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
abstract_list.push_back(cnode->abstract());
|
||||
}
|
||||
auto kernel_build_info = GenerateKernelBuildInfo(cast_list);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, cast_all.get());
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
cast_all->set_abstract(abstract_tuple);
|
||||
AnfAlgo::SetNodeAttr("n", MakeValue(cast_list.size()), cast_all);
|
||||
// 3 replace all the cast by CastAllv tuplegetitem[castall, idx]
|
||||
for (size_t idx = 0; idx < cast_list.size(); ++idx) {
|
||||
std::vector<AnfNodePtr> tuple_getitem_input;
|
||||
tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem));
|
||||
tuple_getitem_input.push_back(cast_all);
|
||||
auto index = NewValueNode(SizeToLong(idx));
|
||||
auto imm = std::make_shared<Int64Imm>(idx);
|
||||
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
|
||||
MS_EXCEPTION_IF_NULL(index);
|
||||
MS_EXCEPTION_IF_NULL(abstract_scalar);
|
||||
index->set_abstract(abstract_scalar);
|
||||
tuple_getitem_input.push_back(index);
|
||||
AnfNodePtr tuple_getitem = graph->NewCNode(tuple_getitem_input);
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
tuple_getitem->set_abstract(cast_list[idx]->abstract());
|
||||
if (!manager->Replace(cast_list[idx], tuple_getitem)) {
|
||||
MS_LOG(EXCEPTION) << "manager replace node failed";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The size of cast_list is zero.";
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool GetDealList(const std::vector<AnfNodePtr> &node_list, std::vector<std::vector<AnfNodePtr>> *deal_list) {
|
||||
MS_EXCEPTION_IF_NULL(deal_list);
|
||||
std::vector<AnfNodePtr> momentum;
|
||||
std::vector<AnfNodePtr> momentum_decay;
|
||||
for (auto &momentum_node : node_list) {
|
||||
|
@ -55,6 +56,9 @@ bool CombineMomentumFusion::Run(const FuncGraphPtr &graph) {
|
|||
return false;
|
||||
}
|
||||
for (auto momentums : deal_list) {
|
||||
if (momentums.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "The size of momentums is zero.";
|
||||
}
|
||||
// 2 create node momentum
|
||||
std::vector<AnfNodePtr> inputs = {};
|
||||
if (AnfAlgo::GetCNodeName(momentums[0]) == kFusedScaleApplyMomentum) {
|
||||
|
@ -70,12 +74,15 @@ bool CombineMomentumFusion::Run(const FuncGraphPtr &graph) {
|
|||
size_t input_num = AnfAlgo::GetInputTensorNum(momentums[0]);
|
||||
for (auto mom : momentums) {
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
inputs.push_back(AnfAlgo::GetInputNode(utils::cast<CNodePtr>(mom), i));
|
||||
auto cnode = utils::cast<CNodePtr>(mom);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
inputs.push_back(AnfAlgo::GetInputNode(cnode, i));
|
||||
}
|
||||
}
|
||||
TraceGuard guard(std::make_shared<TraceOpt>(momentums[0]->debug_info()));
|
||||
auto combine_mom = graph->NewCNode(inputs);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(combine_mom);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
combine_mom->set_kernel_info(kernel_info);
|
||||
AbstractBasePtrList abstract_list;
|
||||
|
|
|
@ -123,13 +123,19 @@ std::pair<size_t, bool> GetCoverIndex(const std::vector<AnfNodeIndex> &inplace_n
|
|||
}
|
||||
|
||||
auto first_node_prim = AnfAlgo::GetCNodePrimitive(first_node);
|
||||
MS_EXCEPTION_IF_NULL(first_node_prim);
|
||||
auto first_node_channel = first_node_prim.get()->GetAttr("out_channel");
|
||||
MS_EXCEPTION_IF_NULL(first_node_channel);
|
||||
size_t first_channel = first_node_channel->cast<Int64ImmPtr>()->value();
|
||||
auto first_imm_ptr = first_node_channel->cast<Int64ImmPtr>();
|
||||
MS_EXCEPTION_IF_NULL(first_imm_ptr);
|
||||
size_t first_channel = first_imm_ptr->value();
|
||||
auto second_node_prim = AnfAlgo::GetCNodePrimitive(second_node);
|
||||
MS_EXCEPTION_IF_NULL(second_node_prim);
|
||||
auto second_node_channel = second_node_prim.get()->GetAttr("out_channel");
|
||||
MS_EXCEPTION_IF_NULL(second_node_channel);
|
||||
size_t second_channel = second_node_channel->cast<Int64ImmPtr>()->value();
|
||||
auto second_imm_ptr = second_node_channel->cast<Int64ImmPtr>();
|
||||
MS_EXCEPTION_IF_NULL(second_imm_ptr);
|
||||
size_t second_channel = second_imm_ptr->value();
|
||||
size_t cover_index = (first_channel >= second_channel) ? 0 : 1;
|
||||
bool ret = ExistDependencyFromAcc2Cover(inplace_node, cover_index);
|
||||
if (ret) {
|
||||
|
@ -165,6 +171,8 @@ void CheckInplaceNodeInputs(std::vector<AnfNodeIndex> *inplace_node, size_t cove
|
|||
// | | | | |
|
||||
// Cover Acc | Acc |
|
||||
// Cover---------------+
|
||||
MS_EXCEPTION_IF_NULL(inplace_node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
size_t acc_index = cover_index == 1 ? 0 : 1;
|
||||
const CNodePtr &cover_node = inplace_node->at(cover_index).node->cast<CNodePtr>();
|
||||
const CNodePtr &acc_node = inplace_node->at(acc_index).node->cast<CNodePtr>();
|
||||
|
@ -177,9 +185,11 @@ void CheckInplaceNodeInputs(std::vector<AnfNodeIndex> *inplace_node, size_t cove
|
|||
bool ret = ExistRoute(acc_input, cover_node);
|
||||
if (ret) {
|
||||
auto new_input = graph->NewCNode(acc_input->inputs());
|
||||
MS_EXCEPTION_IF_NULL(new_input);
|
||||
new_input->set_abstract(acc_input->abstract());
|
||||
CopyKernelInfo(acc_input, new_input);
|
||||
auto new_inplace_node = graph->NewCNode({acc_node->input(0), new_input, acc_node->input(2)});
|
||||
MS_EXCEPTION_IF_NULL(new_inplace_node);
|
||||
new_inplace_node->set_abstract(acc_node->abstract());
|
||||
CopyKernelInfo(acc_node, new_inplace_node);
|
||||
auto manager = graph->manager();
|
||||
|
@ -191,6 +201,10 @@ void CheckInplaceNodeInputs(std::vector<AnfNodeIndex> *inplace_node, size_t cove
|
|||
|
||||
void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<AnfNodeIndex> *inplace_node,
|
||||
const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(skip_node);
|
||||
MS_EXCEPTION_IF_NULL(inplace_node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
SetPrimAttr(aggregate_node.node, "aggregate", true);
|
||||
SetPrimAttr(aggregate_node.node, "aggregate_input_index", aggregate_node.index);
|
||||
SetPrimAttr(skip_node, "skip", true);
|
||||
|
@ -202,6 +216,7 @@ void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<
|
|||
for (size_t i = 0; i < inplace_node->size(); i++) {
|
||||
auto algo = (i == cover_index) ? "cover" : "accumulation";
|
||||
auto node = (*inplace_node)[i].node;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
SetPrimAttr(node, "inplace_algo", algo);
|
||||
SetPrimAttr(node, "inplace_group", group);
|
||||
SetPrimAttr(node, "inplace_output_index", (*inplace_node)[i].index);
|
||||
|
@ -209,10 +224,13 @@ void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<
|
|||
if (order_required && i != cover_index) {
|
||||
auto acc_node = node;
|
||||
auto cover_node = (*inplace_node)[cover_index].node;
|
||||
auto acc_node_input = acc_node->cast<CNodePtr>()->input(1);
|
||||
auto cnode = acc_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto acc_node_input = cnode->input(1);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
||||
acc_node_input, cover_node};
|
||||
auto depend_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
depend_node->set_abstract(acc_node_input->abstract());
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -224,6 +242,9 @@ void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<
|
|||
|
||||
bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node,
|
||||
std::vector<AnfNodeIndex> *inplace) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(inplace);
|
||||
MS_EXCEPTION_IF_NULL(skip_node);
|
||||
MS_EXCEPTION_IF_NULL(aggregate);
|
||||
if (!node->isa<CNode>()) {
|
||||
|
|
|
@ -77,9 +77,12 @@ void SetTransposeOpBuildInfo(const std::string &input_format, const std::string
|
|||
// Insert transpose op between node and used_node whose position is used_node_index.
|
||||
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node,
|
||||
int used_node_index, const std::vector<int64_t> &transpose_perm) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(used_node);
|
||||
|
||||
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
|
||||
<< ", index: " << used_node_index;
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// 0.Judge whether it is a fake transpose
|
||||
auto transed_shape = AnfAlgo::GetInputDeviceShape(used_node, used_node_index);
|
||||
bool is_fake = IsFakeTranspose(transed_shape, transpose_perm);
|
||||
|
@ -94,6 +97,7 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
|
|||
// 2.Set the input of transpose.
|
||||
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node};
|
||||
auto transpose_op = graph->NewCNode(transpose_input);
|
||||
MS_EXCEPTION_IF_NULL(transpose_op);
|
||||
// 3.Set the output info of transpose.
|
||||
auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
|
||||
auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)};
|
||||
|
@ -144,6 +148,7 @@ const AnfNodePtr InsertFormatTransformOp::Process(const FuncGraphPtr &graph, con
|
|||
if ((outputs_format[i] != kOpFormat_DEFAULT) && (outputs_format[i] != origin_data_format)) {
|
||||
// Find all nodes connected with node output, and change their inputs to transpose.
|
||||
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, i);
|
||||
MS_EXCEPTION_IF_NULL(used_node_list);
|
||||
for (size_t j = 0; j < used_node_list->size(); j++) {
|
||||
auto used_node = used_node_list->at(j).first;
|
||||
auto used_node_index = used_node_list->at(j).second - 1;
|
||||
|
@ -166,6 +171,7 @@ void InsertFormatTransformOp::ProcessForTupleItem(const FuncGraphPtr &graph, con
|
|||
const std::vector<int64_t> &transpose_perm,
|
||||
const std::string &transpose_format) const {
|
||||
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, node_index);
|
||||
MS_EXCEPTION_IF_NULL(used_node_list);
|
||||
for (size_t i = 0; i < used_node_list->size(); i++) {
|
||||
auto used_node = used_node_list->at(i).first;
|
||||
auto used_node_index = used_node_list->at(i).second - 1;
|
||||
|
|
|
@ -36,7 +36,7 @@ const BaseRef PostBatchNormAddReluFusion::DefinePattern() const {
|
|||
}
|
||||
|
||||
const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
|
|
|
@ -55,8 +55,13 @@ bool GetOptList(const std::vector<AnfNodePtr> &node_list, std::vector<AnfNodePtr
|
|||
std::vector<std::vector<int64_t>> *string_pos_vec,
|
||||
std::vector<std::vector<std::string>> *string_value_vec,
|
||||
std::vector<std::vector<std::pair<int64_t, int64_t>>> *not_tensor_pos_vec) {
|
||||
MS_EXCEPTION_IF_NULL(opt_list);
|
||||
MS_EXCEPTION_IF_NULL(string_pos_vec);
|
||||
MS_EXCEPTION_IF_NULL(string_value_vec);
|
||||
|
||||
for (auto &node : node_list) {
|
||||
// {prim::kPrimPrint} reduction only applies on print with string, tensor(scalar or tuple)
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<int64_t> string_pos;
|
||||
std::vector<std::string> string_value;
|
||||
std::vector<std::pair<int64_t, int64_t>> value_type;
|
||||
|
@ -69,7 +74,10 @@ bool GetOptList(const std::vector<AnfNodePtr> &node_list, std::vector<AnfNodePtr
|
|||
continue;
|
||||
}
|
||||
auto value_node = current_node->cast<ValueNodePtr>();
|
||||
auto shape_node = dyn_cast<abstract::Shape>(value_node->abstract()->GetShapeTrack());
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto shape = value_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
auto shape_node = dyn_cast<abstract::Shape>(shape->GetShapeTrack());
|
||||
if (shape_node != nullptr) {
|
||||
// a scalar or tuple
|
||||
auto shape_size = shape_node->shape().size();
|
||||
|
@ -84,7 +92,9 @@ bool GetOptList(const std::vector<AnfNodePtr> &node_list, std::vector<AnfNodePtr
|
|||
// not a string
|
||||
continue;
|
||||
}
|
||||
if (node_value->type()->generic_type_id() == kObjectTypeString) {
|
||||
auto type = node_value->type();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->generic_type_id() == kObjectTypeString) {
|
||||
auto current_string_value = GetValue<std::string>(node_value);
|
||||
string_pos.push_back(i);
|
||||
string_value.push_back(std::string(current_string_value));
|
||||
|
@ -122,6 +132,7 @@ bool PrintReduceFusion::Run(const FuncGraphPtr &graph) {
|
|||
for (size_t idx = 0; idx < opt_list.size(); idx++) {
|
||||
auto node = opt_list[idx];
|
||||
CNodePtr cnode = utils::cast<CNodePtr>(node);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
auto prim = std::make_shared<Primitive>("Print");
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||
|
@ -157,6 +168,7 @@ bool PrintReduceFusion::Run(const FuncGraphPtr &graph) {
|
|||
[](const std::pair<int64_t, int64_t> &value) { return value.second; });
|
||||
// create new cnode
|
||||
auto print_fused = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(print_fused);
|
||||
// hand over the attrs to new print
|
||||
AnfAlgo::SetNodeAttr("string_pos", MakeValue<std::vector<int64_t>>(string_pos), print_fused);
|
||||
AnfAlgo::SetNodeAttr("string_value", MakeValue<std::vector<std::string>>(string_value), print_fused);
|
||||
|
|
|
@ -29,10 +29,13 @@ namespace opt {
|
|||
namespace {
|
||||
void ReducePrecision(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t i, const TypeId &src_type,
|
||||
const TypeId &cast_type) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto prim = std::make_shared<Primitive>(prim::kPrimCast->name());
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i)};
|
||||
auto cast = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(cast);
|
||||
auto cast_shape = {AnfAlgo::GetInputDeviceShape(node, i)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({cast_type}, cast_shape, cast.get());
|
||||
FuncGraphManagerPtr manager = graph->manager();
|
||||
|
@ -49,7 +52,10 @@ void ReducePrecision(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t i
|
|||
}
|
||||
void ProcessTupleGetItem(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t node_index, const TypeId &src_type,
|
||||
const TypeId &cast_type) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, node_index);
|
||||
MS_EXCEPTION_IF_NULL(used_node_list);
|
||||
for (size_t i = 0; i < used_node_list->size(); i++) {
|
||||
auto used_node = used_node_list->at(i).first;
|
||||
auto used_node_index = used_node_list->at(i).second - 1;
|
||||
|
@ -64,6 +70,7 @@ bool ReducePrecisionFusion::Run(const FuncGraphPtr &graph) {
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
|
||||
for (auto node : node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
|
@ -83,6 +90,7 @@ bool ReducePrecisionFusion::Run(const FuncGraphPtr &graph) {
|
|||
continue;
|
||||
}
|
||||
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, i);
|
||||
MS_EXCEPTION_IF_NULL(used_node_list);
|
||||
for (size_t j = 0; j < used_node_list->size(); j++) {
|
||||
auto used_node = used_node_list->at(j).first;
|
||||
auto used_node_index = used_node_list->at(j).second - 1;
|
||||
|
|
|
@ -40,6 +40,7 @@ CNodePtr GetRelu(const CNodePtr &relu_grad) {
|
|||
}
|
||||
|
||||
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
|
|
|
@ -34,10 +34,9 @@ const BaseRef RemoveFormatTransformPair::DefinePattern() const {
|
|||
}
|
||||
|
||||
const AnfNodePtr RemoveFormatTransformPair::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
MS_LOG(DEBUG) << "Process node:" << node->fullname_with_scope();
|
||||
auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
|
|
|
@ -32,15 +32,15 @@ const BaseRef RemoveRedundantFormatTransform::DefinePattern() const {
|
|||
}
|
||||
|
||||
const AnfNodePtr RemoveRedundantFormatTransform::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
MS_LOG(DEBUG) << "Process node:" << node->fullname_with_scope();
|
||||
auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
AnfNodePtr first_transpose = nullptr;
|
||||
auto used_node_list = GetRealNodeUsedList(graph, input_node);
|
||||
MS_EXCEPTION_IF_NULL(used_node_list);
|
||||
for (size_t j = 0; j < used_node_list->size(); j++) {
|
||||
auto used_node = used_node_list->at(j).first;
|
||||
if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTranspose->name()) {
|
||||
|
|
|
@ -26,11 +26,9 @@ const BaseRef ReplaceAddNFusion::DefinePattern() const {
|
|||
return addn;
|
||||
}
|
||||
|
||||
const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto A = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
auto B = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 1);
|
||||
MS_EXCEPTION_IF_NULL(A);
|
||||
|
@ -41,6 +39,7 @@ const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const Anf
|
|||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), A, B};
|
||||
auto add_new = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(add_new);
|
||||
std::vector<TypeId> outputs_type;
|
||||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(A, 0));
|
||||
|
|
|
@ -28,10 +28,9 @@ const BaseRef ReplaceMomentumCastFusion::DefinePattern() const {
|
|||
}
|
||||
|
||||
const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
|
||||
auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), kGradIndex);
|
||||
MS_EXCEPTION_IF_NULL(grad_cast);
|
||||
|
|
Loading…
Reference in New Issue