format codes

This commit is contained in:
kswang 2021-06-10 16:29:40 +08:00
parent 1ccb1f5946
commit 39f6d057de
16 changed files with 98 additions and 229 deletions

View File

@ -29,6 +29,7 @@ constexpr size_t kBatchNormRealOutputNum = 3;
constexpr size_t kBatchNormRealInputNum = 3;
constexpr size_t kInputIndex2 = 2;
constexpr size_t kInputIndex3 = 3;
constexpr size_t kInputIndex4 = 4;
bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
MS_EXCEPTION_IF_NULL(func_graph);
@ -117,8 +118,9 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBnOutputNum << ", but it is "
<< bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn);
}
std::vector<AbstractBasePtr> abstract_list{bn_abstract_tuple->elements()[0], bn_abstract_tuple->elements()[3],
bn_abstract_tuple->elements()[4]};
std::vector<AbstractBasePtr> abstract_list{bn_abstract_tuple->elements()[0],
bn_abstract_tuple->elements()[kInputIndex3],
bn_abstract_tuple->elements()[kInputIndex4]};
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
bn_training_update_v2->set_abstract(abstract_tuple);
bn_training_update_v2->set_scope(bn->scope());

View File

@ -47,7 +47,8 @@ bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto index = GetValue<int64_t>(value_node->value());
if (index == kBatchNormGradInferOutputNum || index == SizeToLong(kBatchNormGradInferOutputNum + 1)) {
auto output_num = SizeToLong(kBatchNormGradInferOutputNum);
if (index == output_num || index == output_num + 1) {
MS_LOG(DEBUG) << "The output " << index << " of node " << node->DebugString() << " is not null, no need change";
return false;
}

View File

@ -46,15 +46,17 @@ AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origi
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_concat);
// infer shape
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_concat_cnode, 0);
auto axis = AnfAlgo::GetNodeAttr<int64_t>(origin_concat_cnode, kAttrAxis);
if (axis < 0) {
axis += SizeToLong(input_shape.size());
auto axis_from_attr = AnfAlgo::GetNodeAttr<int64_t>(origin_concat_cnode, kAttrAxis);
if (axis_from_attr < 0) {
axis_from_attr += SizeToLong(input_shape.size());
}
auto output_shape = AnfAlgo::GetOutputInferShape(origin_concat_cnode, 0);
if (axis < 0 || axis >= SizeToLong(output_shape.size()) || axis >= SizeToLong(input_shape.size())) {
MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range"
if (axis_from_attr < 0 || axis_from_attr >= SizeToLong(output_shape.size()) ||
axis_from_attr >= SizeToLong(input_shape.size())) {
MS_LOG(EXCEPTION) << "The concat_dim value " << axis_from_attr << "is out of range"
<< " trace: " << trace::DumpSourceLines(origin_concat_cnode);
}
auto axis = LongToSize(axis_from_attr);
output_shape[axis] = 0;
for (size_t i = begin_index; i < begin_index + offset; ++i) {
input_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_concat_cnode, i - 1);

View File

@ -163,7 +163,7 @@ AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &graph, const AnfNodePtr &
size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[DIM0];
size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[DIM1];
size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[DIM2] / kGateNum;
std::vector<size_t> shape = {t_size, batch, hidden_size + hidden_size};
std::vector<size_t> shape = {t_size, batch, hidden_size << 1};
std::vector<size_t> shape2 = {t_size, batch, hidden_size};
std::vector<std::vector<size_t>> shapes = {shape, shape2};
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get());

View File

@ -24,6 +24,10 @@
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kOriginPaddingSize = 2;
constexpr size_t kGatherInputNum = 4;
constexpr size_t kGatherInputIndicesIndex = 2;
constexpr size_t kGatherInputAxisIndex = 3;
// only pad operator can run in dynamic shape.
CNodePtr CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) {
MS_EXCEPTION_IF_NULL(graph);
@ -62,7 +66,7 @@ CNodePtr CreatePad(const FuncGraphPtr &graph, const CNodePtr &origin_node, const
std::vector<ValuePtr> elements;
for (size_t i = 0; i < shape.size() - 1; ++i) {
ShapeVector padding_vector(2);
ShapeVector padding_vector(kOriginPaddingSize);
auto padding_value = MakeValue(padding_vector);
elements.push_back(padding_value);
}
@ -82,11 +86,12 @@ CNodePtr CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(origin_node);
MS_EXCEPTION_IF_NULL(pad);
if (origin_node->size() != 4) {
if (origin_node->size() != kGatherInputNum) {
MS_LOG(EXCEPTION) << "In dynamic shape scene, gatherv2 should have 3 inputs";
}
std::vector<AnfNodePtr> gatherv2_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimGather->name())), pad,
origin_node->input(2), origin_node->input(3)};
origin_node->input(kGatherInputIndicesIndex),
origin_node->input(kGatherInputAxisIndex)};
auto gather_v2 = graph->NewCNode(gatherv2_inputs);
MS_EXCEPTION_IF_NULL(gather_v2);
gather_v2->set_scope(origin_node->scope());

