!18148 format codes
Merge pull request !18148 from kisnwang/master-codefix
This commit is contained in:
commit
0f1f20165e
|
@ -29,6 +29,7 @@ constexpr size_t kBatchNormRealOutputNum = 3;
|
|||
constexpr size_t kBatchNormRealInputNum = 3;
|
||||
constexpr size_t kInputIndex2 = 2;
|
||||
constexpr size_t kInputIndex3 = 3;
|
||||
constexpr size_t kInputIndex4 = 4;
|
||||
|
||||
bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -117,8 +118,9 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
|
|||
MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBnOutputNum << ", but it is "
|
||||
<< bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn);
|
||||
}
|
||||
std::vector<AbstractBasePtr> abstract_list{bn_abstract_tuple->elements()[0], bn_abstract_tuple->elements()[3],
|
||||
bn_abstract_tuple->elements()[4]};
|
||||
std::vector<AbstractBasePtr> abstract_list{bn_abstract_tuple->elements()[0],
|
||||
bn_abstract_tuple->elements()[kInputIndex3],
|
||||
bn_abstract_tuple->elements()[kInputIndex4]};
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
bn_training_update_v2->set_abstract(abstract_tuple);
|
||||
bn_training_update_v2->set_scope(bn->scope());
|
||||
|
|
|
@ -47,7 +47,8 @@ bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
|||
auto value_node = index_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto index = GetValue<int64_t>(value_node->value());
|
||||
if (index == kBatchNormGradInferOutputNum || index == SizeToLong(kBatchNormGradInferOutputNum + 1)) {
|
||||
auto output_num = SizeToLong(kBatchNormGradInferOutputNum);
|
||||
if (index == output_num || index == output_num + 1) {
|
||||
MS_LOG(DEBUG) << "The output " << index << " of node " << node->DebugString() << " is not null, no need change";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -46,15 +46,17 @@ AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origi
|
|||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_concat);
|
||||
// infer shape
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_concat_cnode, 0);
|
||||
auto axis = AnfAlgo::GetNodeAttr<int64_t>(origin_concat_cnode, kAttrAxis);
|
||||
if (axis < 0) {
|
||||
axis += SizeToLong(input_shape.size());
|
||||
auto axis_from_attr = AnfAlgo::GetNodeAttr<int64_t>(origin_concat_cnode, kAttrAxis);
|
||||
if (axis_from_attr < 0) {
|
||||
axis_from_attr += SizeToLong(input_shape.size());
|
||||
}
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(origin_concat_cnode, 0);
|
||||
if (axis < 0 || axis >= SizeToLong(output_shape.size()) || axis >= SizeToLong(input_shape.size())) {
|
||||
MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range"
|
||||
if (axis_from_attr < 0 || axis_from_attr >= SizeToLong(output_shape.size()) ||
|
||||
axis_from_attr >= SizeToLong(input_shape.size())) {
|
||||
MS_LOG(EXCEPTION) << "The concat_dim value " << axis_from_attr << "is out of range"
|
||||
<< " trace: " << trace::DumpSourceLines(origin_concat_cnode);
|
||||
}
|
||||
auto axis = LongToSize(axis_from_attr);
|
||||
output_shape[axis] = 0;
|
||||
for (size_t i = begin_index; i < begin_index + offset; ++i) {
|
||||
input_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_concat_cnode, i - 1);
|
||||
|
|
|
@ -163,7 +163,7 @@ AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &graph, const AnfNodePtr &
|
|||
size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[DIM0];
|
||||
size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[DIM1];
|
||||
size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[DIM2] / kGateNum;
|
||||
std::vector<size_t> shape = {t_size, batch, hidden_size + hidden_size};
|
||||
std::vector<size_t> shape = {t_size, batch, hidden_size << 1};
|
||||
std::vector<size_t> shape2 = {t_size, batch, hidden_size};
|
||||
std::vector<std::vector<size_t>> shapes = {shape, shape2};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get());
|
||||
|
|
|
@ -24,6 +24,10 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kOriginPaddingSize = 2;
|
||||
constexpr size_t kGatherInputNum = 4;
|
||||
constexpr size_t kGatherInputIndicesIndex = 2;
|
||||
constexpr size_t kGatherInputAxisIndex = 3;
|
||||
// only pad operator can run in dynamic shape.
|
||||
CNodePtr CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -62,7 +66,7 @@ CNodePtr CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node, const
|
|||
|
||||
std::vector<ValuePtr> elements;
|
||||
for (size_t i = 0; i < shape.size() - 1; ++i) {
|
||||
ShapeVector padding_vector(2);
|
||||
ShapeVector padding_vector(kOriginPaddingSize);
|
||||
auto padding_value = MakeValue(padding_vector);
|
||||
elements.push_back(padding_value);
|
||||
}
|
||||
|
@ -82,11 +86,12 @@ CNodePtr CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
MS_EXCEPTION_IF_NULL(pad);
|
||||
if (origin_node->size() != 4) {
|
||||
if (origin_node->size() != kGatherInputNum) {
|
||||
MS_LOG(EXCEPTION) << "In dynamic shape scene, gatherv2 should have 3 inputs";
|
||||
}
|
||||
std::vector<AnfNodePtr> gatherv2_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimGather->name())), pad,
|
||||
origin_node->input(2), origin_node->input(3)};
|
||||
origin_node->input(kGatherInputIndicesIndex),
|
||||
origin_node->input(kGatherInputAxisIndex)};
|
||||
auto gather_v2 = graph->NewCNode(gatherv2_inputs);
|
||||
MS_EXCEPTION_IF_NULL(gather_v2);
|
||||
gather_v2->set_scope(origin_node->scope());
|
||||
|
|
|
@ -24,7 +24,11 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kOutputNum = 2;
|
||||
constexpr size_t kSquareSumOutputNum = 2;
|
||||
constexpr size_t kLarsV2WIndex = 1;
|
||||
constexpr size_t kLarsV2GIndex = 2;
|
||||
constexpr size_t kLarsV2WeightDecayIndex = 3;
|
||||
constexpr size_t kLarsV2LearningRatIndex = 4;
|
||||
void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
|
||||
std::vector<AnfNodePtr> *square_sum_all_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -40,26 +44,25 @@ void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars
|
|||
std::vector<size_t> shape;
|
||||
auto shapes = {shape, shape};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, square_sum_all.get());
|
||||
|
||||
CreateMultipleOutputsOfAnfNode(graph, square_sum_all, kOutputNum, square_sum_all_outputs);
|
||||
CreateMultipleOutputsOfAnfNode(graph, square_sum_all, kSquareSumOutputNum, square_sum_all_outputs);
|
||||
}
|
||||
|
||||
CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
|
||||
const std::vector<AnfNodePtr> &square_sum_all_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(lars_v2);
|
||||
if (square_sum_all_outputs.size() != 2) {
|
||||
if (square_sum_all_outputs.size() != kSquareSumOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "square_sum_all_outputs' size not equal 2"
|
||||
<< " trace: " << trace::DumpSourceLines(lars_v2);
|
||||
}
|
||||
CheckCNodeInputSize(lars_v2, kLarsV2InputTensorNum);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kLarsV2UpdateOpName)),
|
||||
lars_v2->input(1),
|
||||
lars_v2->input(2),
|
||||
lars_v2->input(kLarsV2WIndex),
|
||||
lars_v2->input(kLarsV2GIndex),
|
||||
square_sum_all_outputs[0],
|
||||
square_sum_all_outputs[1],
|
||||
lars_v2->input(3),
|
||||
lars_v2->input(4)};
|
||||
lars_v2->input(kLarsV2WeightDecayIndex),
|
||||
lars_v2->input(kLarsV2LearningRatIndex)};
|
||||
auto lars_v2_update = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(lars_v2_update);
|
||||
lars_v2_update->set_scope(lars_v2->scope());
|
||||
|
|
|
@ -26,6 +26,9 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr size_t kLayerNormGradOutputGammaIndex = 1;
|
||||
constexpr size_t kLayerNormGradOutputBetaIndex = 2;
|
||||
constexpr size_t kLayerNormGradInputGammaIndex = 4;
|
||||
void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop(
|
||||
const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad,
|
||||
std::vector<AnfNodePtr> *layer_norm_x_backprop_outputs) const {
|
||||
|
@ -64,13 +67,15 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop(
|
|||
layer_norm_beta_gamma_backprop->set_kernel_info(kernel_info);
|
||||
layer_norm_beta_gamma_backprop->set_scope(layer_norm_grad->scope());
|
||||
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, 1),
|
||||
AnfAlgo::GetOutputInferDataType(layer_norm_grad, 2)};
|
||||
auto shapes = {AnfAlgo::GetOutputDetailShape(layer_norm_grad, 1), AnfAlgo::GetOutputDetailShape(layer_norm_grad, 2)};
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, kLayerNormGradOutputGammaIndex),
|
||||
AnfAlgo::GetOutputInferDataType(layer_norm_grad, kLayerNormGradOutputBetaIndex)};
|
||||
auto shapes = {AnfAlgo::GetOutputDetailShape(layer_norm_grad, kLayerNormGradOutputGammaIndex),
|
||||
AnfAlgo::GetOutputDetailShape(layer_norm_grad, kLayerNormGradOutputBetaIndex)};
|
||||
AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, layer_norm_beta_gamma_backprop.get());
|
||||
|
||||
// get device shape of LayerNormGrad's 5th Input, and convert it to attr
|
||||
std::vector<size_t> shape_gamma = AnfAlgo::GetPrevNodeOutputInferShape(layer_norm_grad, 4);
|
||||
std::vector<size_t> shape_gamma =
|
||||
AnfAlgo::GetPrevNodeOutputInferShape(layer_norm_grad, kLayerNormGradInputGammaIndex);
|
||||
AnfAlgo::SetNodeAttr(kAttrShapeGamma, MakeValue(opt::Convert2Long(shape_gamma)), layer_norm_beta_gamma_backprop);
|
||||
|
||||
CreateMultipleOutputsOfAnfNode(graph, layer_norm_beta_gamma_backprop, kLayerNormBetaGammaBackpropOutputNum,
|
||||
|
|
|
@ -58,7 +58,7 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
|||
auto data_ptr = assist_tensor->data_c();
|
||||
MS_EXCEPTION_IF_NULL(data_ptr);
|
||||
std::vector<float16> half_data;
|
||||
int64_t dims = 1 * 1 * d * h * w;
|
||||
const int64_t dims = 1 * 1 * d * h * w;
|
||||
int64_t counter = dims;
|
||||
for (int64_t i = 0; i < dims; i++) {
|
||||
half_data.emplace_back(float16(static_cast<float>(counter)));
|
||||
|
@ -110,20 +110,20 @@ const AnfNodePtr MaxPool3DGradGradFission::Process(const FuncGraphPtr &graph, co
|
|||
MS_LOG(INFO) << "The node " << cnode->DebugString() << " is not equal to " << kInputNum << " inputs";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> new_node_inputs{NewValueNode(std::make_shared<Primitive>(kMaxPool3DGradGradOpName))};
|
||||
auto assist_filter_const = CreateValueNode(cnode);
|
||||
new_node_inputs.insert(new_node_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
new_node_inputs.push_back(assist_filter_const);
|
||||
CNodePtr new_max_pool3d_grad_grad_node = graph->NewCNode(new_node_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_max_pool3d_grad_grad_node);
|
||||
new_max_pool3d_grad_grad_node->set_abstract(cnode->abstract());
|
||||
new_max_pool3d_grad_grad_node->set_scope(cnode->scope());
|
||||
AnfAlgo::CopyNodeAttrs(cnode, new_max_pool3d_grad_grad_node);
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kMaxPool3DGradGradOpName))};
|
||||
auto assist_const = CreateValueNode(cnode);
|
||||
(void)new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
(void)new_inputs.emplace_back(assist_const);
|
||||
CNodePtr new_cnode = graph->NewCNode(new_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->AddValueNodeToGraph(assist_filter_const);
|
||||
kernel_graph->AddValueNodeToGraph(assist_const);
|
||||
MS_LOG(INFO) << "Split MaxPool3DGradGrad op success.";
|
||||
}
|
||||
return new_max_pool3d_grad_grad_node;
|
||||
return new_cnode;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -49,17 +49,11 @@ AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_
|
|||
MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range"
|
||||
<< " trace: " << trace::DumpSourceLines(origin_pack_cnode);
|
||||
}
|
||||
std::vector<size_t> new_shape;
|
||||
for (size_t i = 0; i < output_shape.size() + 1; ++i) {
|
||||
if (i < LongToSize(axis)) {
|
||||
new_shape.push_back(output_shape[i]);
|
||||
} else if (i == LongToSize(axis)) {
|
||||
new_shape.push_back(offset);
|
||||
} else {
|
||||
new_shape.push_back(output_shape[i - 1]);
|
||||
}
|
||||
std::vector<size_t> new_shape = output_shape;
|
||||
auto axis_l = LongToSize(axis);
|
||||
if (axis_l < new_shape.size()) {
|
||||
new_shape[axis_l] = offset;
|
||||
}
|
||||
new_shape.erase(new_shape.begin() + axis + 1);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {new_shape},
|
||||
new_pack.get());
|
||||
return new_pack;
|
||||
|
|
|
@ -32,7 +32,7 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
|
|||
MS_EXCEPTION_IF_NULL(bn_cnode);
|
||||
if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) {
|
||||
MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than "
|
||||
<< kBatchNormRealInputNum + 1 << " trace: " << trace::DumpSourceLines(bn);
|
||||
<< (kBatchNormRealInputNum + 1) << " trace: " << trace::DumpSourceLines(bn);
|
||||
}
|
||||
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(1)};
|
||||
|
@ -58,7 +58,7 @@ AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNod
|
|||
MS_EXCEPTION_IF_NULL(bn_cnode);
|
||||
if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) {
|
||||
MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than "
|
||||
<< kBatchNormRealInputNum + 1 << " trace: " << trace::DumpSourceLines(bn);
|
||||
<< (kBatchNormRealInputNum + 1) << " trace: " << trace::DumpSourceLines(bn);
|
||||
}
|
||||
if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum
|
||||
|
|
|
@ -77,9 +77,11 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
|
|||
tensor::TensorPtr filter_tensor = CreateTensor(node);
|
||||
MS_EXCEPTION_IF_NULL(filter_tensor);
|
||||
auto assist_const = std::make_shared<ValueNode>(filter_tensor);
|
||||
MS_EXCEPTION_IF_NULL(assist_const);
|
||||
auto assist_abstract = filter_tensor->ToAbstract();
|
||||
assist_const->set_abstract(assist_abstract);
|
||||
auto assist_kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(assist_kernel_info);
|
||||
assist_const->set_kernel_info(assist_kernel_info);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder op_builder;
|
||||
op_builder.SetOutputsFormat({kOpFormat_NC1HWC0});
|
||||
|
@ -114,16 +116,16 @@ const AnfNodePtr SpaceToDepthSplit::Process(const FuncGraphPtr &graph, const Anf
|
|||
}
|
||||
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kSpaceToDepthOpName))};
|
||||
auto assist_const = CreateValueNode(cnode);
|
||||
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
new_inputs.push_back(assist_const);
|
||||
auto last_input_value = CreateValueNode(cnode);
|
||||
(void)new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
(void)new_inputs.emplace_back(last_input_value);
|
||||
CNodePtr new_cnode = graph->NewCNode(new_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->AddValueNodeToGraph(assist_const);
|
||||
kernel_graph->AddValueNodeToGraph(last_input_value);
|
||||
MS_LOG(INFO) << "Split SpaceToDepth op success.";
|
||||
}
|
||||
return new_cnode;
|
||||
|
|
|
@ -53,6 +53,9 @@ size_t GetSmallSplitSize(const AnfNodePtr &split_node, int64_t split_dim, int64_
|
|||
if (LongToSize(split_dim) >= input_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "The split_dim value should be less than the shape size of input 0";
|
||||
}
|
||||
if (num_split == 0) {
|
||||
MS_LOG(EXCEPTION) << "Divisor 'num_split' should not be 0.";
|
||||
}
|
||||
return input_shape[LongToSize(split_dim)] / LongToSize(num_split);
|
||||
}
|
||||
|
||||
|
@ -104,17 +107,23 @@ void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePt
|
|||
if (split_dim < 0) {
|
||||
split_dim += SizeToLong(output_shape.size());
|
||||
}
|
||||
for (size_t i = 0; i < LongToSize(num_split); ++i) {
|
||||
output_shape[LongToSize(split_dim)] = size_splits_base[i];
|
||||
if (split_dim < 0) {
|
||||
MS_LOG(EXCEPTION) << "Error split dim: " << split_dim;
|
||||
}
|
||||
auto split_dim_l = LongToSize(split_dim);
|
||||
auto num_split_l = LongToSize(num_split);
|
||||
for (size_t i = 0; i < num_split_l; ++i) {
|
||||
output_shape[split_dim_l] = LongToSize(size_splits_base[i]);
|
||||
base_output_shapes_base.emplace_back(output_shape);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({type_id}, {output_shape}, base_splitv_outputs[i].get());
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int64_t num_split, int64_t divisor) {
|
||||
AnfNodePtr SplitFission::DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int64_t num_split,
|
||||
int64_t divisor, int64_t split_dim) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto split_dim = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrAxis);
|
||||
CNodePtr base_splitv = CreateBaseSplitVNode(func_graph, cnode);
|
||||
|
||||
// Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs.
|
||||
|
@ -173,7 +182,6 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int6
|
|||
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef SplitFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
@ -196,7 +204,7 @@ const AnfNodePtr SplitFission::Process(const FuncGraphPtr &func_graph, const Anf
|
|||
if (num_split <= outputs_divisor_) {
|
||||
return nullptr;
|
||||
}
|
||||
return DoFission(func_graph, cnode, num_split, outputs_divisor_);
|
||||
return DoFission(func_graph, cnode, num_split, outputs_divisor_, AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrAxis));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,20 +16,24 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLIT_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLIT_FISSION_H_
|
||||
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr int kSplitOutputsDivisor = 63;
|
||||
constexpr int64_t kSplitOutputsDivisor = 63;
|
||||
class SplitFission : public PatternProcessPass {
|
||||
public:
|
||||
explicit SplitFission(bool multigraph = true)
|
||||
: PatternProcessPass("split_fission", multigraph), outputs_divisor_(kSplitOutputsDivisor) {}
|
||||
explicit SplitFission(const std::string name = "split_fission", bool multigraph = true,
|
||||
int64_t divisor = kSplitOutputsDivisor)
|
||||
: PatternProcessPass(name, multigraph), outputs_divisor_(divisor) {}
|
||||
~SplitFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
protected:
|
||||
AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int64_t num_split, int64_t divisor,
|
||||
int64_t split_dim) const;
|
||||
int64_t outputs_divisor_;
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -19,163 +19,6 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
std::vector<AnfNodePtr> splitv_inputs{NewValueNode(std::make_shared<Primitive>(kSplitVOpName)), input_node};
|
||||
CNodePtr splitv = func_graph->NewCNode(splitv_inputs);
|
||||
MS_EXCEPTION_IF_NULL(splitv);
|
||||
splitv->set_scope(input_node->scope());
|
||||
return splitv;
|
||||
}
|
||||
|
||||
CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(origin_cnode);
|
||||
CheckCNodeInputSize(origin_cnode, kSplitInputTensorNum);
|
||||
return CreateSplitVNode(func_graph, origin_cnode->input(1));
|
||||
}
|
||||
|
||||
void SetAttrForSplitVNode(const AnfNodePtr &splitv, const std::vector<int64_t> &size_splits, int64_t split_dim,
|
||||
int64_t num_split) {
|
||||
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_splits), splitv);
|
||||
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(split_dim), splitv);
|
||||
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(num_split), splitv);
|
||||
}
|
||||
|
||||
size_t GetSmallSplitSize(const AnfNodePtr &split_node, int64_t split_dim, int64_t num_split) {
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(split_node, 0);
|
||||
if (split_dim < 0) {
|
||||
split_dim += SizeToLong(input_shape.size());
|
||||
}
|
||||
if (LongToSize(split_dim) >= input_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "The split_dim value should be less than the shape size of input 0";
|
||||
}
|
||||
if (num_split == 0) {
|
||||
MS_LOG(EXCEPTION) << "Divisor 'num_split' should not be 0.";
|
||||
}
|
||||
return input_shape[LongToSize(split_dim)] / LongToSize(num_split);
|
||||
}
|
||||
|
||||
void AddNewOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &new_splitv, int64_t outputs_num,
|
||||
std::vector<AnfNodePtr> *inputs) {
|
||||
MS_EXCEPTION_IF_NULL(inputs);
|
||||
std::vector<AnfNodePtr> new_splitv_output;
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, new_splitv, LongToSize(outputs_num), &new_splitv_output);
|
||||
inputs->insert(inputs->end(), new_splitv_output.begin(), new_splitv_output.end());
|
||||
}
|
||||
|
||||
AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, int64_t index) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto idx = NewValueNode(index);
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
auto imm = std::make_shared<Int64Imm>(index);
|
||||
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
|
||||
idx->set_abstract(abstract_scalar);
|
||||
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
|
||||
return tuple_getitem;
|
||||
}
|
||||
|
||||
void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int64_t split_dim, int64_t split_size, int64_t num_split,
|
||||
std::vector<TypeId> *new_type_ids,
|
||||
std::vector<std::vector<size_t>> *new_output_shapes) {
|
||||
MS_EXCEPTION_IF_NULL(new_type_ids);
|
||||
MS_EXCEPTION_IF_NULL(new_output_shapes);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
|
||||
if (split_dim < 0) {
|
||||
split_dim += SizeToLong(output_shape.size());
|
||||
}
|
||||
output_shape[LongToSize(split_dim)] = LongToSize(split_size);
|
||||
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0);
|
||||
for (int64_t i = 0; i < num_split; ++i) {
|
||||
new_type_ids->emplace_back(type_id);
|
||||
new_output_shapes->emplace_back(output_shape);
|
||||
}
|
||||
}
|
||||
|
||||
void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv,
|
||||
const std::vector<AnfNodePtr> &base_splitv_outputs,
|
||||
const std::vector<int64_t> &size_splits_base, int64_t split_dim,
|
||||
int64_t num_split) {
|
||||
SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
|
||||
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0);
|
||||
std::vector<TypeId> base_type_ids(num_split, type_id);
|
||||
std::vector<std::vector<size_t>> base_output_shapes_base;
|
||||
if (split_dim < 0) {
|
||||
split_dim += SizeToLong(output_shape.size());
|
||||
}
|
||||
for (size_t i = 0; i < LongToSize(num_split); ++i) {
|
||||
output_shape[LongToSize(split_dim)] = size_splits_base[i];
|
||||
base_output_shapes_base.emplace_back(output_shape);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({type_id}, {output_shape}, base_splitv_outputs[i].get());
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get());
|
||||
}
|
||||
|
||||
AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int64_t num_split, int64_t divisor) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto split_dim = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrSplitDim);
|
||||
CNodePtr base_splitv = CreateBaseSplitVNode(func_graph, cnode);
|
||||
|
||||
// Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs.
|
||||
auto small_split_size = SizeToLong(GetSmallSplitSize(cnode, split_dim, num_split));
|
||||
std::vector<int64_t> size_splits_new(divisor, small_split_size);
|
||||
// Create new output shape and new output type id for each new Splitv node which has full inputs.
|
||||
std::vector<TypeId> new_type_ids;
|
||||
std::vector<std::vector<size_t>> new_output_shapes;
|
||||
CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, divisor, &new_type_ids, &new_output_shapes);
|
||||
|
||||
// Create make_tuple input to create a make_tuple for replacing the old Split node.
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
// Start to divide the outputs of Split.
|
||||
std::vector<int64_t> size_splits_base;
|
||||
std::vector<AnfNodePtr> base_splitv_outputs;
|
||||
const auto base_split_size = divisor * small_split_size;
|
||||
int64_t nodes_num = 0;
|
||||
int64_t cur_output_index = 0;
|
||||
while (num_split - cur_output_index > divisor) {
|
||||
auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
|
||||
base_splitv_outputs.push_back(tuple_getitem);
|
||||
CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem);
|
||||
SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get());
|
||||
AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs);
|
||||
cur_output_index += divisor;
|
||||
size_splits_base.emplace_back(base_split_size);
|
||||
nodes_num++;
|
||||
}
|
||||
if (cur_output_index < num_split) {
|
||||
auto last_node_num_split = num_split - cur_output_index;
|
||||
if (last_node_num_split > 1) {
|
||||
auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
|
||||
base_splitv_outputs.push_back(tuple_getitem);
|
||||
CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem);
|
||||
std::vector<int64_t> size_splits_new_last(last_node_num_split, small_split_size);
|
||||
SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split);
|
||||
// Create new output shape and new output type id for the last Splitv node
|
||||
std::vector<TypeId> last_new_type_ids;
|
||||
std::vector<std::vector<size_t>> last_new_output_shapes;
|
||||
CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, last_node_num_split, &last_new_type_ids,
|
||||
&last_new_output_shapes);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(last_new_type_ids, last_new_output_shapes, new_splitv.get());
|
||||
AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs);
|
||||
size_splits_base.emplace_back(last_node_num_split * small_split_size);
|
||||
} else {
|
||||
auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
|
||||
base_splitv_outputs.push_back(tuple_getitem);
|
||||
make_tuple_inputs.emplace_back(tuple_getitem);
|
||||
size_splits_base.emplace_back(small_split_size);
|
||||
}
|
||||
nodes_num++;
|
||||
}
|
||||
// Set Attr and abstract for the base splitv
|
||||
SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, base_splitv_outputs, size_splits_base, split_dim, nodes_num);
|
||||
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef SplitVFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto split_prim = std::make_shared<Primitive>(kSplitVOpName);
|
||||
|
@ -198,6 +41,6 @@ const AnfNodePtr SplitVFission::Process(const FuncGraphPtr &func_graph, const An
|
|||
if (num_split <= outputs_divisor_) {
|
||||
return nullptr;
|
||||
}
|
||||
return DoFission(func_graph, cnode, num_split, outputs_divisor_);
|
||||
return DoFission(func_graph, cnode, num_split, outputs_divisor_, AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrSplitDim));
|
||||
}
|
||||
} // namespace mindspore::opt
|
||||
|
|
|
@ -17,21 +17,17 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLITV_FISSION_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/split_fission.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SplitVFission : public PatternProcessPass {
|
||||
const int kSplitOutputsDivisor = 63;
|
||||
|
||||
constexpr int64_t kSplitVOutputsDivisor = 63;
|
||||
class SplitVFission : public SplitFission {
|
||||
public:
|
||||
explicit SplitVFission(bool multigraph = true)
|
||||
: PatternProcessPass("split_fission", multigraph), outputs_divisor_(kSplitOutputsDivisor) {}
|
||||
explicit SplitVFission(bool multigraph = true) : SplitFission("splitv_fission", multigraph, kSplitVOutputsDivisor) {}
|
||||
~SplitVFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
int64_t outputs_divisor_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -228,7 +228,11 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c
|
|||
MS_EXCEPTION_IF_NULL(index_node);
|
||||
auto value_node = index_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto index = GetValue<int64_t>(value_node->value());
|
||||
auto value_index = GetValue<int64_t>(value_node->value());
|
||||
if (value_index < 0) {
|
||||
MS_LOG(EXCEPTION) << "Error value index: " << value_index;
|
||||
}
|
||||
auto index = LongToSize(value_index);
|
||||
if (index == kReplaceOutputIndex0 || index == kReplaceOutputIndex1) {
|
||||
(void)manager->Replace(output, bn_training_update_outputs[index]);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue