[MSLITE] CodeCheck: converter

This commit is contained in:
wang_shaocong 2021-10-19 16:33:42 +08:00
parent 12e71fa9d2
commit f4dc2dcd5b
49 changed files with 364 additions and 260 deletions

View File

@ -58,11 +58,19 @@ STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, Shap
MS_LOG(ERROR) << "string tensor's dim size not found.";
return RET_ERROR;
}
size_t shape_size = std::stoi(shape_size_str);
size_t shape_size = std::atoi(shape_size_str.c_str());
MS_CHECK_TRUE_RET(shape_size != 0, RET_ERROR);
for (; *offset < tensor_info->Size(); (*offset)++) {
if (tensor_data[*offset] == ',') {
cnt++;
shape_vector->push_back(std::stoi(shape_str));
int64_t shape = 0;
try {
shape = std::stoi(shape_str);
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Get shape failed: " << e.what();
return RET_ERROR;
}
shape_vector->push_back(shape);
shape_str.clear();
} else {
shape_str.push_back(tensor_data[*offset]);

View File

@ -197,11 +197,18 @@ int Flags::InitInTensorShape() {
return lite::RET_ERROR;
}
for (const auto &dim : dims) {
if (std::stoi(dim) < 0) {
auto dim_value = -1;
try {
dim_value = std::stoi(dim);
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Get dim failed: " << e.what();
return lite::RET_ERROR;
}
if (dim_value < 0) {
MS_LOG(ERROR) << "Unsupported dim < 0.";
return lite::RET_ERROR;
} else {
shape.push_back(std::stoi(dim));
shape.push_back(dim_value);
}
}
lite::ConverterInnerContext::GetInstance()->UpdateGraphInputTensorShape(name, shape);
@ -419,7 +426,12 @@ bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *pa
const char *colon = ":";
for (const auto &device : device_rates) {
std::vector<std::string> rate = lite::SplitStringToVector(device, *colon);
parallel_split_config->parallel_compute_rates_.push_back(std::stoi(rate.back()));
auto compute_rate = std::atoi(rate.back().c_str());
if (compute_rate == 0) {
MS_LOG(ERROR) << "The compute rate is invalid.";
return false;
}
parallel_split_config->parallel_compute_rates_.push_back(compute_rate);
}
if (parallel_split_config->parallel_compute_rates_.size() != 2) {
return false;

View File

@ -67,7 +67,7 @@ AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
auto while_prim = primitive->cast<PrimWhilePtr>();
MS_CHECK_TRUE_RET(while_prim != nullptr, nullptr);
AbstractBasePtrList output;
for (int64_t i = 0; i < (int64_t)input_args.size(); i++) {
for (size_t i = 0; i < input_args.size(); i++) {
auto build_shape_ptr = input_args[i]->BuildShape();
MS_CHECK_TRUE_RET(build_shape_ptr != nullptr, nullptr);
auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(build_shape_ptr)[kShape];

View File

@ -116,7 +116,7 @@ int CenterCrop(cv::Mat *image, int width, int height) {
}
int PreProcess(const preprocess::DataPreProcessParam &data_pre_process_param, const std::string &input_name,
int image_index, mindspore::tensor::MSTensor *tensor) {
size_t image_index, mindspore::tensor::MSTensor *tensor) {
if (tensor == nullptr) {
MS_LOG(ERROR) << "tensor is nullptr.";
return RET_NULL_PTR;
@ -155,7 +155,7 @@ int PreProcess(const preprocess::DataPreProcessParam &data_pre_process_param, co
return RET_OK;
}
int PreProcess(const DataPreProcessParam &data_pre_process_param, const std::string &input_name, int image_index,
int PreProcess(const DataPreProcessParam &data_pre_process_param, const std::string &input_name, size_t image_index,
void **data, size_t *size) {
if (data == nullptr || size == nullptr) {
MS_LOG(ERROR) << "data or size is nullptr.";

View File

@ -35,11 +35,11 @@ int Resize(cv::Mat *image, int width, int height, cv::InterpolationFlags resize_
int CenterCrop(cv::Mat *image, int width, int height);
// NOTE:`data` must be use delete[] to free buffer.
int PreProcess(const DataPreProcessParam &data_pre_process_param, const std::string &input_name, int image_index,
int PreProcess(const DataPreProcessParam &data_pre_process_param, const std::string &input_name, size_t image_index,
void **data, size_t *size);
int PreProcess(const preprocess::DataPreProcessParam &data_pre_process_param, const std::string &input_name,
int image_index, mindspore::tensor::MSTensor *tensor);
size_t image_index, mindspore::tensor::MSTensor *tensor);
int ImagePreProcess(const ImagePreProcessParam &image_preprocess_param, cv::Mat *image, void **data, size_t *size);

View File

@ -64,7 +64,7 @@ int ConcatSplitEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
Spliter::GetInstance()->graph_node_outputs();
auto finder = graph_node_outputs.find(pre_cnode->fullname_with_scope());
if (finder == graph_node_outputs.end()) {
return RET_OK;
return RET_ERROR;
}
if (finder->second.size() > 1) {
return RET_OK;
@ -89,7 +89,7 @@ int ConcatSplitEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
// get inputs node
auto it = graph_node_outputs.find(cnode->fullname_with_scope());
if (it == graph_node_outputs.end()) {
return RET_OK;
return RET_ERROR;
}
int out_num = it->second.size();
if (out_num != prim->get_number_split()) {
@ -101,14 +101,14 @@ int ConcatSplitEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
auto tmp = it->second[i];
auto tmp_cnode = tmp->cast<CNodePtr>();
if (tmp_cnode == nullptr) {
return RET_OK;
return RET_ERROR;
}
if (!CheckPrimitiveType(tmp_cnode, prim::kPrimTupleGetItem)) {
return RET_OK;
}
auto tmp_it = graph_node_outputs.find(tmp_cnode->fullname_with_scope());
if (tmp_it == graph_node_outputs.end()) {
return RET_OK;
return RET_ERROR;
}
if (tmp_it->second.size() != 1) {
return RET_OK;
@ -116,7 +116,7 @@ int ConcatSplitEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
auto next = tmp_it->second[0];
auto next_cnode = next->cast<CNodePtr>();
MS_ASSERT(next_cnode != nullptr);
inputs_node.push_back(next_cnode);
}
// replace inputs

View File

@ -48,14 +48,21 @@ std::vector<int64_t> GetSplitPadList(const std::shared_ptr<ops::Conv2DFusion> &o
int64_t output_w = static_cast<int64_t>(
std::ceil(static_cast<float>(input_w) / static_cast<float>(ori_conv_prim->get_stride().at(kIndexW))));
auto kernel_h = ori_conv_prim->get_kernel_size().at(kIndexH);
auto dilation_h = ori_conv_prim->get_dilation().at(kIndexH);
auto kernel_w = ori_conv_prim->get_kernel_size().at(kIndexW);
auto dilation_w = ori_conv_prim->get_dilation().at(kIndexW);
if (INT_MUL_OVERFLOW_THRESHOLD((kernel_h - 1), dilation_h, INT64_MAX) ||
INT_MUL_OVERFLOW_THRESHOLD((kernel_w - 1), dilation_w, INT64_MAX)) {
MS_LOG(ERROR) << "int mul overflow";
return {};
}
std::vector<int64_t> new_pad_list;
int64_t pad_up = 0, pad_down = 0, pad_left = 0, pad_right = 0;
int64_t pad_h_all = (output_h - 1) * ori_conv_prim->get_stride().at(kIndexH) +
(ori_conv_prim->get_kernel_size().at(kIndexH) - 1) * ori_conv_prim->get_dilation().at(kIndexH) +
1 - input_h;
int64_t pad_w_all = (output_w - 1) * ori_conv_prim->get_stride().at(kIndexW) +
(ori_conv_prim->get_kernel_size().at(kIndexW) - 1) * ori_conv_prim->get_dilation().at(kIndexW) +
1 - input_w;
int64_t pad_h_all =
(output_h - 1) * ori_conv_prim->get_stride().at(kIndexH) + (kernel_h - 1) * dilation_h + 1 - input_h;
int64_t pad_w_all =
(output_w - 1) * ori_conv_prim->get_stride().at(kIndexW) + (kernel_w - 1) * dilation_w + 1 - input_w;
// only check pad_up and pad_down is positive
// if compute overflowed, we will get abnormal it in infer_shape
if (pad_h_all >= 0) {
@ -110,9 +117,11 @@ bool CalSplitInShape(const std::vector<std::vector<ShapeVector>> &node_in_out_sh
const std::shared_ptr<ops::Conv2DFusion> &ori_conv_prim, int64_t index_node,
std::vector<std::vector<int64_t>> *split_axis_inputs_shape,
std::vector<std::vector<int64_t>> *split_axis_reduce_inputs_shape) {
MS_ASSERT(split_info != nullptr && split_axis_inputs_shape != nullptr && split_axis_reduce_inputs_shape != nullptr);
MS_ASSERT(split_info != nullptr && ori_conv_prim != nullptr && split_axis_inputs_shape != nullptr &&
split_axis_reduce_inputs_shape != nullptr);
MS_ASSERT(node_in_out_shapes.size() > index_node);
auto in_out_shape = node_in_out_shapes.at(index_node);
MS_ASSERT(!in_out_shape.empty());
auto in_shape = in_out_shape.front();
if (in_shape.size() < kAxisW) {
MS_LOG(DEBUG) << "out of in_shape range";
@ -129,29 +138,34 @@ bool CalSplitInShape(const std::vector<std::vector<ShapeVector>> &node_in_out_sh
// iter splited_num
for (int64_t index = 0; index < split_num; index++) {
// shape
auto stride_h = ori_conv_prim->get_stride()[kIndexH];
auto split_axis_dim = (*split_axis_inputs_shape)[index_node][index] - 1;
if (INT_MUL_OVERFLOW_THRESHOLD(stride_h, split_axis_dim, INT64_MAX)) {
MS_LOG(ERROR) << "int mul overflow";
return false;
}
if (split_info->axis == CuttingStragedy::CUT_H) { // H
if (index == 0) {
tmp = ori_conv_prim->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[index_node][index] - 1) -
ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
tmp =
stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
} else if (index == split_num - 1) {
tmp = ori_conv_prim->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[index_node][index] - 1) -
ori_conv_prim->get_pad_list()[kPadDown] + ori_conv_prim->get_kernel_size()[kIndexH];
} else {
tmp = ori_conv_prim->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[index_node][index] - 1) - 0 +
tmp = stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadDown] +
ori_conv_prim->get_kernel_size()[kIndexH];
} else {
tmp = stride_h * split_axis_dim - 0 + ori_conv_prim->get_kernel_size()[kIndexH];
}
}
split_axis_shape.push_back(tmp);
// reduce shape
auto split_axis_reduce_dim = (*split_axis_reduce_inputs_shape)[index_node][index] - 1;
if (split_info->axis == CuttingStragedy::CUT_H) { // H
if (index == split_num - 1) {
tmp = ori_conv_prim->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[index_node][index] - 1) -
ori_conv_prim->get_pad_list()[kPadDown] - ori_conv_prim->get_pad_list()[kPadUp] +
ori_conv_prim->get_kernel_size()[kIndexH];
} else {
tmp = ori_conv_prim->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[index_node][index] - 1) -
tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadDown] -
ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
} else {
tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadUp] +
ori_conv_prim->get_kernel_size()[kIndexH];
}
}
split_axis_reduce_shape.push_back(tmp);
@ -186,9 +200,8 @@ std::shared_ptr<ops::Conv2DFusion> CopyConvPrim(const std::shared_ptr<ops::Conv2
new_prim->set_stride(ori_conv_prim->get_stride());
new_prim->set_activation_type(ori_conv_prim->get_activation_type());
new_prim->set_pad_list(ori_conv_prim->get_pad_list());
if (ori_conv_prim->GetAttr(ops::kIsDepthWise) != nullptr) {
auto is_depth_value = ori_conv_prim->GetAttr(ops::kIsDepthWise);
MS_CHECK_TRUE_MSG(is_depth_value != nullptr, nullptr, "conv has no kIsDepthWise attribute");
auto is_depth_value = ori_conv_prim->GetAttr(ops::kIsDepthWise);
if (is_depth_value != nullptr) {
bool is_depth_wise = GetValue<bool>(is_depth_value);
new_prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(is_depth_wise));
}
@ -218,6 +231,7 @@ bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePt
auto input_shapes = Spliter::GetInstance()->graph_node_input_shapes()[out_node_name];
// 0-> in-shape 1->out-shape
// only one in and one output
MS_ASSERT(!input_shapes.empty() && !output_shapes.empty());
node_in_out_shapes.push_back({input_shapes.front(), output_shapes.front()});
index_node++;
}
@ -244,7 +258,9 @@ bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePt
// iter node
while (index_node < node_size) {
auto conv_cnode = conv_nodes[index_node]->cast<CNodePtr>();
MS_ASSERT(conv_cnode != nullptr);
auto ori_conv_prim = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(conv_cnode->input(kAnfPrimitiveIndex));
MS_CHECK_TRUE_RET(ori_conv_prim != nullptr, false);
if (!CalSplitInShape(node_in_out_shapes, split_info, ori_conv_prim, index_node, &split_axis_inputs_shape,
&split_axis_reduce_inputs_shape)) {
MS_LOG(ERROR) << "CalSplitInShape failed";
@ -405,6 +421,10 @@ bool UpdateRatioWithPadStride(int64_t *ratio, size_t ratio_len, size_t split_siz
int visited_block = 0;
for (size_t i = 0; i < split_size - 1; i++) {
visited_block += ratio[i];
if (INT_MUL_OVERFLOW_THRESHOLD(split_dim_size, visited_block, INT64_MAX)) {
MS_LOG(ERROR) << "int mul overflow";
return false;
}
int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
new_ratio[i + 1] = cur_border;
}

View File

@ -30,6 +30,7 @@ AnfNodePtr IterNodeOutputs::Run(const FuncGraphPtr &func_graph, const AnfNodePtr
auto inputs = cnode->inputs();
for (const auto &input_node : inputs) {
MS_ASSERT(input_node != nullptr);
if (!utils::isa<CNodePtr>(input_node)) {
continue;
}

View File

@ -29,6 +29,7 @@ using mindspore::schema::PrimitiveType_Conv2dTransposeFusion;
namespace mindspore {
namespace opt {
std::string MultiConvSplitPass::IsMultiParallelConvNode(const AnfNodePtr &node) const {
MS_ASSERT(node != nullptr);
for (const auto &parallel_prim : kParallelOpNames) {
if (CheckPrimitiveType(node, parallel_prim.first.first)) {
return parallel_prim.second;

View File

@ -32,6 +32,7 @@ AnfNodePtr NodeOutShapes::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &
auto cnode = node->cast<CNodePtr>();
// assume multi inputs
for (const auto &input_node : cnode->inputs()) {
MS_ASSERT(input_node != nullptr);
if (utils::isa<CNodePtr>(input_node) || utils::isa<ParameterPtr>(input_node)) {
auto in_shape = input_node->Shape();
if (in_shape == nullptr) {

View File

@ -56,10 +56,10 @@ bool FuseBias(const lite::DataInfo &add_bias, const lite::DataInfo &conv_bias, s
add_bias.data_.size()) != EOK) {
return false;
}
fusion_bias->resize(out_channel, 0);
fusion_bias->resize(static_cast<size_t>(out_channel), 0);
if (!conv_bias.data_.empty()) {
if (conv_bias.data_type_ != TypeId::kNumberTypeFloat32 && conv_bias.data_type_ != TypeId::kNumberTypeFloat &&
conv_bias.data_.size() != out_channel * sizeof(float)) {
conv_bias.data_.size() != static_cast<size_t>(out_channel) * sizeof(float)) {
return false;
}
if (memcpy_s(fusion_bias->data(), fusion_bias->size() * sizeof(float), conv_bias.data_.data(),

View File

@ -110,8 +110,7 @@ void ReplaceParamsAndNodes(const FuncGraphPtr &func_graph, const CNodePtr &conv_
(void)manager->Replace(pad_cnode, pad_cnode->input(1));
}
bool IsPrimitiveProper(const CNodePtr &conv_cnode, const CNodePtr &pad_cnode) {
MS_ASSERT(conv_cnode != nullptr);
bool IsPrimitiveProper(const CNodePtr &pad_cnode) {
MS_ASSERT(pad_cnode != nullptr);
if (!utils::isa<Parameter>(pad_cnode->input(kInputIndexTwo))) {
return false;
@ -240,7 +239,7 @@ AnfNodePtr ConvPadFusion::Process(const std::string &pattern_name, const FuncGra
return nullptr;
}
if (!IsPrimitiveProper(conv_cnode, pad_cnode)) {
if (!IsPrimitiveProper(pad_cnode)) {
MS_LOG(WARNING) << conv_cnode->fullname_with_scope() << " is not match with previous "
<< pad_cnode->fullname_with_scope() << " op. Fusion failed!";
return nullptr;

View File

@ -267,7 +267,7 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &func_graph, const Anf
return nullptr;
}
scale_primitive->set_activation_type(scale_act_type_);
scale_primitive->set_axis(-(bias_tensor_->shape_c().size()));
scale_primitive->set_axis(-(static_cast<int64_t>(bias_tensor_->shape_c().size())));
// create scale op
auto scale_node = func_graph->NewCNode(scale_primitive, {mul_input_anode_, mul_const_anode_, add_const_anode_});
return scale_node;

View File

@ -51,7 +51,8 @@ STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) {
}
axes->resize(1);
if (!axes_value->shape().empty()) {
axes->resize(axes_value->shape()[0]);
MS_CHECK_GE(axes_value->shape()[0], 0, lite::RET_ERROR);
axes->resize(static_cast<size_t>(axes_value->shape()[0]));
}
if (memcpy_s(axes->data(), axes->size() * sizeof(int), axes_value->data_c(), axes_value->Size()) == EOK) {
return lite::RET_OK;
@ -285,7 +286,8 @@ int ExpandDimsShapeSizeInfer(const std::vector<int> &in_shape_size, const schema
int StridedSliceShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
MS_ASSERT(in_shape_size.size() > 0);
auto new_axis_mask = primitive.value.AsStridedSlice()->new_axis_mask;
MS_ASSERT(primitive.value.AsStridedSlice() != nullptr);
auto new_axis_mask = static_cast<size_t>(primitive.value.AsStridedSlice()->new_axis_mask);
auto add_dims = 0;
while (new_axis_mask != 0) {
new_axis_mask = (new_axis_mask - 1) & new_axis_mask;
@ -404,7 +406,7 @@ std::map<string, int> NormFusion::ShapeSizeInfer(const FuncGraphPtr &func_graph)
}
// Cal shape size infer function
auto shape_size_infer_func = shape_size_infer_iter->second;
auto shape_size = shape_size_infer_iter->second(in_shape_sizes, *prim_t);
auto shape_size = shape_size_infer_func(in_shape_sizes, *prim_t);
// Update node shape size map
node_shape_size[cnode->fullname_with_scope()] = shape_size;
}

View File

@ -766,8 +766,8 @@ const AnfNodePtr TfBidirectionGruFusion::Process(const FuncGraphPtr &func_graph,
MS_CHECK_TRUE_RET(fw_cond_graph_pattern != nullptr, nullptr);
auto fw_cond = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[0]]);
MS_ASSERT(fw_cond != nullptr);
auto fw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, fw_cond_graph_pattern, fw_cond_primitive_vars,
fw_cond, kCondCNodesNum, kCondNodesNum);
auto fw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(fw_cond_graph_pattern, fw_cond_primitive_vars, fw_cond,
kCondCNodesNum, kCondNodesNum);
if (fw_cond_equiv == nullptr || fw_cond_equiv->empty()) {
return nullptr;
}
@ -778,8 +778,8 @@ const AnfNodePtr TfBidirectionGruFusion::Process(const FuncGraphPtr &func_graph,
MS_CHECK_TRUE_RET(bw_cond_graph_pattern != nullptr, nullptr);
auto bw_cond = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[0]]);
MS_ASSERT(bw_cond != nullptr);
auto bw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, bw_cond_graph_pattern, bw_cond_primitive_vars,
bw_cond, kCondCNodesNum, kCondNodesNum);
auto bw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(bw_cond_graph_pattern, bw_cond_primitive_vars, bw_cond,
kCondCNodesNum, kCondNodesNum);
if (bw_cond_equiv == nullptr || bw_cond_equiv->empty()) {
return nullptr;
}
@ -790,8 +790,8 @@ const AnfNodePtr TfBidirectionGruFusion::Process(const FuncGraphPtr &func_graph,
MS_CHECK_TRUE_RET(fw_body_graph_pattern != nullptr, nullptr);
auto fw_body = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[1]]);
MS_ASSERT(fw_body != nullptr);
auto fw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, fw_body_graph_pattern, fw_primitive_vars_body,
fw_body, kBodyCNodesNum, kBodyNodesNum);
auto fw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(fw_body_graph_pattern, fw_primitive_vars_body, fw_body,
kBodyCNodesNum, kBodyNodesNum);
if (fw_body_equiv == nullptr || fw_body_equiv->empty()) {
return nullptr;
}
@ -802,8 +802,8 @@ const AnfNodePtr TfBidirectionGruFusion::Process(const FuncGraphPtr &func_graph,
MS_CHECK_TRUE_RET(bw_body_graph_pattern != nullptr, nullptr);
auto bw_body = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[1]]);
MS_ASSERT(bw_body != nullptr);
auto bw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, bw_body_graph_pattern, bw_primitive_vars_body,
bw_body, kBodyCNodesNum, kBodyNodesNum);
auto bw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(bw_body_graph_pattern, bw_primitive_vars_body, bw_body,
kBodyCNodesNum, kBodyNodesNum);
if (bw_body_equiv == nullptr || bw_body_equiv->empty()) {
return nullptr;
}

View File

@ -191,7 +191,7 @@ STATUS TfLstmCellFusion::SetWeightAbstractAndDefault(const ParameterPtr &weight,
return RET_ERROR;
}
const auto param_num = shape[0] * shape[1] * shape[kInputIndexTwo];
auto tensor_data = new (std::nothrow) float[param_num * sizeof(float)];
auto tensor_data = new (std::nothrow) float[static_cast<size_t>(param_num) * sizeof(float)];
std::vector<int> data_diff{0, 3, 2, 1};
if (tensor_data == nullptr) {
MS_LOG(DEBUG) << "new data failed";
@ -204,7 +204,8 @@ STATUS TfLstmCellFusion::SetWeightAbstractAndDefault(const ParameterPtr &weight,
}
}
}
auto tensor_info = lite::CreateTensorInfo(tensor_data, param_num * sizeof(float), shape, kNumberTypeFloat32);
auto tensor_info =
lite::CreateTensorInfo(tensor_data, static_cast<size_t>(param_num) * sizeof(float), shape, kNumberTypeFloat32);
delete[] tensor_data;
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "create tensor info failed.";
@ -359,11 +360,6 @@ CNodePtr TfLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const
MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
auto &vars = while_input_vars_;
auto limit1 = utils::cast<AnfNodePtr>((*equiv)[vars[3]]);
MS_ASSERT(limit1);
auto limit2 = utils::cast<AnfNodePtr>((*equiv)[vars[7]]);
MS_ASSERT(limit2);
auto weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]);
MS_ASSERT(weight);
auto bias = utils::cast<AnfNodePtr>((*equiv)[vars[10]]);

View File

@ -329,7 +329,7 @@ bool TfliteLstmCellFusion::CheckReferencedOutputs(const FuncGraphPtr &func_graph
auto manager = func_graph->manager();
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr";
return RET_ERROR;
return false;
}
auto while_node_users = manager->node_users()[while_cnode];
std::vector<size_t> valid_indexes{3, 4, 5};
@ -352,9 +352,9 @@ bool TfliteLstmCellFusion::CheckReferencedOutputs(const FuncGraphPtr &func_graph
return true;
}
EquivPtr TfliteLstmCellFusion::CheckSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &pattern,
const PrimitiveVarMapPtr &primitive_vars, const AnfNodePtr &anf_sub_graph,
const size_t cnode_num, const size_t all_node_num) {
EquivPtr TfliteLstmCellFusion::CheckSubGraph(const AnfNodePtr &pattern, const PrimitiveVarMapPtr &primitive_vars,
const AnfNodePtr &anf_sub_graph, const size_t cnode_num,
const size_t all_node_num) {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(pattern != nullptr);
MS_ASSERT(primitive_vars != nullptr);
@ -370,9 +370,7 @@ EquivPtr TfliteLstmCellFusion::CheckSubGraph(const FuncGraphPtr &func_graph, con
return MatchGraph(sub_graph, primitive_vars, pattern);
}
bool TfliteLstmCellFusion::CheckBodyGraph(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const CNodePtr &while_cnode, float *zoneout_cell,
float *zoneout_hidden) const {
bool TfliteLstmCellFusion::CheckBodyGraph(const EquivPtr &equiv, float *zoneout_cell, float *zoneout_hidden) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(equiv != nullptr);
MS_ASSERT(while_cnode != nullptr);
@ -465,8 +463,9 @@ STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector<AnfNodePtr> &par
MS_LOG(ERROR) << "bias data shape error";
return RET_ERROR;
}
step = data_shapes[0][0];
data_size = 8 * step;
step = static_cast<int>(data_shapes[0][0]);
MS_CHECK_INT_MUL_NOT_OVERFLOW(C8NUM, step, RET_ERROR);
data_size = C8NUM * step;
new_shape = std::vector<int64_t>({1, data_size});
} else {
@ -475,8 +474,10 @@ STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector<AnfNodePtr> &par
return RET_ERROR;
}
new_shape = std::vector<int64_t>({1, data_shapes[0][0] * kUnidirectionalGateNum, data_shapes[0][1]});
step = data_shapes[0][0] * data_shapes[0][1];
data_size = 4 * step;
MS_CHECK_INT_MUL_NOT_OVERFLOW(data_shapes[0][0], data_shapes[0][1], RET_ERROR);
step = static_cast<int>(data_shapes[0][0] * data_shapes[0][1]);
MS_CHECK_INT_MUL_NOT_OVERFLOW(C4NUM, step, RET_ERROR);
data_size = C4NUM * step;
}
auto tensor_info = lite::CreateTensorInfo(nullptr, 0, new_shape, kNumberTypeFloat32);
@ -528,12 +529,6 @@ CNodePtr TfliteLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, co
MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
auto &vars = while_input_vars_;
auto limit1 = utils::cast<AnfNodePtr>((*equiv)[vars[3]]);
MS_ASSERT(limit1);
auto limit2 = utils::cast<AnfNodePtr>((*equiv)[vars[7]]);
MS_ASSERT(limit2);
auto i2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]);
MS_ASSERT(i2i_weight);
auto i2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[10]]);
@ -764,8 +759,8 @@ const AnfNodePtr TfliteLstmCellFusion::Process(const FuncGraphPtr &func_graph, c
MS_CHECK_TRUE_RET(primitive_vars_cond != nullptr, nullptr);
auto cond_graph_pattern = GetCondGraphPattern(primitive_vars_cond);
MS_CHECK_TRUE_RET(cond_graph_pattern != nullptr, nullptr);
auto cond_equiv = CheckSubGraph(func_graph, cond_graph_pattern, primitive_vars_cond, while_cnode->input(1),
cond_cnodes_num_, cond_nodes_num_);
auto cond_equiv =
CheckSubGraph(cond_graph_pattern, primitive_vars_cond, while_cnode->input(1), cond_cnodes_num_, cond_nodes_num_);
if (cond_equiv == nullptr || cond_equiv->empty()) {
return nullptr;
}
@ -773,14 +768,14 @@ const AnfNodePtr TfliteLstmCellFusion::Process(const FuncGraphPtr &func_graph, c
MS_CHECK_TRUE_RET(primitive_vars_body != nullptr, nullptr);
auto body_graph_pattern = GetBodyGraphPattern(primitive_vars_body);
MS_CHECK_TRUE_RET(body_graph_pattern != nullptr, nullptr);
auto body_equiv = CheckSubGraph(func_graph, body_graph_pattern, primitive_vars_body, while_cnode->input(2),
body_cnodes_num_, body_nodes_num_);
auto body_equiv =
CheckSubGraph(body_graph_pattern, primitive_vars_body, while_cnode->input(2), body_cnodes_num_, body_nodes_num_);
if (body_equiv == nullptr || body_equiv->empty()) {
return nullptr;
}
float zoneout_cell = 0.0f;
float zoneout_hidden = 0.0f;
if (!CheckBodyGraph(func_graph, body_equiv, while_cnode, &zoneout_cell, &zoneout_hidden)) {
if (!CheckBodyGraph(body_equiv, &zoneout_cell, &zoneout_hidden)) {
return nullptr;
}
const std::string lstm_name = "lstm_" + while_cnode->fullname_with_scope();

View File

@ -36,9 +36,8 @@ class TfliteLstmCellFusion : public PatternProcessPass {
static EquivPtr MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars,
const AnfNodePtr &pattern);
static EquivPtr CheckSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &pattern,
const PrimitiveVarMapPtr &primitive_vars, const AnfNodePtr &anf_sub_graph,
size_t cnode_num, size_t all_node_num);
static EquivPtr CheckSubGraph(const AnfNodePtr &pattern, const PrimitiveVarMapPtr &primitive_vars,
const AnfNodePtr &anf_sub_graph, size_t cnode_num, size_t all_node_num);
static lite::STATUS SetAbstractTuple(const CNodePtr &cnode, int output_num);
@ -68,8 +67,7 @@ class TfliteLstmCellFusion : public PatternProcessPass {
private:
CNodePtr GetWhileCnode(const AnfNodePtr &cnode) const;
bool CheckBodyGraph(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const CNodePtr &while_cnode,
float *zoneout_cell, float *zoneout_hidden) const;
bool CheckBodyGraph(const EquivPtr &equiv, float *zoneout_cell, float *zoneout_hidden) const;
static bool CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode);

View File

@ -138,9 +138,8 @@ CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu
return cnode;
}
AnfNodePtr TransposeFusion::TransTransFusion(const mindspore::FuncGraphPtr &func_graph,
const mindspore::AnfNodePtr &node) const {
MS_ASSERT(func_graph != nullptr && node != nullptr);
AnfNodePtr TransposeFusion::TransTransFusion(const mindspore::AnfNodePtr &node) const {
MS_ASSERT(node != nullptr);
auto trans_cnode_2 = node->cast<CNodePtr>();
if (IsMarkedTrainOp(trans_cnode_2)) {
return nullptr;
@ -181,7 +180,7 @@ AnfNodePtr TransposeFusion::Process(const std::string &pattern_name, const minds
return nullptr;
}
if (pattern_name == "TransTransPatternName") {
return TransTransFusion(func_graph, node);
return TransTransFusion(node);
}
if (node->cast<CNodePtr>() == nullptr) {
return nullptr;

View File

@ -38,7 +38,7 @@ class TransposeFusion : public MultiplePatternProcessPass {
VectorRef DefineActivationscalePattern() const;
VectorRef DefineTransTransPattern() const;
VectorRef DefineBiasAddPattern() const;
AnfNodePtr TransTransFusion(const mindspore::FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const;
AnfNodePtr TransTransFusion(const mindspore::AnfNodePtr &node) const;
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &,
const EquivPtr &) const override;
};

View File

@ -24,10 +24,11 @@
#include "ops/tensor_array_read.h"
#include "ops/tensor_array_write.h"
#include "tools/converter/ops/ops_def.h"
#include "nnacl/op_base.h"
namespace mindspore::opt {
constexpr auto kDefaultIndex = 0;
constexpr auto kInputIndex = 1;
constexpr auto kInputNodeIndex = 1;
constexpr auto kDefaultNumTensors = 1;
constexpr auto kFlowInPlaceHolder = 1;
@ -45,6 +46,8 @@ static bool IsSupportedNode(const BaseRef &n) {
}
static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tensor_array_write_node) {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(tensor_array_write_node != nullptr);
// set tensor_array_write_node as graph output to keep it
auto return_node = func_graph->get_return();
if (!CheckPrimitiveType(return_node, prim::kPrimReturn)) {
@ -56,7 +59,7 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens
MS_LOG(ERROR) << "graph return node is not cnode";
return lite::RET_NULL_PTR;
}
auto output_node = return_node->input(kInputIndex);
auto output_node = return_node->input(kInputNodeIndex);
if (output_node == nullptr) {
MS_LOG(ERROR) << "graph output node is null";
return lite::RET_NULL_PTR;
@ -80,8 +83,9 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens
MS_LOG(ERROR) << "make_tuple_prim_ptr is nullptr";
return lite::RET_NULL_PTR;
}
auto make_tuple_cnode =
func_graph->NewCNode({NewValueNode(make_tuple_prim_ptr), output_node, tensor_array_write_node});
auto make_tuple_vnode = NewValueNode(make_tuple_prim_ptr);
MS_CHECK_TRUE_RET(make_tuple_vnode != nullptr, lite::RET_NULL_PTR);
auto make_tuple_cnode = func_graph->NewCNode({make_tuple_vnode, output_node, tensor_array_write_node});
if (make_tuple_cnode == nullptr) {
MS_LOG(ERROR) << "NewCNode failed";
return lite::RET_NULL_PTR;
@ -94,7 +98,10 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens
MS_LOG(ERROR) << "return_prim_ptr is nullptr";
return lite::RET_NULL_PTR;
}
auto new_return_node = func_graph->NewCNode({NewValueNode(return_prim_ptr), make_tuple_cnode});
auto return_value_node = NewValueNode(return_prim_ptr);
MS_CHECK_TRUE_RET(return_value_node != nullptr, lite::RET_NULL_PTR);
auto new_return_node = func_graph->NewCNode({return_value_node, make_tuple_cnode});
MS_CHECK_TRUE_RET(new_return_node != nullptr, lite::RET_NULL_PTR);
new_return_node->set_fullname_with_scope(return_cnode->fullname_with_scope());
MS_ASSERT(new_return_node != nullptr);
func_graph->set_return(new_return_node);
@ -147,11 +154,7 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A
return nullptr;
}
auto tensor_info = utils::cast<tensor::TensorPtr>(abstract_tensor->GetValueTrack());
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "tensor::Tensor of abstract is nullptr";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
MS_ASSERT(tensor_info != nullptr);
if (tensor_info->data_type() == kObjectTypeTensorType) {
MS_LOG(ERROR) << "tensor::Tensor of abstract is nullptr";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_SUPPORT);
@ -160,35 +163,32 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A
// tensor_array
auto tensor_array = std::make_shared<ops::TensorArray>();
if (tensor_array == nullptr) {
MS_LOG(ERROR) << "tensor_array is nullptr";
return nullptr;
}
MS_CHECK_TRUE_RET(tensor_array != nullptr, nullptr);
std::vector<int> element_shape;
std::for_each(tensor_info->shape().begin(), tensor_info->shape().end(),
[&element_shape](int64_t v) { element_shape.push_back(static_cast<int>(v)); });
tensor_array->set_element_shape(element_shape);
tensor_array->set_data_type(tensor_info->data_type());
auto tensor_array_node = func_graph->NewCNode({
NewValueNode(tensor_array),
NewValueNode(kDefaultNumTensors),
});
auto tensor_array_vnode = NewValueNode(tensor_array);
MS_CHECK_TRUE_RET(tensor_array_vnode != nullptr, nullptr);
auto num_tensors_vnode = NewValueNode(kDefaultNumTensors);
MS_CHECK_TRUE_RET(num_tensors_vnode != nullptr, nullptr);
auto tensor_array_node = func_graph->NewCNode({tensor_array_vnode, num_tensors_vnode});
MS_ASSERT(tensor_array_node != nullptr);
tensor_array_node->set_abstract(abstract->Clone());
tensor_array_node->set_fullname_with_scope(cnode->fullname_with_scope() + "_tensor_array");
// {"handle", "index", "flow_in"} -> {"tensor"}
auto tensor_array_read = std::make_shared<ops::TensorArrayRead>();
if (tensor_array_read == nullptr) {
MS_LOG(ERROR) << "tensor_array_read is nullptr";
return nullptr;
}
auto tensor_array_read_node = func_graph->NewCNode({
NewValueNode(tensor_array_read),
tensor_array_node,
NewValueNode(kDefaultIndex),
NewValueNode(kFlowInPlaceHolder),
});
MS_CHECK_TRUE_RET(tensor_array_read != nullptr, nullptr);
auto tensor_array_read_vnode = NewValueNode(tensor_array_read);
MS_CHECK_TRUE_RET(tensor_array_read_vnode != nullptr, nullptr);
auto read_index_vnode = NewValueNode(kDefaultIndex);
MS_CHECK_TRUE_RET(read_index_vnode != nullptr, nullptr);
auto read_flow_in_vnode = NewValueNode(kFlowInPlaceHolder);
MS_CHECK_TRUE_RET(read_flow_in_vnode != nullptr, nullptr);
auto tensor_array_read_node =
func_graph->NewCNode({tensor_array_read_vnode, tensor_array_node, read_index_vnode, read_flow_in_vnode});
MS_ASSERT(tensor_array_read_node != nullptr);
tensor_array_read_node->set_abstract(abstract->Clone());
tensor_array_read_node->set_fullname_with_scope(cnode->fullname_with_scope() + "_tensor_array_read");
@ -196,17 +196,15 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A
// {"handle", "index", "value", "flow_in"} -> {"flow_out"}
auto tensor_array_write = std::make_shared<ops::TensorArrayWrite>();
if (tensor_array_write == nullptr) {
MS_LOG(ERROR) << "tensor_array_write is nullptr";
return nullptr;
}
auto tensor_array_write_node = func_graph->NewCNode({
NewValueNode(tensor_array_write),
tensor_array_node,
NewValueNode(kDefaultIndex),
cnode,
NewValueNode(kFlowInPlaceHolder),
});
MS_CHECK_TRUE_RET(tensor_array_write != nullptr, nullptr);
auto tensor_array_write_vnode = NewValueNode(tensor_array_write);
MS_CHECK_TRUE_RET(tensor_array_write_vnode != nullptr, nullptr);
auto write_index_vnode = NewValueNode(kDefaultIndex);
MS_CHECK_TRUE_RET(write_index_vnode != nullptr, nullptr);
auto write_flow_in_vnode = NewValueNode(kFlowInPlaceHolder);
MS_CHECK_TRUE_RET(write_flow_in_vnode != nullptr, nullptr);
auto tensor_array_write_node =
func_graph->NewCNode({tensor_array_write_vnode, tensor_array_node, write_index_vnode, cnode, write_flow_in_vnode});
if (tensor_array_write_node == nullptr) {
MS_LOG(ERROR) << "rensor_array_write_node is nullptr";
return nullptr;

View File

@ -35,6 +35,7 @@ bool ClipConvertActivationPass::Run(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
MS_ASSERT(node != nullptr);
if (!utils::isa<CNode>(node)) {
continue;
}
@ -57,10 +58,12 @@ bool ClipConvertActivationPass::Run(const FuncGraphPtr &graph) {
if ((min == -1) && (max == -1)) {
if (clip_cnode->size() > kClipMinIndex) {
auto min_tensor_info = GetTensorInfo(clip_cnode->input(kClipMinIndex));
MS_CHECK_TRUE_MSG(min_tensor_info != nullptr, false, "min_tensor_info is nullptr");
if (min_tensor_info->data_type() != mindspore::kNumberTypeFloat32) {
MS_LOG(ERROR) << "Clip param type invalid";
return false;
}
MS_CHECK_TRUE_MSG(min_tensor_info->data_c() != nullptr, false, "tensor data is nullptr");
min = *reinterpret_cast<float *>(min_tensor_info->data_c());
} else {
min = FLT_MIN;
@ -68,17 +71,19 @@ bool ClipConvertActivationPass::Run(const FuncGraphPtr &graph) {
if (clip_cnode->size() > kClipMaxIndex) {
auto max_tensor_info = GetTensorInfo(clip_cnode->input(kClipMaxIndex));
MS_CHECK_TRUE_MSG(max_tensor_info != nullptr, false, "max_tensor_info is nullptr");
if (max_tensor_info->data_type() != mindspore::kNumberTypeFloat32) {
MS_LOG(ERROR) << "Clip param type invalid";
return false;
}
MS_CHECK_TRUE_MSG(max_tensor_info->data_c() != nullptr, false, "tensor data is nullptr");
max = *reinterpret_cast<float *>(max_tensor_info->data_c());
} else {
max = FLT_MAX;
}
}
auto manager = graph->manager();
MS_ASSERT(manager != nullptr);
auto primitive_c = std::make_shared<mindspore::ops::Activation>();
MS_CHECK_TRUE_MSG(primitive_c != nullptr, false, "primitive_c is nullptr");
primitive_c->Init(0, min, max, mindspore::RELU6);

View File

@ -50,7 +50,6 @@ void ControlFlowPass::VisitedNodesUsedByAfterParts(const std::set<AnfNodePtr> &v
std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg) {
std::deque<AnfNodePtr> nodes{};
std::set<AnfNodePtr> visited_nodes_used_by_after_fg_set{};
std::set<FuncGraphPtr> visited_fg_set{};
std::set<AnfNodePtr> remain_nodes_set{};
nodes.assign(remain_nodes.begin(), remain_nodes.end());
while (!nodes.empty()) {
@ -150,6 +149,7 @@ int ControlFlowPass::SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow
// notice: fg->nodes() is not work in this pass, cause too many useless parameter have been created.
auto node_list = TopoSort(fg->get_return());
for (auto &node : node_list) {
MS_ASSERT(node != nullptr);
if (utils::isa<CNodePtr>(node) &&
(CheckPrimitiveType(node, prim::kPrimWhile) || CheckPrimitiveType(node, prim::kPrimIf))) {
*control_flow_node = node;
@ -194,6 +194,7 @@ int ControlFlowPass::CreateAfterGraph(const FuncGraphPtr &main_fg, const std::ve
*after_fg = std::make_shared<FuncGraph>();
MS_CHECK_TRUE_MSG(*after_fg != nullptr, lite::RET_NULL_PTR, "*after_fg is nullptr");
auto manager = main_fg->manager();
MS_ASSERT(manager != nullptr);
manager->AddFuncGraph(*after_fg);
(*after_fg)->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
(*after_fg)->set_attr("graph_name", MakeValue(aim_cnode->fullname_with_scope() + "_after_fg"));
@ -244,7 +245,7 @@ int ControlFlowPass::CreateWhileCondCallNode(
auto origin_cond_fg_inputs = cond_fg->get_inputs();
for (auto &item : visited_nodes_used_by_after_fg) {
bool found = false;
size_t input_index = -1;
size_t input_index = 0;
for (size_t i = kPartialFirstInputSize; i < cond_partial_cnode_inputs.size(); ++i) {
if (cond_partial_cnode_inputs[i] == item) {
found = true;
@ -276,6 +277,7 @@ int ControlFlowPass::CreateWhileCondCallNode(
// insert call node
std::vector<AnfNodePtr> call_node_inputs{cond_partial_cnode};
*cond_call_cnode = fg->NewCNode(call_node_inputs);
MS_CHECK_TRUE_MSG(*cond_call_cnode != nullptr, lite::RET_NULL_PTR, "new cnode is nullptr");
(*cond_call_cnode)->set_fullname_with_scope("call_" + cond_partial_cnode->fullname_with_scope());
return RET_SUCCESS;
@ -284,7 +286,7 @@ int ControlFlowPass::CreateWhileCondCallNode(
int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, const CNodePtr &while_cnode,
CNodePtr *body_partial_node) {
auto body_vnode = while_cnode->input(kWhileBodyIndex);
MS_CHECK_TRUE_MSG(body_vnode != nullptr, lite::RET_NULL_PTR, "body_vnode is nullptr");
MS_CHECK_TRUE_MSG(body_vnode != nullptr, RET_FAILED, "body_vnode is nullptr");
auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode);
if (body_fg == nullptr) {
MS_LOG(ERROR) << "Get value as func_graph failed.";
@ -302,6 +304,7 @@ int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, con
auto cond_fg_inputs = cond_fg->get_inputs();
body_partial_node_inputs.insert(body_partial_node_inputs.end(), cond_fg_inputs.begin(), cond_fg_inputs.end());
*body_partial_node = cond_fg->NewCNode(body_partial_node_inputs);
MS_CHECK_TRUE_MSG(*body_partial_node != nullptr, RET_FAILED, "new cnode is nullptr");
(*body_partial_node)->set_fullname_with_scope("CNode_" + body_fg->get_attr("graph_name")->ToString());
// add after inputs for body fg to call cond fg
@ -312,9 +315,7 @@ int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, con
MS_LOG(ERROR) << "fg is not right.";
return RET_FAILED;
}
auto cond_fg_input_para = cond_fg_inputs[i]->cast<ParameterPtr>();
auto new_parameter = body_fg->add_parameter();
MS_ASSERT(cond_fg_input_para != nullptr);
MS_CHECK_TRUE_MSG(new_parameter != nullptr, lite::RET_NULL_PTR, "new_parameter is nullptr");
new_parameter->set_name(cond_fg_inputs[i]->fullname_with_scope() + "_body_fg_parameter");
new_parameter->set_abstract(cond_fg_inputs[i]->abstract());
@ -363,7 +364,7 @@ int ControlFlowPass::CreateWhileAfterPartialNode(
}
auto after_value_node = NewValueNode(after_fg);
MS_CHECK_TRUE_MSG(after_value_node != nullptr, lite::RET_NULL_PTR, "after_value_node is nullptr");
MS_CHECK_TRUE_MSG(after_value_node != nullptr, RET_FAILED, "after_value_node is nullptr");
ValueNodePtr partial_anf_primitive = lite::GetPartialFusionPrim();
if (partial_anf_primitive == nullptr) {
MS_LOG(ERROR) << "GetPartialFusionPrim failed.";
@ -397,7 +398,7 @@ int ControlFlowPass::CreateWhileAfterPartialNode(
after_partial_cnode_inputs.push_back(cond_fg_inputs.at(input_index));
auto new_parameter = after_fg->add_parameter();
MS_CHECK_TRUE_MSG(new_parameter != nullptr, lite::RET_NULL_PTR, "new_parameter != nullptr");
MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter != nullptr");
new_parameter->set_name(node->fullname_with_scope() + "_after_partial_parameter");
new_parameter->set_abstract(node->abstract());
after_partial_inputs_and_after_fg_inputs_replace_pairs[node] = new_parameter;
@ -407,7 +408,7 @@ int ControlFlowPass::CreateWhileAfterPartialNode(
for (auto &input : cond_nodes_used_by_after_partial) {
after_partial_cnode_inputs.push_back(visited_nodes_and_cond_fg_inputs_replace_pairs.at(input));
auto new_parameter = after_fg->add_parameter();
MS_CHECK_TRUE_MSG(new_parameter != nullptr, lite::RET_NULL_PTR, "new_parameter != nullptr");
MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter != nullptr");
new_parameter->set_name(input->fullname_with_scope() + "_after_fg_parameter");
new_parameter->set_abstract(input->abstract());
visited_nodes_after_fg_replace_pair[visited_nodes_and_cond_fg_inputs_replace_pairs.at(input)] = new_parameter;
@ -417,6 +418,7 @@ int ControlFlowPass::CreateWhileAfterPartialNode(
ReplaceNode(after_fg, after_partial_inputs_and_after_fg_inputs_replace_pairs);
ReplaceNode(after_fg, visited_nodes_after_fg_replace_pair);
*after_partial_cnode = cond_fg->NewCNode(after_partial_cnode_inputs);
MS_CHECK_TRUE_MSG(*after_partial_cnode != nullptr, RET_FAILED, "new cnode is nullptr");
(*after_partial_cnode)->set_fullname_with_scope("CNode_" + after_fg->get_attr("graph_name")->ToString());
return RET_SUCCESS;
}
@ -448,7 +450,9 @@ int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNo
return ret;
}
AnfNodePtr cond_fg_vnode = cond_call_cnode->input(kCNodePrimIndex)->cast<CNodePtr>()->input(kCNodeFirstInputIndex);
auto cond_fg_cnode = cond_call_cnode->input(kCNodePrimIndex)->cast<CNodePtr>();
MS_ASSERT(cond_fg_cnode != nullptr);
AnfNodePtr cond_fg_vnode = cond_fg_cnode->input(kCNodeFirstInputIndex);
MS_ASSERT(cond_fg_vnode != nullptr);
auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_fg_vnode);
MS_CHECK_TRUE_MSG(cond_fg != nullptr, RET_FAILED, "Get value as func_graph failed.");
@ -492,8 +496,9 @@ int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNo
fg->DropNode(while_cnode);
fg->set_output(cond_call_cnode);
auto after_fg =
after_partial_cnode->input(kCNodeFirstInputIndex)->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>();
auto after_cnode = after_partial_cnode->input(kCNodeFirstInputIndex)->cast<ValueNodePtr>();
MS_ASSERT(after_cnode != nullptr);
auto after_fg = after_cnode->value()->cast<FuncGraphPtr>();
if (after_fg == nullptr) {
MS_LOG(ERROR) << "after_fg is nullptr.";
return RET_FAILED;
@ -506,7 +511,9 @@ int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNo
int ControlFlowPass::CreateIfPartialNodeExternalInputs(const CNodePtr &if_cnode, const FuncGraphPtr &partial_fg,
std::vector<AnfNodePtr> *then_partial_cnode_inputs) {
auto if_inputs = if_cnode->inputs();
auto partial_fg_name = partial_fg->get_attr("graph_name")->ToString();
auto fg_name_attr = partial_fg->get_attr("graph_name");
MS_CHECK_TRUE_RET(fg_name_attr != nullptr, RET_FAILED);
auto partial_fg_name = fg_name_attr->ToString();
std::vector<AnfNodePtr> if_external_inputs{};
if_external_inputs.assign(if_inputs.begin() + kIfMinInputSize, if_inputs.end());
auto origin_then_fg_inputs = partial_fg->get_inputs();
@ -523,7 +530,13 @@ int ControlFlowPass::CreateIfPartialNodeExternalInputs(const CNodePtr &if_cnode,
auto pos = partial_fg_name.size() + sizeof("_input_");
auto pos2 = fg_input_name.find('_', pos);
auto idx_str = fg_input_name.substr(pos - 1, pos2 - pos + 1);
auto partial_idx = std::stoi(idx_str);
auto partial_idx = 0;
try {
partial_idx = std::stoi(idx_str);
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Get index failed: " << e.what();
return RET_FAILED;
}
then_partial_cnode_inputs->push_back(if_external_inputs.at(partial_idx));
}
}
@ -552,7 +565,7 @@ int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &i
auto origin_then_fg_inputs = then_fg->get_inputs();
for (auto &item : *visited_nodes_used_by_after_fg) {
bool found = false;
size_t input_index = -1;
size_t input_index = 0;
for (size_t i = kPartialFirstInputSize; i < then_partial_cnode_inputs.size(); ++i) {
if (then_partial_cnode_inputs[i] == item) {
found = true;
@ -580,7 +593,10 @@ int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &i
then_nodes_used_by_after_partial.push_back(new_parameter);
}
*then_partial_cnode = fg->NewCNode(then_partial_cnode_inputs);
auto then_fg_name = then_fg->get_attr("graph_name")->ToString();
MS_CHECK_TRUE_MSG(*then_partial_cnode != nullptr, RET_FAILED, "new cnode is nullptr");
auto fg_name_attr = then_fg->get_attr("graph_name");
MS_CHECK_TRUE_RET(fg_name_attr != nullptr, RET_FAILED);
auto then_fg_name = fg_name_attr->ToString();
(*then_partial_cnode)->set_fullname_with_scope("partial_" + then_fg_name);
// create after partial node
@ -764,6 +780,7 @@ int ControlFlowPass::ProcessControlOp(const FuncGraphPtr &fg) {
}
bool ControlFlowPass::Run(const FuncGraphPtr &fg) {
MS_ASSERT(fg != nullptr);
to_process_q.push_back(fg);
while (!to_process_q.empty()) {
auto cur_fg = to_process_q.front();

View File

@ -79,7 +79,7 @@ bool Conv1DWeightExpandingPass::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(weight_node != nullptr);
auto prim = GetValueNode<PrimitivePtr>(conv_cnode->input(0));
MS_CHECK_TRUE_MSG(prim != nullptr, RET_FAILED, "GetValueNode failed");
MS_CHECK_TRUE_MSG(prim != nullptr, false, "GetValueNode failed");
schema::Format schema_format = schema::Format::Format_KCHW;
if (prim->GetAttr(ops::kFormat) != nullptr) {
schema_format = static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(ops::kFormat)));
@ -91,6 +91,6 @@ bool Conv1DWeightExpandingPass::Run(const FuncGraphPtr &func_graph) {
return false;
}
}
return RET_OK;
return true;
}
} // namespace mindspore::opt

View File

@ -233,8 +233,8 @@ STATUS DecreaseTransposeAlgo::PostTransposeFusion(const FuncGraphPtr &func_graph
return lite::RET_OK;
}
STATUS DecreaseTransposeAlgo::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm,
bool before, size_t index) {
STATUS DecreaseTransposeAlgo::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> perm, bool before, size_t index) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
AnfNodePtr new_input = nullptr;
new_input = transpose_strategy_.TransposePairFuseWhenInsert(func_graph, cnode, perm, before, index);
@ -388,6 +388,7 @@ STATUS DecreaseTransposeAlgo::InsertPostTransNode(const FuncGraphPtr &func_graph
return lite::RET_ERROR;
} else {
tuple_get_item = GenTupleGetItemNode(func_graph, cnode, 0);
MS_CHECK_TRUE_RET(tuple_get_item != nullptr, lite::RET_ERROR);
post_node = tuple_get_item;
func_graph->manager()->Replace(cnode, tuple_get_item);
}
@ -489,7 +490,13 @@ int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGra
auto last_underline = node_name.find_last_of("_");
node_name = node_name.substr(0, last_underline);
last_underline = node_name.find_last_of("_");
auto index = std::stoi(node_name.substr(last_underline + 1)) + static_cast<int>(kInputSizeThree);
auto index = 0;
try {
index = std::stoi(node_name.substr(last_underline + 1)) + static_cast<int>(kInputSizeThree);
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Get index failed: " << e.what();
return lite::RET_ERROR;
}
param_node->set_abstract(GetCNodeInputAbstract(cnode, index)->Clone());
if (utils::isa<CNodePtr>(cnode->input(index))) {
ShapeVector shape_vec = {-1};

View File

@ -41,7 +41,7 @@ class DecreaseTransposeAlgo : public Pass {
private:
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> perm, bool before,
size_t index = 0);
bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph);
bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph);

View File

@ -62,7 +62,9 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) {
MS_LOG(ERROR) << "the node input is invalid.";
return false;
}
auto data_shape = utils::cast<abstract::ShapePtr>(data_node->GetShapeTrack())->shape();
auto data_shape_ptr = utils::cast<abstract::ShapePtr>(data_node->GetShapeTrack());
MS_ASSERT(data_shape_ptr != nullptr);
auto data_shape = data_shape_ptr->shape();
if (data_shape.empty()) {
MS_LOG(DEBUG) << "the tensor's shape is dynamic.";
return true;
@ -72,11 +74,15 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) {
MS_LOG(ERROR) << "the weight node input is invalid.";
return false;
}
auto weight_shape = utils::cast<abstract::ShapePtr>(weight_data_node->GetShapeTrack())->shape();
auto weight_shape_ptr = utils::cast<abstract::ShapePtr>(weight_data_node->GetShapeTrack());
MS_ASSERT(weight_shape_ptr != nullptr);
auto weight_shape = weight_shape_ptr->shape();
if (weight_shape.empty()) {
MS_LOG(DEBUG) << "the weight's shape is dynamic.";
return true;
}
MS_CHECK_TRUE_RET(data_shape.size() == DIMENSION_4D, false);
MS_CHECK_TRUE_RET(weight_shape.size() == DIMENSION_4D, false);
if (data_shape[3] == 1 || data_shape[3] != weight_shape[3]) {
conv2d_fusion->EraseAttr(ops::kIsDepthWise);
conv2d_fusion->set_group(static_cast<int64_t>(data_shape[3]));

View File

@ -50,7 +50,9 @@ int GetCNodeCertainInputFormat(const CNodePtr cnode, int index, mindspore::Forma
MS_LOG(ERROR) << "cnode has no format attr. " << real_cnode->fullname_with_scope();
return lite::RET_ERROR;
}
*format = static_cast<mindspore::Format>(GetValue<int64_t>(primitive->GetAttr(ops::kFormat)));
auto format_attr = primitive->GetAttr(ops::kFormat);
MS_CHECK_TRUE_MSG(format_attr != nullptr, lite::RET_NULL_PTR, "GetAttr Failed");
*format = static_cast<mindspore::Format>(GetValue<int64_t>(format_attr));
if (CheckPrimitiveType(real_cnode, prim::kPrimTranspose)) {
std::vector<int> perm;
if (GetTransposePerm(real_cnode, &perm) != lite::RET_OK) {
@ -184,7 +186,7 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
return lite::RET_ERROR;
}
auto ret = SetSubGraphInput(cnode, sub_func_graph);
if (ret != RET_OK) {
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "SetSubGraphInput failed: " << ret;
return lite::RET_ERROR;
}
@ -202,7 +204,7 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
return lite::RET_ERROR;
}
ret = SetSubGraphInput(cnode, sub_func_graph);
if (ret != RET_OK) {
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "SetSubGraphInput failed: " << ret;
return lite::RET_ERROR;
}
@ -215,7 +217,7 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
return lite::RET_ERROR;
}
ret = SetSubGraphAbstract(cnode, sub_func_graph);
if (ret != RET_OK) {
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "SetSubGraphAbstract failed: " << ret;
return lite::RET_ERROR;
}
@ -241,13 +243,20 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt
auto last_underline = node_name.find_last_of("_");
node_name = node_name.substr(0, last_underline);
last_underline = node_name.find_last_of("_");
auto index = std::stoi(node_name.substr(last_underline + 1)) + 3;
auto index = 0;
try {
index = std::stoi(node_name.substr(last_underline + 1)) + 3;
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Get index failed: " << e.what();
return RET_ERROR;
}
param_node->set_abstract(opt::GetCNodeInputAbstract(cnode, index)->Clone());
if (utils::isa<CNodePtr>(cnode->input(index))) {
ShapeVector shape_vec = {-1};
auto out_cnode = cnode->input(index)->cast<CNodePtr>();
MS_ASSERT(trans_cnode != nullptr);
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
MS_CHECK_TRUE_MSG(out_prim != nullptr, lite::RET_ERROR, "GetValueNode Failed");
if (out_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(opt::kInferDone))) {
auto abstract_shape = std::make_shared<abstract::Shape>(shape_vec);
CHECK_NULL_RETURN(abstract_shape);

View File

@ -34,7 +34,7 @@ namespace opt {
namespace {
constexpr int kInputChannal = 3;
constexpr size_t INITIAL_SIZE = 1024;
void RectifyFormat(const CNodePtr &cnode, const std::vector<lite::Tensor *> &inputs, FmkType fmk_type) {
void RectifyFormat(const std::vector<lite::Tensor *> &inputs, FmkType fmk_type) {
MS_ASSERT(cnode != nullptr);
if (fmk_type != converter::kFmkTypeOnnx) {
return;
@ -127,7 +127,7 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
fbb.Clear();
return lite::RET_ERROR;
}
RectifyFormat(cnode, inputs, fmk_type_);
RectifyFormat(inputs, fmk_type_);
ret = KernelInferShape(inputs, outputs, parameter);
if (parameter->destroy_func_ != nullptr) {
parameter->destroy_func_(parameter);
@ -208,7 +208,8 @@ std::vector<int> NodeInferShape::GetIntVecInput(const CNodePtr &cnode, size_t in
if (specify_tensors.front()->shape().size() != 1) {
return {};
}
tensor_data.resize(specify_tensors.front()->shape()[0]);
MS_CHECK_GE(specify_tensors.front()->shape()[0], 0, {});
tensor_data.resize(static_cast<size_t>(specify_tensors.front()->shape()[0]));
if (memcpy_s(tensor_data.data(), tensor_data.size() * sizeof(int), specify_tensors.front()->data(),
specify_tensors.front()->Size()) != EOK) {
return {};

View File

@ -264,6 +264,27 @@ int RemoveRedundantOpPass::RemoveDropoutOp(const AnfNodePtr &anf_node, const Fun
return lite::RET_OK;
}
int RemoveRedundantOpPass::GetConstDataFromInputNode(const CNodePtr &cnode, lite::DataInfo *data_info) {
MS_ASSERT(cnode != nullptr);
MS_ASSERT(data_info != nullptr);
auto padding_node = cnode->input(kInputIndexTwo);
MS_ASSERT(padding_node != nullptr);
if (utils::isa<Parameter>(padding_node)) {
auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::kFmkTypeMs, false, data_info);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "fetch data from parameter node failed.";
return lite::RET_ERROR;
}
} else if (utils::isa<ValueNode>(padding_node)) {
auto status = lite::FetchDataFromValueNode(cnode, 2, converter::kFmkTypeMs, false, data_info);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "fetch data from value node failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
if (!utils::isa<CNodePtr>(anf_node)) {
MS_LOG(DEBUG) << "anf node is node a cnode.";
@ -278,20 +299,10 @@ int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const
}
auto is_invalid = true;
if (cnode->size() > kInputSizeTwo) {
auto padding_node = cnode->input(kInputIndexTwo);
lite::DataInfo data_info;
if (utils::isa<Parameter>(padding_node)) {
auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::kFmkTypeMs, false, &data_info);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "fetch data from parameter node failed.";
return lite::RET_ERROR;
}
} else if (utils::isa<ValueNode>(padding_node)) {
auto status = lite::FetchDataFromValueNode(cnode, 2, converter::kFmkTypeMs, false, &data_info);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "fetch data from value node failed.";
return lite::RET_ERROR;
}
if (GetConstDataFromInputNode(cnode, &data_info) != RET_OK) {
MS_LOG(ERROR) << "Get pad data failed.";
return lite::RET_ERROR;
}
if (!data_info.data_.empty()) {
auto pad_data = reinterpret_cast<int *>(data_info.data_.data());
@ -308,19 +319,18 @@ int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const
} else {
auto pad_prim = utils::cast<std::shared_ptr<mindspore::ops::PadFusion>>(primitive);
MS_ASSERT(pad_prim != nullptr);
if (pad_prim->GetAttr(ops::kPadding) != nullptr) {
auto pad_data = pad_prim->get_paddings();
for (size_t i = 0; i < pad_data.size(); i++) {
for (size_t j = 0; j < pad_data[i].size(); j++) {
if (pad_data[i][j] != 0) {
is_invalid = false;
break;
}
}
if (is_invalid == false) {
MS_CHECK_TRUE_RET(pad_prim->GetAttr(ops::kPadding) != nullptr, lite::RET_ERROR);
auto pad_data = pad_prim->get_paddings();
for (size_t i = 0; i < pad_data.size(); i++) {
for (size_t j = 0; j < pad_data[i].size(); j++) {
if (pad_data[i][j] != 0) {
is_invalid = false;
break;
}
}
if (is_invalid == false) {
break;
}
}
}
if (is_invalid) {

View File

@ -21,6 +21,7 @@
#include "backend/optimizer/common/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/anf_exporter/fetch_content.h"
using mindspore::converter::FmkType;
namespace mindspore::opt {
@ -38,6 +39,7 @@ class RemoveRedundantOpPass : public Pass {
bool Run(const FuncGraphPtr &graph) override;
private:
int GetConstDataFromInputNode(const CNodePtr &cnode, lite::DataInfo *data_info);
bool is_train_model_ = false;
std::set<AnfNodePtr> remove_cnode_;
};

View File

@ -222,7 +222,7 @@ STATUS SlicePreposePass::SwapSliceWithPreceed(const FuncGraphPtr &graph, const C
return RET_OK;
}
ValueNodePtr SlicePreposePass::CreateSliceValueNode(const FuncGraphPtr &graph, const std::vector<int64_t> &axes) {
ValueNodePtr SlicePreposePass::CreateSliceValueNode(const std::vector<int64_t> &axes) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(slice_cnode != nullptr);
auto new_slice = std::make_shared<mindspore::ops::SliceFusion>();
@ -233,7 +233,7 @@ ValueNodePtr SlicePreposePass::CreateSliceValueNode(const FuncGraphPtr &graph, c
return value_node;
}
ValueNodePtr SlicePreposePass::CopySliceValueNode(const FuncGraphPtr &graph, const CNodePtr &slice_cnode) {
ValueNodePtr SlicePreposePass::CopySliceValueNode(const CNodePtr &slice_cnode) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(slice_cnode != nullptr);
auto slice_c = GetValueNode<std::shared_ptr<mindspore::ops::SliceFusion>>(slice_cnode->input(0));
@ -400,7 +400,7 @@ CNodePtr SlicePreposePass::CreateReshapeCNode(const FuncGraphPtr &graph, const s
return reshape_cnode;
}
bool SlicePreposePass::SiblingsAreSameSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &output_node_list,
bool SlicePreposePass::SiblingsAreSameSlice(const NodeUsedListPtr &output_node_list,
const std::vector<int64_t> &ref_shape) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(output_node_list != nullptr);
@ -512,15 +512,15 @@ int64_t SlicePreposePass::GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode
MS_ASSERT(shape_out_copy != nullptr);
MS_ASSERT(is_normal_mode != nullptr);
MS_ASSERT(support_abnormal_mode != nullptr);
int64_t abnormal_index_out = -1;
auto slice_node = GetSlice(slice_cnode);
if (slice_node == nullptr) {
MS_LOG(ERROR) << "slice is nullptr";
return false;
return abnormal_index_out;
}
auto slice_axes = slice_node->get_axes();
auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
int64_t abnormal_index_out = -1;
for (size_t j = 0; j < shape_out.size(); ++j) {
int index = -1;
for (size_t i = 0; i < slice_axes.size(); ++i) {
@ -532,8 +532,8 @@ int64_t SlicePreposePass::GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode
if (index == -1) {
continue;
}
MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > index, false, "slice_begin.size() is wrong");
MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > index, false, "slice_size.size() is wrong");
MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > index, abnormal_index_out, "slice_begin.size() is wrong");
MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > index, abnormal_index_out, "slice_size.size() is wrong");
if (slice_begin[index] != 0 || (slice_size[index] != -1 && slice_size[index] != shape_out[j])) {
if (mapped_axe[j] == -1) {
if (is_normal_mode) {
@ -634,7 +634,7 @@ CNodePtr SlicePreposePass::CreateSlice1ForReshapePrepose(const FuncGraphPtr &gra
} else {
new_size1[abnormal_axe_in] = static_cast<int>(shape_in[abnormal_axe_in] - count_sliced_axe_in);
}
auto new_slice1 = CreateSliceValueNode(graph, new_axes1);
auto new_slice1 = CreateSliceValueNode(new_axes1);
if (new_slice1 == nullptr) {
MS_LOG(ERROR) << "CreateSliceValueNode failed";
return nullptr;
@ -660,8 +660,7 @@ CNodePtr SlicePreposePass::CreateSlice1ForReshapePrepose(const FuncGraphPtr &gra
CNodePtr SlicePreposePass::CreateSlice2ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
const CNodePtr &new_reshape1_cnode,
const std::vector<int64_t> &new_shape1,
const int64_t abnormal_axe_in,
const int64_t count_sliced_axe_in, const int64_t count_sliced2,
const int64_t abnormal_axe_in, const int64_t count_sliced2,
const bool slice_at_front) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(slice_cnode != nullptr);
@ -679,7 +678,7 @@ CNodePtr SlicePreposePass::CreateSlice2ForReshapePrepose(const FuncGraphPtr &gra
} else {
new_size2[abnormal_axe_in] = static_cast<int>(count_sliced2);
}
auto new_slice2 = CreateSliceValueNode(graph, new_axes2);
auto new_slice2 = CreateSliceValueNode(new_axes2);
if (new_slice2 == nullptr) {
MS_LOG(ERROR) << "CreateSliceValueNode failed";
return nullptr;
@ -782,9 +781,8 @@ bool SlicePreposePass::PreposeWithAbnormalReshape(const FuncGraphPtr &graph, con
const int64_t count_sliced_abnormal_axe =
shape_out[abnormal_axe_out] - (count_sliced_axe_front + count_sliced_axe_rear);
const int64_t count_sliced2 = count_sliced_abnormal_axe * outer_size_out;
auto new_slice2_cnode =
CreateSlice2ForReshapePrepose(graph, slice_cnode, new_reshape1_cnode, new_shape1, abnormal_axe_in,
count_sliced_axe_in, count_sliced2, slice_at_front);
auto new_slice2_cnode = CreateSlice2ForReshapePrepose(graph, slice_cnode, new_reshape1_cnode, new_shape1,
abnormal_axe_in, count_sliced2, slice_at_front);
if (new_slice2_cnode == nullptr) {
return false;
}
@ -977,7 +975,7 @@ bool SlicePreposePass::PreposeWithMatmul(const FuncGraphPtr &graph, const CNodeP
return false;
}
auto slice_node = GetSlice(slice_cnode);
MS_CHECK_TRUE_MSG(slice_node != nullptr, RET_ERROR, "slice is nullptr");
MS_CHECK_TRUE_MSG(slice_node != nullptr, false, "slice is nullptr");
auto axes = slice_node->get_axes();
auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
auto size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
@ -1012,7 +1010,7 @@ bool SlicePreposePass::PreposeWithMatmul(const FuncGraphPtr &graph, const CNodeP
left_size[i] = -1;
}
}
auto left_slice_vnode = CreateSliceValueNode(graph, left_axes);
auto left_slice_vnode = CreateSliceValueNode(left_axes);
MS_CHECK_TRUE_MSG(left_slice_vnode != nullptr, false, "CreateSliceValueNode failed");
auto begin_parameter = BuildIntVecParameterNode(
graph, left_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
@ -1050,7 +1048,7 @@ bool SlicePreposePass::PreposeWithMatmul(const FuncGraphPtr &graph, const CNodeP
MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
node_name_index += 1;
auto right_slice_vnode = CreateSliceValueNode(graph, right_axes);
auto right_slice_vnode = CreateSliceValueNode(right_axes);
MS_CHECK_TRUE_MSG(right_slice_vnode != nullptr, false, "CreateSliceValueNode failed");
const std::vector<AnfNodePtr> inputs = {right_slice_vnode, matmul_cnode->input(2), begin_parameter, size_parameter};
auto new_slice_cnode = InsertSlice(graph, inputs, matmul_cnode, 2, tr);
@ -1094,7 +1092,7 @@ bool SlicePreposePass::PreposeWithFullConnection(const FuncGraphPtr &graph, cons
auto slice_node = GetSlice(slice_cnode);
if (slice_node == nullptr) {
MS_LOG(ERROR) << "slice is nullptr";
return RET_ERROR;
return false;
}
auto axes = slice_node->get_axes();
auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
@ -1134,7 +1132,7 @@ bool SlicePreposePass::PreposeWithFullConnection(const FuncGraphPtr &graph, cons
std::vector<int> new_size(shape_in.size(), -1);
new_begin[mapped_axe[0]] = begin[0];
new_size[mapped_axe[0]] = size[0];
auto new_slice_vnode = CreateSliceValueNode(graph, new_axes);
auto new_slice_vnode = CreateSliceValueNode(new_axes);
if (new_slice_vnode == nullptr) {
MS_LOG(ERROR) << "CreateSliceValueNode failed";
return false;
@ -1267,7 +1265,7 @@ bool SlicePreposePass::PreposeWithArithmetic(const FuncGraphPtr &graph, const CN
continue;
} else if (shape.empty()) { // infershape failed at this input
if (IsScalarNode(another_input)) { // if another input is scalar, we can process this one
auto new_slice_vnode = CopySliceValueNode(graph, slice_cnode);
auto new_slice_vnode = CopySliceValueNode(slice_cnode);
if (new_slice_vnode == nullptr) {
changed = false;
break;
@ -1299,7 +1297,7 @@ bool SlicePreposePass::PreposeWithArithmetic(const FuncGraphPtr &graph, const CN
changed = false;
break;
}
auto new_slice_vnode = CreateSliceValueNode(graph, new_axes);
auto new_slice_vnode = CreateSliceValueNode(new_axes);
if (new_slice_vnode == nullptr) {
changed = false;
break;
@ -1525,7 +1523,7 @@ bool SlicePreposePass::Run(const FuncGraphPtr &graph) {
}
auto output_node_list = GetRealNodeUsedList(graph, utils::cast<AnfNodePtr>(preceed_node));
if (output_node_list->size() > 1) { // referenced by multi nodes
if (SiblingsAreSameSlice(graph, output_node_list) && MergeParallelSlice(graph, output_node_list)) {
if (SiblingsAreSameSlice(output_node_list) && MergeParallelSlice(graph, output_node_list)) {
this_time_changed = true;
break;
}

View File

@ -43,8 +43,8 @@ class SlicePreposePass : public Pass {
static void ClearCNodeAbstractValue(const CNodePtr &cnode);
static STATUS SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
const CNodePtr &preceed_cnode, int index, const TransactionPtr &tr = nullptr);
static ValueNodePtr CreateSliceValueNode(const FuncGraphPtr &graph, const std::vector<int64_t> &axes);
static ValueNodePtr CopySliceValueNode(const FuncGraphPtr &graph, const CNodePtr &slice_cnode);
static ValueNodePtr CreateSliceValueNode(const std::vector<int64_t> &axes);
static ValueNodePtr CopySliceValueNode(const CNodePtr &slice_cnode);
static CNodePtr InsertSlice(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs,
const CNodePtr &preceed_cnode, int index, const TransactionPtr &tr);
static STATUS VerifySliceAttrs(const CNodePtr &slice_cnode, int dim = -1);
@ -52,8 +52,7 @@ class SlicePreposePass : public Pass {
std::vector<int64_t> *axes, std::vector<int> *begin, std::vector<int> *size);
static CNodePtr CreateReshapeCNode(const FuncGraphPtr &graph, const std::vector<int64_t> &shape,
const AbstractBasePtr &abstract, const CNodePtr &preceed_cnode);
static bool SiblingsAreSameSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &output_node_list,
const std::vector<int64_t> &ref_shape = {});
static bool SiblingsAreSameSlice(const NodeUsedListPtr &output_node_list, const std::vector<int64_t> &ref_shape = {});
static int64_t GetReshapeAbnormalAxeIn(const std::vector<int64_t> &shape_in, const std::vector<int64_t> &shape_out,
std::vector<int64_t> *mapped_axe);
static int64_t GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode, const std::vector<int64_t> &mapped_axe,
@ -70,8 +69,7 @@ class SlicePreposePass : public Pass {
static CNodePtr CreateSlice2ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
const CNodePtr &new_reshape1_cnode,
const std::vector<int64_t> &new_shape1, int64_t abnormal_axe_in,
int64_t count_sliced_axe_in, int64_t count_sliced2,
bool slice_at_front);
int64_t count_sliced2, bool slice_at_front);
static bool PreposeWithAbnormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
const CNodePtr &reshape_cnode, const CNodePtr &matmul_cnode,
const std::vector<int64_t> &shape_in, const std::vector<int64_t> &shape_out,

View File

@ -136,6 +136,7 @@ STATUS ChangeCommonOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode, For
if (prim->GetAttr(ops::kAxis) == nullptr) {
return lite::RET_NOT_SUPPORT;
}
MS_CHECK_TRUE_MSG(prim->GetAttr(ops::kAxis) != nullptr, lite::RET_NULL_PTR, "GetAttr Failed.");
auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
if (axis < 0) {
axis += kInputSizeFour;
@ -160,17 +161,21 @@ STATUS ChangeOpCrop(const FuncGraphPtr &func_graph, const CNodePtr &cnode, Forma
MS_LOG(ERROR) << "cnode is invalid.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_RET(crop_prim->GetAttr(ops::kAxis) != nullptr, lite::RET_ERROR);
auto axis = crop_prim->get_axis();
if (axis < 0) {
axis += kInputSizeFour;
}
MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
MS_CHECK_TRUE_RET(crop_prim->GetAttr(ops::kOffsets) != nullptr, lite::RET_ERROR);
auto offsets = crop_prim->get_offsets();
if (trans_type == kNCHW2NHWC) {
auto new_axis = kNH2NC[axis];
if (new_axis == 0) {
MS_CHECK_GE(offsets.size(), kInputIndexFour, lite::RET_ERROR);
offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]};
} else if (new_axis == kInputIndexThree) {
MS_CHECK_GE(offsets.size(), kInputIndexThree, lite::RET_ERROR);
offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]};
} else {
offsets.push_back(0);
@ -238,7 +243,10 @@ STATUS ChangeOpPad(const FuncGraphPtr &func_graph, const CNodePtr &cnode, Format
}
auto param_node =
BuildIntVec2DParameterNode(func_graph, padding_list, cnode->input(kInputIndexTwo)->fullname_with_scope());
func_graph->manager()->Replace(cnode->input(kInputIndexTwo), param_node);
MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_NULL_PTR, "BuildParameterNode Failed");
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
manager->Replace(cnode->input(kInputIndexTwo), param_node);
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
if (prim->GetAttr(ops::kPaddings) != nullptr) {
@ -280,7 +288,10 @@ STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, Form
[](int64_t v) { return static_cast<int>(v); });
}
for (size_t i = 2; i < cnode->size(); ++i) {
TransformAttrByAxes(func_graph, cnode, i, axes, trans_type, node_infer_shape);
if (TransformAttrByAxes(func_graph, cnode, i, axes, trans_type, node_infer_shape) != RET_OK) {
MS_LOG(ERROR) << "Transform axes failed.";
return RET_ERROR;
}
}
auto tmp_axes = TransformOpAxesAttr(axes, trans_type);
std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end());
@ -312,13 +323,18 @@ STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode
if (index == kInputIndexFour) {
continue;
}
TransformAttrByAxes(func_graph, cnode, index, axes, trans_type, node_infer_shape);
if (TransformAttrByAxes(func_graph, cnode, index, axes, trans_type, node_infer_shape) != RET_OK) {
MS_LOG(ERROR) << "transform axes failed.";
return lite::RET_ERROR;
}
}
auto cur_axes = TransformOpAxesAttr(axes, trans_type);
auto param_node =
BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope());
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "BuildIntVecParameterNode failed");
func_graph->manager()->Replace(cnode->input(kInputIndexFour), param_node);
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
manager->Replace(cnode->input(kInputIndexFour), param_node);
return lite::RET_OK;
}
} // namespace
@ -481,6 +497,7 @@ STATUS TransposeStrategy::TransposeInsertDependOnShape(const FuncGraphPtr &func_
return lite::RET_ERROR;
}
CNodePtr base_node = before ? cnode : node_users.front().first->cast<CNodePtr>();
MS_ASSERT(base_node != nullptr);
size_t input_index = before ? index : node_users.front().second;
auto shape = node_infer_shape_.GetInputShape(base_node, input_index);
if (!shape.empty() && shape.size() != kNH2NC.size()) {
@ -522,7 +539,7 @@ bool TransposeStrategy::IsInOutCanFuison(const FuncGraphPtr &func_graph, const s
return true;
}
void TransposeStrategy::DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info) {
void TransposeStrategy::DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info) const {
if (trans_info->pre_ == trans_info->post_) {
return;
}

View File

@ -50,7 +50,7 @@ class TransposeStrategy {
STATUS TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool before, size_t index);
bool IsInOutCanFuison(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, size_t *trans_count,
FormatTransNodeType *trans_type);
void DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info);
void DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info) const;
FmkType fmk_type_{converter::kFmkTypeMs};
bool train_flag_{false};
NodeInferShape node_infer_shape_;

View File

@ -24,7 +24,7 @@ void RemoveUnusedCastOpPass::SetFmkType(FmkType type) { this->fmk_type = type; }
bool RemoveUnusedCastOpPass::Run(const FuncGraphPtr &func_graph) {
if (this->fmk_type != converter::kFmkTypeMs) {
MS_LOG(ERROR) << "The framework type of model should be mindspore.";
return RET_ERROR;
return false;
}
MS_ASSERT(func_graph != nullptr);
auto manager = func_graph->manager();
@ -42,12 +42,12 @@ bool RemoveUnusedCastOpPass::Run(const FuncGraphPtr &func_graph) {
auto abstract_base = cast_cnode->input(1)->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << cast_cnode->input(1)->fullname_with_scope();
return RET_ERROR;
return false;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, "
<< cast_cnode->input(1)->fullname_with_scope();
return RET_ERROR;
return false;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
auto input_type = abstract_tensor->element()->GetTypeTrack();
@ -56,12 +56,12 @@ bool RemoveUnusedCastOpPass::Run(const FuncGraphPtr &func_graph) {
if (cast_cnode->inputs().size() != kCastInputNum || !utils::isa<ValueNodePtr>(cast_cnode->input(2))) {
MS_LOG(ERROR) << "Second input of cast should be a ValueNode";
return RET_ERROR;
return false;
}
auto output_type = GetValueNode<NumberPtr>(cast_cnode->input(2));
if (output_type == nullptr) {
MS_LOG(ERROR) << "Second input of cast is nullptr";
return RET_ERROR;
return false;
}
auto output_type_value = output_type->type_id();
if ((input_type_value == kNumberTypeFloat32 && output_type_value == kNumberTypeFloat16) ||

View File

@ -19,6 +19,7 @@
#include "ops/transpose.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "include/errorcode.h"
#include "nnacl/op_base.h"
namespace mindspore::opt {
constexpr size_t kTransposeInput = 1;
@ -60,7 +61,7 @@ std::vector<int> GetTransposePerm(const CNodePtr &node) {
bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) {
if (this->fmk_type != converter::kFmkTypeOnnx) {
MS_LOG(ERROR) << "The framework type of model should be onnx.";
return RET_ERROR;
return false;
}
MS_ASSERT(func_graph != nullptr);
auto manager = func_graph->manager();

View File

@ -67,6 +67,7 @@ STATUS UpdateConv2DParamPass::UpdateConv2DAttr(const CNodePtr &cnode) {
return lite::RET_ERROR;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
if (prim->GetAttr(ops::kFormat) == nullptr) {
MS_LOG(ERROR) << "current conv2d's format is undefined.";
return lite::RET_ERROR;
@ -85,6 +86,7 @@ STATUS UpdateConv2DParamPass::UpdateConv2DAttr(const CNodePtr &cnode) {
prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
prim->AddAttr(ops::kGroup, MakeValue(is_depth_wise ? out_channel : 1));
}
MS_ASSERT(prim->GetAttr(ops::kGroup) != nullptr);
auto group = GetValue<int64_t>(prim->GetAttr(ops::kGroup));
if (CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
std::swap(in_channel, out_channel);
@ -101,8 +103,6 @@ STATUS UpdateConv2DParamPass::UpdateConv2DAttr(const CNodePtr &cnode) {
bool UpdateConv2DParamPass::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {

View File

@ -92,6 +92,7 @@ int Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
int Conv2DInfo::CheckIfSplit() {
auto conv_prim = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(cnode_->input(kAnfPrimitiveIndex));
MS_ASSERT(conv_prim != nullptr);
auto strides = conv_prim->get_stride();
std::vector<int64_t> weight_shape;
std::vector<int64_t> input_shape;
@ -99,7 +100,9 @@ int Conv2DInfo::CheckIfSplit() {
// for n, h, cin, we should checkout it's input whether bigger than split total ratio
if (split_mode_ != SplitCOUT) {
auto input_node_abstract = GetCNodeInputAbstract(cnode_, 1);
MS_CHECK_TRUE_RET(input_node_abstract != nullptr, RET_ERROR);
auto weight_node_abstract = GetCNodeInputAbstract(cnode_, 2);
MS_CHECK_TRUE_RET(weight_node_abstract != nullptr, RET_ERROR);
if (!utils::isa<abstract::AbstractTensorPtr>(input_node_abstract)) {
MS_LOG(ERROR) << "conv_input_abstract of should be abstract tensor";
return RET_ERROR;
@ -163,8 +166,11 @@ AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t in
MS_CHECK_GE(conv_prim->get_stride().size(), 1, nullptr);
auto extend_bottom = conv_prim->get_kernel_size().at(kIndexH) - conv_prim->get_stride().at(kIndexH);
auto bottom_vector = std::vector<int64_t>(split_num, extend_bottom);
MS_CHECK_GE(split_num, 1, nullptr);
bottom_vector[split_num - 1] = 0;
split_prim->set_extend_bottom(bottom_vector);
MS_CHECK_GE(conv_prim->get_pad_list().size(), 1, nullptr);
MS_CHECK_TRUE_RET(input_shape.size() == DIMENSION_4D, nullptr);
if (!UpdateRatioWithPadStride(new_splits.data(), new_splits.size(), split_num, input_shape[split_dim],
conv_prim->get_pad_list().at(kPadUp), conv_prim->get_stride().at(kIndexH))) {
MS_LOG(ERROR) << "UpdateRatioWithPadStride failed";
@ -178,7 +184,9 @@ AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t in
split_prim->set_number_split(split_num);
split_prim->set_ratio(new_splits);
std::vector<AnfNodePtr> split_inputs = {NewValueNode(split_prim)};
auto split_primitive = NewValueNode(split_prim);
MS_CHECK_TRUE_MSG(split_primitive != nullptr, nullptr, "create SplitWithOverlap return nullptr");
std::vector<AnfNodePtr> split_inputs = {split_primitive};
// ori_conv_node must only have one input
split_inputs.push_back(orig_node->input(input_index + 1));
auto split_cnode = func_graph_->NewCNode(split_inputs);
@ -225,7 +233,6 @@ int Conv2DInfo::InferParallelCNodes() {
if (CheckIfSplit() != RET_OK) {
return RET_ERROR;
}
Strategys strategys = strategy_.strategys;
size_t dev_num = strategy_.dev_num;
std::vector<AnfNodePtr> feature_split_outputs;
std::vector<AnfNodePtr> kernel_split_outputs;
@ -361,7 +368,7 @@ int Conv2DInfo::InferReplaceOp() {
size_t dev_num = strategy_.dev_num;
if (split_mode_ == SplitCIN) {
MS_LOG(DEBUG) << name_ << " : Split Cin, infer Forward op.";
replace_op_ = CreateReduceNode(cnode_, parallel_output_nodes_, kAxisCIn, dev_num, true);
replace_op_ = CreateReduceNode(cnode_, parallel_output_nodes_, kAxisCIn, dev_num);
} else {
int32_t concat_dim;
if (split_mode_ == SplitN) {
@ -372,7 +379,7 @@ int Conv2DInfo::InferReplaceOp() {
} else {
concat_dim = kAxisH;
}
replace_op_ = CreateConcateNode(cnode_, parallel_output_nodes_, concat_dim, dev_num, true);
replace_op_ = CreateConcateNode(cnode_, parallel_output_nodes_, concat_dim, dev_num);
}
if (replace_op_ == nullptr) {

View File

@ -67,6 +67,11 @@ void CreateSplitConstantTensors(const tensor::TensorPtr &constant_tensor, const
for (int64_t i = 0; i < split_num; i++) {
// init shape for [split_dim]
visited_block += splits[i];
if (total_block_count == 0) {
MS_LOG(ERROR) << "divisor is zero";
split_constant_tensors->clear();
return;
}
auto cur_shape = UP_DIV(split_dim_size * visited_block, total_block_count);
split_constant_shapes.at(i).at(split_dim) = cur_shape;
auto tensor = std::make_shared<tensor::Tensor>(weight_type_id, split_constant_shapes.at(i));
@ -491,7 +496,7 @@ int DepthwiseConv2DInfo::InferParallelCNodes() {
int DepthwiseConv2DInfo::InferReplaceOp() {
size_t dev_num = strategy_.dev_num;
replace_op_ = CreateConcateNode(cnode_, parallel_output_nodes_, split_dim_, dev_num, true);
replace_op_ = CreateConcateNode(cnode_, parallel_output_nodes_, split_dim_, dev_num);
if (replace_op_ == nullptr) {
return RET_ERROR;
}

View File

@ -80,10 +80,8 @@ bool MultiConvSplit::CheckSplitValid() {
if (i >= static_cast<int64_t>(final_ratios.size()) || final_ratios.at(i) <= 0) {
return false;
}
MS_CHECK_INT_ADD_NOT_OVERFLOW(total_block_count, final_ratios.at(i), false);
total_block_count += final_ratios.at(i);
if (i == 0) {
MS_CHECK_INT_ADD_NOT_OVERFLOW(visited_block, final_ratios.at(i), false);
visited_block += final_ratios.at(i);
}
}
@ -122,7 +120,8 @@ int MultiConvSplit::GetMultiConvNodes(const AnfNodePtr &conv_node) {
while (index < split_info_.in_num_conv - 1) {
MS_CHECK_LT(index, static_cast<int32_t>(conv_nodes_.size()), RET_ERROR);
auto curr_node = conv_nodes_[index];
auto curr_cnode = conv_nodes_[index]->cast<CNodePtr>();
MS_ASSERT(curr_node != nullptr);
auto curr_cnode = curr_node->cast<CNodePtr>();
MS_CHECK_TRUE_RET(curr_cnode != nullptr, RET_ERROR);
auto tmp_node = curr_cnode->input(1);
if (!IsConv2D(tmp_node)) {
@ -198,7 +197,7 @@ bool MultiConvSplit::SplitSingleConv(const AnfNodePtr &ori_node, const std::vect
// node inputs
std::vector<AnfNodePtr> conv_inputs;
conv_inputs.push_back(NewValueNode(conv_prim));
AdJustInputs(ori_node, inputs_node, weight_nodes, bias_nodes, output_conv_index, &conv_inputs);
AdJustInputs(ori_node, inputs_node, output_conv_index, &conv_inputs);
// create new conv node
if (!CreateNewConvNode(ori_node, conv_inputs, output_conv_index, outputs_node)) {
return false;
@ -208,7 +207,6 @@ bool MultiConvSplit::SplitSingleConv(const AnfNodePtr &ori_node, const std::vect
}
void MultiConvSplit::AdJustInputs(const AnfNodePtr &ori_conv_node, const std::vector<AnfNodePtr> &new_inputs_node,
const std::vector<AnfNodePtr> &weight_node, const std::vector<AnfNodePtr> &bias_nodes,
int output_conv_index, std::vector<AnfNodePtr> *conv_inputs) {
MS_ASSERT(ori_conv_node != nullptr && conv_inputs != nullptr);
auto ori_conv_cnode = ori_conv_node->cast<CNodePtr>();

View File

@ -41,7 +41,6 @@ class MultiConvSplit : public MultiNodeSplit {
virtual AnfNodePtr MultiConvNHSplit(const AnfNodePtr &node);
virtual void AdJustInputs(const AnfNodePtr &ori_node, const std::vector<AnfNodePtr> &new_inputs_node,
const std::vector<AnfNodePtr> &weight_node, const std::vector<AnfNodePtr> &bias_nodes,
int output_conv_index, std::vector<AnfNodePtr> *conv_inputs);
virtual bool CreateNewConvNode(const AnfNodePtr &ori_conv_node, const std::vector<AnfNodePtr> &conv_inputs,

View File

@ -42,10 +42,7 @@ int MultiNodeSplitProxy::InitResource() {
return RET_OK;
}
int MultiNodeSplitProxy::FreeResource() {
multi_node_split_ = nullptr;
return RET_OK;
}
void MultiNodeSplitProxy::FreeResource() { multi_node_split_ = nullptr; }
AnfNodePtr MultiNodeSplitProxy::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
@ -55,10 +52,7 @@ AnfNodePtr MultiNodeSplitProxy::DoSplit(const FuncGraphPtr &func_graph, const An
return node;
}
auto res_node = multi_node_split_->DoSplit(func_graph, node);
ret = FreeResource();
if (ret != RET_OK) {
return node;
}
FreeResource();
return res_node;
}

View File

@ -46,7 +46,7 @@ class MultiNodeSplitProxy : public MultiNodeSplit {
private:
int InitResource();
int FreeResource();
void FreeResource();
private:
SplitMode split_mode_{NoSplit};

View File

@ -35,7 +35,7 @@ bool is_any_not_none(const std::vector<int64_t> &split) {
return std::any_of(split.begin(), split.end(), [](int64_t v) { return v != static_cast<int64_t>(NoSplit); });
}
std::shared_ptr<abstract::AbstractTensor> OperatorInfo::CreateFakeAbstractTensor() {
std::shared_ptr<abstract::AbstractTensor> OperatorInfo::CreateFakeAbstractTensor() const {
auto type_ptr = TypeIdToType(operator_type_id_);
std::vector<int64_t> shape_vector;
return std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
@ -135,7 +135,7 @@ int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t
}
AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
int32_t concat_dim, size_t input_nodes_num, bool trans_format) {
int32_t concat_dim, size_t input_nodes_num) {
MS_EXCEPTION_IF_NULL(orig_node);
if (input_nodes.size() != input_nodes_num) {
MS_LOG(ERROR) << name_ << " : Input nodes size of concat is not equal to input nodes number.";
@ -162,7 +162,7 @@ AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std:
}
AnfNodePtr OperatorInfo::CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
int32_t reduce_dim, size_t input_nodes_num, bool trans_format) {
int32_t reduce_dim, size_t input_nodes_num) {
MS_EXCEPTION_IF_NULL(orig_node);
if (input_nodes.size() != input_nodes_num) {
MS_LOG(ERROR) << name_ << " : Input nodes size of reduce is not equal to input nodes number.";

View File

@ -65,11 +65,11 @@ class OperatorInfo {
int CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, std::vector<AnfNodePtr> *outputs);
AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
int32_t concat_dim, size_t input_nodes_num, bool trans_format);
int32_t concat_dim, size_t input_nodes_num);
AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes, int32_t reduce_dim,
size_t input_nodes_num, bool trans_format);
size_t input_nodes_num);
std::shared_ptr<abstract::AbstractTensor> CreateFakeAbstractTensor();
std::shared_ptr<abstract::AbstractTensor> CreateFakeAbstractTensor() const;
virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &input_node, size_t input_index,
std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,

View File

@ -37,16 +37,15 @@ OperatorInfoFactory *OperatorInfoFactory::GeInstance() {
return &factory;
}
int OperatorInfoFactory::RegisterOperatorInfo(schema::PrimitiveType operator_type, TypeId type_id, bool is_depth_wise,
const OperatorInfoCreatorFunc &creator_func) {
void OperatorInfoFactory::RegisterOperatorInfo(schema::PrimitiveType operator_type, TypeId type_id, bool is_depth_wise,
const OperatorInfoCreatorFunc &creator_func) {
// create a key to find the only create function
SplitOpKey op_key(operator_type, type_id, is_depth_wise);
if (operator_info_map_.find(op_key) != operator_info_map_.end()) {
MS_LOG(ERROR) << " Operator already exist " << op_key.ToString();
return lite::RET_ERROR;
return;
}
this->operator_info_map_.insert(std::pair<SplitOpKey, OperatorInfoCreatorFunc>(op_key, creator_func));
return lite::RET_OK;
}
OperatorInfoCreatorFunc OperatorInfoFactory::FindOperatorInfo(const SplitOpKey &op_key) {

View File

@ -56,8 +56,8 @@ class OperatorInfoFactory {
OperatorInfoFactory &operator=(const OperatorInfoFactory &) = delete;
int RegisterOperatorInfo(schema::PrimitiveType operator_type, TypeId type_id, bool is_depth_wise,
const OperatorInfoCreatorFunc &creator_func);
void RegisterOperatorInfo(schema::PrimitiveType operator_type, TypeId type_id, bool is_depth_wise,
const OperatorInfoCreatorFunc &creator_func);
OperatorInfoCreatorFunc FindOperatorInfo(const SplitOpKey &split_op_key);

View File

@ -104,6 +104,7 @@ AnfNodePtr ParallelPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &n
}
// if current conv2d node has two output nodes ,we do not split it;
auto manager = func_graph->manager();
MS_CHECK_TRUE_MSG(manager != nullptr, nullptr, "manager is nullptr.");
auto iter = manager->node_users().find(node);
if (iter == manager->node_users().end()) {
MS_LOG(ERROR) << "node : " << node->fullname_with_scope() << "has no output";