View File

@ -24,7 +24,11 @@
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kOutputNum = 2;
constexpr size_t kSquareSumOutputNum = 2;
constexpr size_t kLarsV2WIndex = 1;
constexpr size_t kLarsV2GIndex = 2;
constexpr size_t kLarsV2WeightDecayIndex = 3;
constexpr size_t kLarsV2LearningRatIndex = 4;
void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
std::vector<AnfNodePtr> *square_sum_all_outputs) {
MS_EXCEPTION_IF_NULL(graph);
@ -40,26 +44,25 @@ void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars
std::vector<size_t> shape;
auto shapes = {shape, shape};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, square_sum_all.get());
CreateMultipleOutputsOfAnfNode(graph, square_sum_all, kOutputNum, square_sum_all_outputs);
CreateMultipleOutputsOfAnfNode(graph, square_sum_all, kSquareSumOutputNum, square_sum_all_outputs);
}
CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
const std::vector<AnfNodePtr> &square_sum_all_outputs) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(lars_v2);
if (square_sum_all_outputs.size() != 2) {
if (square_sum_all_outputs.size() != kSquareSumOutputNum) {
MS_LOG(EXCEPTION) << "square_sum_all_outputs' size not equal 2"
<< " trace: " << trace::DumpSourceLines(lars_v2);
}
CheckCNodeInputSize(lars_v2, kLarsV2InputTensorNum);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kLarsV2UpdateOpName)),
lars_v2->input(1),
lars_v2->input(2),
lars_v2->input(kLarsV2WIndex),
lars_v2->input(kLarsV2GIndex),
square_sum_all_outputs[0],
square_sum_all_outputs[1],
lars_v2->input(3),
lars_v2->input(4)};
lars_v2->input(kLarsV2WeightDecayIndex),
lars_v2->input(kLarsV2LearningRatIndex)};
auto lars_v2_update = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(lars_v2_update);
lars_v2_update->set_scope(lars_v2->scope());

View File

@ -26,6 +26,9 @@
namespace mindspore {
namespace opt {
constexpr size_t kLayerNormGradOutputGammaIndex = 1;
constexpr size_t kLayerNormGradOutputBetaIndex = 2;
constexpr size_t kLayerNormGradInputGammaIndex = 4;
void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop(
const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad,
std::vector<AnfNodePtr> *layer_norm_x_backprop_outputs) const {
@ -64,13 +67,15 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop(
layer_norm_beta_gamma_backprop->set_kernel_info(kernel_info);
layer_norm_beta_gamma_backprop->set_scope(layer_norm_grad->scope());
auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, 1),
AnfAlgo::GetOutputInferDataType(layer_norm_grad, 2)};
auto shapes = {AnfAlgo::GetOutputDetailShape(layer_norm_grad, 1), AnfAlgo::GetOutputDetailShape(layer_norm_grad, 2)};
auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, kLayerNormGradOutputGammaIndex),
AnfAlgo::GetOutputInferDataType(layer_norm_grad, kLayerNormGradOutputBetaIndex)};
auto shapes = {AnfAlgo::GetOutputDetailShape(layer_norm_grad, kLayerNormGradOutputGammaIndex),
AnfAlgo::GetOutputDetailShape(layer_norm_grad, kLayerNormGradOutputBetaIndex)};
AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, layer_norm_beta_gamma_backprop.get());
// get device shape of LayerNormGrad's 5th Input, and convert it to attr
std::vector<size_t> shape_gamma = AnfAlgo::GetPrevNodeOutputInferShape(layer_norm_grad, 4);
std::vector<size_t> shape_gamma =
AnfAlgo::GetPrevNodeOutputInferShape(layer_norm_grad, kLayerNormGradInputGammaIndex);
AnfAlgo::SetNodeAttr(kAttrShapeGamma, MakeValue(opt::Convert2Long(shape_gamma)), layer_norm_beta_gamma_backprop);
CreateMultipleOutputsOfAnfNode(graph, layer_norm_beta_gamma_backprop, kLayerNormBetaGammaBackpropOutputNum,

View File

@ -58,7 +58,7 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
auto data_ptr = assist_tensor->data_c();
MS_EXCEPTION_IF_NULL(data_ptr);
std::vector<float16> half_data;
int64_t dims = 1 * 1 * d * h * w;
const int64_t dims = 1 * 1 * d * h * w;
int64_t counter = dims;
for (int64_t i = 0; i < dims; i++) {
half_data.emplace_back(float16(static_cast<float>(counter)));
@ -110,20 +110,20 @@ const AnfNodePtr MaxPool3DGradGradFission::Process(const FuncGraphPtr &graph, co
MS_LOG(INFO) << "The node " << cnode->DebugString() << " is not equal to " << kInputNum << " inputs";
return nullptr;
}
std::vector<AnfNodePtr> new_node_inputs{NewValueNode(std::make_shared<Primitive>(kMaxPool3DGradGradOpName))};
auto assist_filter_const = CreateValueNode(cnode);
new_node_inputs.insert(new_node_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
new_node_inputs.push_back(assist_filter_const);
CNodePtr new_max_pool3d_grad_grad_node = graph->NewCNode(new_node_inputs);
MS_EXCEPTION_IF_NULL(new_max_pool3d_grad_grad_node);
new_max_pool3d_grad_grad_node->set_abstract(cnode->abstract());
new_max_pool3d_grad_grad_node->set_scope(cnode->scope());
AnfAlgo::CopyNodeAttrs(cnode, new_max_pool3d_grad_grad_node);
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kMaxPool3DGradGradOpName))};
auto assist_const = CreateValueNode(cnode);
(void)new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
(void)new_inputs.emplace_back(assist_const);
CNodePtr new_cnode = graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
if (kernel_graph != nullptr) {
kernel_graph->AddValueNodeToGraph(assist_filter_const);
kernel_graph->AddValueNodeToGraph(assist_const);
MS_LOG(INFO) << "Split MaxPool3DGradGrad op success.";
}
return new_max_pool3d_grad_grad_node;
return new_cnode;
}
} // namespace opt
} // namespace mindspore

View File

@ -49,17 +49,11 @@ AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_
MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range"
<< " trace: " << trace::DumpSourceLines(origin_pack_cnode);
}
std::vector<size_t> new_shape;
for (size_t i = 0; i < output_shape.size() + 1; ++i) {
if (i < LongToSize(axis)) {
new_shape.push_back(output_shape[i]);
} else if (i == LongToSize(axis)) {
new_shape.push_back(offset);
} else {
new_shape.push_back(output_shape[i - 1]);
}
std::vector<size_t> new_shape = output_shape;
auto axis_l = LongToSize(axis);
if (axis_l < new_shape.size()) {
new_shape[axis_l] = offset;
}
new_shape.erase(new_shape.begin() + axis + 1);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {new_shape},
new_pack.get());
return new_pack;

View File

@ -32,7 +32,7 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
MS_EXCEPTION_IF_NULL(bn_cnode);
if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) {
MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than "
<< kBatchNormRealInputNum + 1 << " trace: " << trace::DumpSourceLines(bn);
<< (kBatchNormRealInputNum + 1) << " trace: " << trace::DumpSourceLines(bn);
}
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(1)};
@ -58,7 +58,7 @@ AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNod
MS_EXCEPTION_IF_NULL(bn_cnode);
if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) {
MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than "
<< kBatchNormRealInputNum + 1 << " trace: " << trace::DumpSourceLines(bn);
<< (kBatchNormRealInputNum + 1) << " trace: " << trace::DumpSourceLines(bn);
}
if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) {
MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum

View File

@ -77,9 +77,11 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
tensor::TensorPtr filter_tensor = CreateTensor(node);
MS_EXCEPTION_IF_NULL(filter_tensor);
auto assist_const = std::make_shared<ValueNode>(filter_tensor);
MS_EXCEPTION_IF_NULL(assist_const);
auto assist_abstract = filter_tensor->ToAbstract();
assist_const->set_abstract(assist_abstract);
auto assist_kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(assist_kernel_info);
assist_const->set_kernel_info(assist_kernel_info);
kernel::KernelBuildInfo::KernelBuildInfoBuilder op_builder;
op_builder.SetOutputsFormat({kOpFormat_NC1HWC0});
@ -114,16 +116,16 @@ const AnfNodePtr SpaceToDepthSplit::Process(const FuncGraphPtr &graph, const Anf
}
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kSpaceToDepthOpName))};
auto assist_const = CreateValueNode(cnode);
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
new_inputs.push_back(assist_const);
auto last_input_value = CreateValueNode(cnode);
(void)new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
(void)new_inputs.emplace_back(last_input_value);
CNodePtr new_cnode = graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
if (kernel_graph != nullptr) {
kernel_graph->AddValueNodeToGraph(assist_const);
kernel_graph->AddValueNodeToGraph(last_input_value);
MS_LOG(INFO) << "Split SpaceToDepth op success.";
}
return new_cnode;

View File

@ -53,6 +53,9 @@ size_t GetSmallSplitSize(const AnfNodePtr &split_node, int64_t split_dim, int64_
if (LongToSize(split_dim) >= input_shape.size()) {
MS_LOG(EXCEPTION) << "The split_dim value should be less than the shape size of input 0";
}
if (num_split == 0) {
MS_LOG(EXCEPTION) << "Divisor 'num_split' should not be 0.";
}
return input_shape[LongToSize(split_dim)] / LongToSize(num_split);
}
@ -104,17 +107,23 @@ void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePt
if (split_dim < 0) {
split_dim += SizeToLong(output_shape.size());
}
for (size_t i = 0; i < LongToSize(num_split); ++i) {
output_shape[LongToSize(split_dim)] = size_splits_base[i];
if (split_dim < 0) {
MS_LOG(EXCEPTION) << "Error split dim: " << split_dim;
}
auto split_dim_l = LongToSize(split_dim);
auto num_split_l = LongToSize(num_split);
for (size_t i = 0; i < num_split_l; ++i) {
output_shape[split_dim_l] = LongToSize(size_splits_base[i]);
base_output_shapes_base.emplace_back(output_shape);
AnfAlgo::SetOutputInferTypeAndShape({type_id}, {output_shape}, base_splitv_outputs[i].get());
}
AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get());
}
} // namespace
AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int64_t num_split, int64_t divisor) {
AnfNodePtr SplitFission::DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int64_t num_split,
int64_t divisor, int64_t split_dim) const {
MS_EXCEPTION_IF_NULL(func_graph);
auto split_dim = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrAxis);
CNodePtr base_splitv = CreateBaseSplitVNode(func_graph, cnode);
// Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs.
@ -173,7 +182,6 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int6
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
return make_tuple;
}
} // namespace
const BaseRef SplitFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
@ -196,7 +204,7 @@ const AnfNodePtr SplitFission::Process(const FuncGraphPtr &func_graph, const Anf
if (num_split <= outputs_divisor_) {
return nullptr;
}
return DoFission(func_graph, cnode, num_split, outputs_divisor_);
return DoFission(func_graph, cnode, num_split, outputs_divisor_, AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrAxis));
}
} // namespace opt
} // namespace mindspore

View File

@ -16,20 +16,24 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLIT_FISSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLIT_FISSION_H_
#include <string>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
constexpr int kSplitOutputsDivisor = 63;
constexpr int64_t kSplitOutputsDivisor = 63;
class SplitFission : public PatternProcessPass {
public:
explicit SplitFission(bool multigraph = true)
: PatternProcessPass("split_fission", multigraph), outputs_divisor_(kSplitOutputsDivisor) {}
explicit SplitFission(const std::string name = "split_fission", bool multigraph = true,
int64_t divisor = kSplitOutputsDivisor)
: PatternProcessPass(name, multigraph), outputs_divisor_(divisor) {}
~SplitFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
protected:
AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int64_t num_split, int64_t divisor,
int64_t split_dim) const;
int64_t outputs_divisor_;
};
} // namespace opt

View File

@ -19,163 +19,6 @@
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore::opt {
namespace {
CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(input_node);
std::vector<AnfNodePtr> splitv_inputs{NewValueNode(std::make_shared<Primitive>(kSplitVOpName)), input_node};
CNodePtr splitv = func_graph->NewCNode(splitv_inputs);
MS_EXCEPTION_IF_NULL(splitv);
splitv->set_scope(input_node->scope());
return splitv;
}
CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) {
MS_EXCEPTION_IF_NULL(origin_cnode);
CheckCNodeInputSize(origin_cnode, kSplitInputTensorNum);
return CreateSplitVNode(func_graph, origin_cnode->input(1));
}
void SetAttrForSplitVNode(const AnfNodePtr &splitv, const std::vector<int64_t> &size_splits, int64_t split_dim,
int64_t num_split) {
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_splits), splitv);
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(split_dim), splitv);
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(num_split), splitv);
}
size_t GetSmallSplitSize(const AnfNodePtr &split_node, int64_t split_dim, int64_t num_split) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(split_node, 0);
if (split_dim < 0) {
split_dim += SizeToLong(input_shape.size());
}
if (LongToSize(split_dim) >= input_shape.size()) {
MS_LOG(EXCEPTION) << "The split_dim value should be less than the shape size of input 0";
}
if (num_split == 0) {
MS_LOG(EXCEPTION) << "Divisor 'num_split' should not be 0.";
}
return input_shape[LongToSize(split_dim)] / LongToSize(num_split);
}
void AddNewOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &new_splitv, int64_t outputs_num,
std::vector<AnfNodePtr> *inputs) {
MS_EXCEPTION_IF_NULL(inputs);
std::vector<AnfNodePtr> new_splitv_output;
CreateMultipleOutputsOfAnfNode(func_graph, new_splitv, LongToSize(outputs_num), &new_splitv_output);
inputs->insert(inputs->end(), new_splitv_output.begin(), new_splitv_output.end());
}
AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, int64_t index) {
MS_EXCEPTION_IF_NULL(func_graph);
auto idx = NewValueNode(index);
MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int64Imm>(index);
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
idx->set_abstract(abstract_scalar);
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
return tuple_getitem;
}
void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int64_t split_dim, int64_t split_size, int64_t num_split,
std::vector<TypeId> *new_type_ids,
std::vector<std::vector<size_t>> *new_output_shapes) {
MS_EXCEPTION_IF_NULL(new_type_ids);
MS_EXCEPTION_IF_NULL(new_output_shapes);
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
if (split_dim < 0) {
split_dim += SizeToLong(output_shape.size());
}
output_shape[LongToSize(split_dim)] = LongToSize(split_size);
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0);
for (int64_t i = 0; i < num_split; ++i) {
new_type_ids->emplace_back(type_id);
new_output_shapes->emplace_back(output_shape);
}
}
void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv,
const std::vector<AnfNodePtr> &base_splitv_outputs,
const std::vector<int64_t> &size_splits_base, int64_t split_dim,
int64_t num_split) {
SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split);
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0);
std::vector<TypeId> base_type_ids(num_split, type_id);
std::vector<std::vector<size_t>> base_output_shapes_base;
if (split_dim < 0) {
split_dim += SizeToLong(output_shape.size());
}
for (size_t i = 0; i < LongToSize(num_split); ++i) {
output_shape[LongToSize(split_dim)] = size_splits_base[i];
base_output_shapes_base.emplace_back(output_shape);
AnfAlgo::SetOutputInferTypeAndShape({type_id}, {output_shape}, base_splitv_outputs[i].get());
}
AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get());
}
AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int64_t num_split, int64_t divisor) {
MS_EXCEPTION_IF_NULL(func_graph);
auto split_dim = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrSplitDim);
CNodePtr base_splitv = CreateBaseSplitVNode(func_graph, cnode);
// Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs.
auto small_split_size = SizeToLong(GetSmallSplitSize(cnode, split_dim, num_split));
std::vector<int64_t> size_splits_new(divisor, small_split_size);
// Create new output shape and new output type id for each new Splitv node which has full inputs.
std::vector<TypeId> new_type_ids;
std::vector<std::vector<size_t>> new_output_shapes;
CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, divisor, &new_type_ids, &new_output_shapes);
// Create make_tuple input to create a make_tuple for replacing the old Split node.
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
// Start to divide the outputs of Split.
std::vector<int64_t> size_splits_base;
std::vector<AnfNodePtr> base_splitv_outputs;
const auto base_split_size = divisor * small_split_size;
int64_t nodes_num = 0;
int64_t cur_output_index = 0;
while (num_split - cur_output_index > divisor) {
auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
base_splitv_outputs.push_back(tuple_getitem);
CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem);
SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor);
AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get());
AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs);
cur_output_index += divisor;
size_splits_base.emplace_back(base_split_size);
nodes_num++;
}
if (cur_output_index < num_split) {
auto last_node_num_split = num_split - cur_output_index;
if (last_node_num_split > 1) {
auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
base_splitv_outputs.push_back(tuple_getitem);
CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem);
std::vector<int64_t> size_splits_new_last(last_node_num_split, small_split_size);
SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split);
// Create new output shape and new output type id for the last Splitv node
std::vector<TypeId> last_new_type_ids;
std::vector<std::vector<size_t>> last_new_output_shapes;
CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, last_node_num_split, &last_new_type_ids,
&last_new_output_shapes);
AnfAlgo::SetOutputInferTypeAndShape(last_new_type_ids, last_new_output_shapes, new_splitv.get());
AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs);
size_splits_base.emplace_back(last_node_num_split * small_split_size);
} else {
auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
base_splitv_outputs.push_back(tuple_getitem);
make_tuple_inputs.emplace_back(tuple_getitem);
size_splits_base.emplace_back(small_split_size);
}
nodes_num++;
}
// Set Attr and abstract for the base splitv
SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, base_splitv_outputs, size_splits_base, split_dim, nodes_num);
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
return make_tuple;
}
} // namespace
const BaseRef SplitVFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto split_prim = std::make_shared<Primitive>(kSplitVOpName);
@ -198,6 +41,6 @@ const AnfNodePtr SplitVFission::Process(const FuncGraphPtr &func_graph, const An
if (num_split <= outputs_divisor_) {
return nullptr;
}
return DoFission(func_graph, cnode, num_split, outputs_divisor_);
return DoFission(func_graph, cnode, num_split, outputs_divisor_, AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrSplitDim));
}
} // namespace mindspore::opt

View File

@ -17,21 +17,17 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_SPLITV_FISSION_H_
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ir_fission/split_fission.h"
namespace mindspore {
namespace opt {
class SplitVFission : public PatternProcessPass {
const int kSplitOutputsDivisor = 63;
constexpr int64_t kSplitVOutputsDivisor = 63;
class SplitVFission : public SplitFission {
public:
explicit SplitVFission(bool multigraph = true)
: PatternProcessPass("split_fission", multigraph), outputs_divisor_(kSplitOutputsDivisor) {}
explicit SplitVFission(bool multigraph = true) : SplitFission("splitv_fission", multigraph, kSplitVOutputsDivisor) {}
~SplitVFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
int64_t outputs_divisor_;
};
} // namespace opt
} // namespace mindspore

View File

@ -228,7 +228,11 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c
MS_EXCEPTION_IF_NULL(index_node);
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto index = GetValue<int64_t>(value_node->value());
auto value_index = GetValue<int64_t>(value_node->value());
if (value_index < 0) {
MS_LOG(EXCEPTION) << "Error value index: " << value_index;
}
auto index = LongToSize(value_index);
if (index == kReplaceOutputIndex0 || index == kReplaceOutputIndex1) {
(void)manager->Replace(output, bn_training_update_outputs[index]);
}