forked from mindspore-Ecosystem/mindspore
[MSLITE] CodeCheck: converter
This commit is contained in:
parent
12e71fa9d2
commit
f4dc2dcd5b
|
@ -58,11 +58,19 @@ STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, Shap
|
|||
MS_LOG(ERROR) << "string tensor's dim size not found.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t shape_size = std::stoi(shape_size_str);
|
||||
size_t shape_size = std::atoi(shape_size_str.c_str());
|
||||
MS_CHECK_TRUE_RET(shape_size != 0, RET_ERROR);
|
||||
for (; *offset < tensor_info->Size(); (*offset)++) {
|
||||
if (tensor_data[*offset] == ',') {
|
||||
cnt++;
|
||||
shape_vector->push_back(std::stoi(shape_str));
|
||||
int64_t shape = 0;
|
||||
try {
|
||||
shape = std::stoi(shape_str);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "Get shape failed: " << e.what();
|
||||
return RET_ERROR;
|
||||
}
|
||||
shape_vector->push_back(shape);
|
||||
shape_str.clear();
|
||||
} else {
|
||||
shape_str.push_back(tensor_data[*offset]);
|
||||
|
|
|
@ -197,11 +197,18 @@ int Flags::InitInTensorShape() {
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
for (const auto &dim : dims) {
|
||||
if (std::stoi(dim) < 0) {
|
||||
auto dim_value = -1;
|
||||
try {
|
||||
dim_value = std::stoi(dim);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "Get dim failed: " << e.what();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (dim_value < 0) {
|
||||
MS_LOG(ERROR) << "Unsupported dim < 0.";
|
||||
return lite::RET_ERROR;
|
||||
} else {
|
||||
shape.push_back(std::stoi(dim));
|
||||
shape.push_back(dim_value);
|
||||
}
|
||||
}
|
||||
lite::ConverterInnerContext::GetInstance()->UpdateGraphInputTensorShape(name, shape);
|
||||
|
@ -419,7 +426,12 @@ bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *pa
|
|||
const char *colon = ":";
|
||||
for (const auto &device : device_rates) {
|
||||
std::vector<std::string> rate = lite::SplitStringToVector(device, *colon);
|
||||
parallel_split_config->parallel_compute_rates_.push_back(std::stoi(rate.back()));
|
||||
auto compute_rate = std::atoi(rate.back().c_str());
|
||||
if (compute_rate == 0) {
|
||||
MS_LOG(ERROR) << "The compute rate is invalid.";
|
||||
return false;
|
||||
}
|
||||
parallel_split_config->parallel_compute_rates_.push_back(compute_rate);
|
||||
}
|
||||
if (parallel_split_config->parallel_compute_rates_.size() != 2) {
|
||||
return false;
|
||||
|
|
|
@ -67,7 +67,7 @@ AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
auto while_prim = primitive->cast<PrimWhilePtr>();
|
||||
MS_CHECK_TRUE_RET(while_prim != nullptr, nullptr);
|
||||
AbstractBasePtrList output;
|
||||
for (int64_t i = 0; i < (int64_t)input_args.size(); i++) {
|
||||
for (size_t i = 0; i < input_args.size(); i++) {
|
||||
auto build_shape_ptr = input_args[i]->BuildShape();
|
||||
MS_CHECK_TRUE_RET(build_shape_ptr != nullptr, nullptr);
|
||||
auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(build_shape_ptr)[kShape];
|
||||
|
|
|
@ -116,7 +116,7 @@ int CenterCrop(cv::Mat *image, int width, int height) {
|
|||
}
|
||||
|
||||
int PreProcess(const preprocess::DataPreProcessParam &data_pre_process_param, const std::string &input_name,
|
||||
int image_index, mindspore::tensor::MSTensor *tensor) {
|
||||
size_t image_index, mindspore::tensor::MSTensor *tensor) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -155,7 +155,7 @@ int PreProcess(const preprocess::DataPreProcessParam &data_pre_process_param, co
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int PreProcess(const DataPreProcessParam &data_pre_process_param, const std::string &input_name, int image_index,
|
||||
int PreProcess(const DataPreProcessParam &data_pre_process_param, const std::string &input_name, size_t image_index,
|
||||
void **data, size_t *size) {
|
||||
if (data == nullptr || size == nullptr) {
|
||||
MS_LOG(ERROR) << "data or size is nullptr.";
|
||||
|
|
|
@ -35,11 +35,11 @@ int Resize(cv::Mat *image, int width, int height, cv::InterpolationFlags resize_
|
|||
int CenterCrop(cv::Mat *image, int width, int height);
|
||||
|
||||
// NOTE:`data` must be use delete[] to free buffer.
|
||||
int PreProcess(const DataPreProcessParam &data_pre_process_param, const std::string &input_name, int image_index,
|
||||
int PreProcess(const DataPreProcessParam &data_pre_process_param, const std::string &input_name, size_t image_index,
|
||||
void **data, size_t *size);
|
||||
|
||||
int PreProcess(const preprocess::DataPreProcessParam &data_pre_process_param, const std::string &input_name,
|
||||
int image_index, mindspore::tensor::MSTensor *tensor);
|
||||
size_t image_index, mindspore::tensor::MSTensor *tensor);
|
||||
|
||||
int ImagePreProcess(const ImagePreProcessParam &image_preprocess_param, cv::Mat *image, void **data, size_t *size);
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ int ConcatSplitEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
|
|||
Spliter::GetInstance()->graph_node_outputs();
|
||||
auto finder = graph_node_outputs.find(pre_cnode->fullname_with_scope());
|
||||
if (finder == graph_node_outputs.end()) {
|
||||
return RET_OK;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (finder->second.size() > 1) {
|
||||
return RET_OK;
|
||||
|
@ -89,7 +89,7 @@ int ConcatSplitEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
|
|||
// get inputs node
|
||||
auto it = graph_node_outputs.find(cnode->fullname_with_scope());
|
||||
if (it == graph_node_outputs.end()) {
|
||||
return RET_OK;
|
||||
return RET_ERROR;
|
||||
}
|
||||
int out_num = it->second.size();
|
||||
if (out_num != prim->get_number_split()) {
|
||||
|
@ -101,14 +101,14 @@ int ConcatSplitEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
|
|||
auto tmp = it->second[i];
|
||||
auto tmp_cnode = tmp->cast<CNodePtr>();
|
||||
if (tmp_cnode == nullptr) {
|
||||
return RET_OK;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!CheckPrimitiveType(tmp_cnode, prim::kPrimTupleGetItem)) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto tmp_it = graph_node_outputs.find(tmp_cnode->fullname_with_scope());
|
||||
if (tmp_it == graph_node_outputs.end()) {
|
||||
return RET_OK;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (tmp_it->second.size() != 1) {
|
||||
return RET_OK;
|
||||
|
@ -116,7 +116,7 @@ int ConcatSplitEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
|
|||
|
||||
auto next = tmp_it->second[0];
|
||||
auto next_cnode = next->cast<CNodePtr>();
|
||||
|
||||
MS_ASSERT(next_cnode != nullptr);
|
||||
inputs_node.push_back(next_cnode);
|
||||
}
|
||||
// replace inputs
|
||||
|
|
|
@ -48,14 +48,21 @@ std::vector<int64_t> GetSplitPadList(const std::shared_ptr<ops::Conv2DFusion> &o
|
|||
int64_t output_w = static_cast<int64_t>(
|
||||
std::ceil(static_cast<float>(input_w) / static_cast<float>(ori_conv_prim->get_stride().at(kIndexW))));
|
||||
|
||||
auto kernel_h = ori_conv_prim->get_kernel_size().at(kIndexH);
|
||||
auto dilation_h = ori_conv_prim->get_dilation().at(kIndexH);
|
||||
auto kernel_w = ori_conv_prim->get_kernel_size().at(kIndexW);
|
||||
auto dilation_w = ori_conv_prim->get_dilation().at(kIndexW);
|
||||
if (INT_MUL_OVERFLOW_THRESHOLD((kernel_h - 1), dilation_h, INT64_MAX) ||
|
||||
INT_MUL_OVERFLOW_THRESHOLD((kernel_w - 1), dilation_w, INT64_MAX)) {
|
||||
MS_LOG(ERROR) << "int mul overflow";
|
||||
return {};
|
||||
}
|
||||
std::vector<int64_t> new_pad_list;
|
||||
int64_t pad_up = 0, pad_down = 0, pad_left = 0, pad_right = 0;
|
||||
int64_t pad_h_all = (output_h - 1) * ori_conv_prim->get_stride().at(kIndexH) +
|
||||
(ori_conv_prim->get_kernel_size().at(kIndexH) - 1) * ori_conv_prim->get_dilation().at(kIndexH) +
|
||||
1 - input_h;
|
||||
int64_t pad_w_all = (output_w - 1) * ori_conv_prim->get_stride().at(kIndexW) +
|
||||
(ori_conv_prim->get_kernel_size().at(kIndexW) - 1) * ori_conv_prim->get_dilation().at(kIndexW) +
|
||||
1 - input_w;
|
||||
int64_t pad_h_all =
|
||||
(output_h - 1) * ori_conv_prim->get_stride().at(kIndexH) + (kernel_h - 1) * dilation_h + 1 - input_h;
|
||||
int64_t pad_w_all =
|
||||
(output_w - 1) * ori_conv_prim->get_stride().at(kIndexW) + (kernel_w - 1) * dilation_w + 1 - input_w;
|
||||
// only check pad_up and pad_down is positive
|
||||
// if compute overflowed, we will get abnormal it in infer_shape
|
||||
if (pad_h_all >= 0) {
|
||||
|
@ -110,9 +117,11 @@ bool CalSplitInShape(const std::vector<std::vector<ShapeVector>> &node_in_out_sh
|
|||
const std::shared_ptr<ops::Conv2DFusion> &ori_conv_prim, int64_t index_node,
|
||||
std::vector<std::vector<int64_t>> *split_axis_inputs_shape,
|
||||
std::vector<std::vector<int64_t>> *split_axis_reduce_inputs_shape) {
|
||||
MS_ASSERT(split_info != nullptr && split_axis_inputs_shape != nullptr && split_axis_reduce_inputs_shape != nullptr);
|
||||
MS_ASSERT(split_info != nullptr && ori_conv_prim != nullptr && split_axis_inputs_shape != nullptr &&
|
||||
split_axis_reduce_inputs_shape != nullptr);
|
||||
MS_ASSERT(node_in_out_shapes.size() > index_node);
|
||||
auto in_out_shape = node_in_out_shapes.at(index_node);
|
||||
MS_ASSERT(!in_out_shape.empty());
|
||||
auto in_shape = in_out_shape.front();
|
||||
if (in_shape.size() < kAxisW) {
|
||||
MS_LOG(DEBUG) << "out of in_shape range";
|
||||
|
@ -129,29 +138,34 @@ bool CalSplitInShape(const std::vector<std::vector<ShapeVector>> &node_in_out_sh
|
|||
// iter splited_num
|
||||
for (int64_t index = 0; index < split_num; index++) {
|
||||
// shape
|
||||
auto stride_h = ori_conv_prim->get_stride()[kIndexH];
|
||||
auto split_axis_dim = (*split_axis_inputs_shape)[index_node][index] - 1;
|
||||
if (INT_MUL_OVERFLOW_THRESHOLD(stride_h, split_axis_dim, INT64_MAX)) {
|
||||
MS_LOG(ERROR) << "int mul overflow";
|
||||
return false;
|
||||
}
|
||||
if (split_info->axis == CuttingStragedy::CUT_H) { // H
|
||||
if (index == 0) {
|
||||
tmp = ori_conv_prim->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[index_node][index] - 1) -
|
||||
ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
|
||||
tmp =
|
||||
stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
|
||||
} else if (index == split_num - 1) {
|
||||
tmp = ori_conv_prim->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[index_node][index] - 1) -
|
||||
ori_conv_prim->get_pad_list()[kPadDown] + ori_conv_prim->get_kernel_size()[kIndexH];
|
||||
} else {
|
||||
tmp = ori_conv_prim->get_stride()[kIndexH] * ((*split_axis_inputs_shape)[index_node][index] - 1) - 0 +
|
||||
tmp = stride_h * split_axis_dim - ori_conv_prim->get_pad_list()[kPadDown] +
|
||||
ori_conv_prim->get_kernel_size()[kIndexH];
|
||||
} else {
|
||||
tmp = stride_h * split_axis_dim - 0 + ori_conv_prim->get_kernel_size()[kIndexH];
|
||||
}
|
||||
}
|
||||
split_axis_shape.push_back(tmp);
|
||||
|
||||
// reduce shape
|
||||
auto split_axis_reduce_dim = (*split_axis_reduce_inputs_shape)[index_node][index] - 1;
|
||||
if (split_info->axis == CuttingStragedy::CUT_H) { // H
|
||||
if (index == split_num - 1) {
|
||||
tmp = ori_conv_prim->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[index_node][index] - 1) -
|
||||
ori_conv_prim->get_pad_list()[kPadDown] - ori_conv_prim->get_pad_list()[kPadUp] +
|
||||
ori_conv_prim->get_kernel_size()[kIndexH];
|
||||
} else {
|
||||
tmp = ori_conv_prim->get_stride()[kIndexH] * ((*split_axis_reduce_inputs_shape)[index_node][index] - 1) -
|
||||
tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadDown] -
|
||||
ori_conv_prim->get_pad_list()[kPadUp] + ori_conv_prim->get_kernel_size()[kIndexH];
|
||||
} else {
|
||||
tmp = stride_h * split_axis_reduce_dim - ori_conv_prim->get_pad_list()[kPadUp] +
|
||||
ori_conv_prim->get_kernel_size()[kIndexH];
|
||||
}
|
||||
}
|
||||
split_axis_reduce_shape.push_back(tmp);
|
||||
|
@ -186,9 +200,8 @@ std::shared_ptr<ops::Conv2DFusion> CopyConvPrim(const std::shared_ptr<ops::Conv2
|
|||
new_prim->set_stride(ori_conv_prim->get_stride());
|
||||
new_prim->set_activation_type(ori_conv_prim->get_activation_type());
|
||||
new_prim->set_pad_list(ori_conv_prim->get_pad_list());
|
||||
if (ori_conv_prim->GetAttr(ops::kIsDepthWise) != nullptr) {
|
||||
auto is_depth_value = ori_conv_prim->GetAttr(ops::kIsDepthWise);
|
||||
MS_CHECK_TRUE_MSG(is_depth_value != nullptr, nullptr, "conv has no kIsDepthWise attribute");
|
||||
auto is_depth_value = ori_conv_prim->GetAttr(ops::kIsDepthWise);
|
||||
if (is_depth_value != nullptr) {
|
||||
bool is_depth_wise = GetValue<bool>(is_depth_value);
|
||||
new_prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(is_depth_wise));
|
||||
}
|
||||
|
@ -218,6 +231,7 @@ bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePt
|
|||
auto input_shapes = Spliter::GetInstance()->graph_node_input_shapes()[out_node_name];
|
||||
// 0-> in-shape 1->out-shape
|
||||
// only one in and one output
|
||||
MS_ASSERT(!input_shapes.empty() && !output_shapes.empty());
|
||||
node_in_out_shapes.push_back({input_shapes.front(), output_shapes.front()});
|
||||
index_node++;
|
||||
}
|
||||
|
@ -244,7 +258,9 @@ bool UpdateSplitInfo(const FuncGraphPtr &func_graph, const std::vector<AnfNodePt
|
|||
// iter node
|
||||
while (index_node < node_size) {
|
||||
auto conv_cnode = conv_nodes[index_node]->cast<CNodePtr>();
|
||||
MS_ASSERT(conv_cnode != nullptr);
|
||||
auto ori_conv_prim = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(conv_cnode->input(kAnfPrimitiveIndex));
|
||||
MS_CHECK_TRUE_RET(ori_conv_prim != nullptr, false);
|
||||
if (!CalSplitInShape(node_in_out_shapes, split_info, ori_conv_prim, index_node, &split_axis_inputs_shape,
|
||||
&split_axis_reduce_inputs_shape)) {
|
||||
MS_LOG(ERROR) << "CalSplitInShape failed";
|
||||
|
@ -405,6 +421,10 @@ bool UpdateRatioWithPadStride(int64_t *ratio, size_t ratio_len, size_t split_siz
|
|||
int visited_block = 0;
|
||||
for (size_t i = 0; i < split_size - 1; i++) {
|
||||
visited_block += ratio[i];
|
||||
if (INT_MUL_OVERFLOW_THRESHOLD(split_dim_size, visited_block, INT64_MAX)) {
|
||||
MS_LOG(ERROR) << "int mul overflow";
|
||||
return false;
|
||||
}
|
||||
int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
|
||||
new_ratio[i + 1] = cur_border;
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ AnfNodePtr IterNodeOutputs::Run(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
auto inputs = cnode->inputs();
|
||||
|
||||
for (const auto &input_node : inputs) {
|
||||
MS_ASSERT(input_node != nullptr);
|
||||
if (!utils::isa<CNodePtr>(input_node)) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ using mindspore::schema::PrimitiveType_Conv2dTransposeFusion;
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
std::string MultiConvSplitPass::IsMultiParallelConvNode(const AnfNodePtr &node) const {
|
||||
MS_ASSERT(node != nullptr);
|
||||
for (const auto ¶llel_prim : kParallelOpNames) {
|
||||
if (CheckPrimitiveType(node, parallel_prim.first.first)) {
|
||||
return parallel_prim.second;
|
||||
|
|
|
@ -32,6 +32,7 @@ AnfNodePtr NodeOutShapes::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
// assume multi inputs
|
||||
for (const auto &input_node : cnode->inputs()) {
|
||||
MS_ASSERT(input_node != nullptr);
|
||||
if (utils::isa<CNodePtr>(input_node) || utils::isa<ParameterPtr>(input_node)) {
|
||||
auto in_shape = input_node->Shape();
|
||||
if (in_shape == nullptr) {
|
||||
|
|
|
@ -56,10 +56,10 @@ bool FuseBias(const lite::DataInfo &add_bias, const lite::DataInfo &conv_bias, s
|
|||
add_bias.data_.size()) != EOK) {
|
||||
return false;
|
||||
}
|
||||
fusion_bias->resize(out_channel, 0);
|
||||
fusion_bias->resize(static_cast<size_t>(out_channel), 0);
|
||||
if (!conv_bias.data_.empty()) {
|
||||
if (conv_bias.data_type_ != TypeId::kNumberTypeFloat32 && conv_bias.data_type_ != TypeId::kNumberTypeFloat &&
|
||||
conv_bias.data_.size() != out_channel * sizeof(float)) {
|
||||
conv_bias.data_.size() != static_cast<size_t>(out_channel) * sizeof(float)) {
|
||||
return false;
|
||||
}
|
||||
if (memcpy_s(fusion_bias->data(), fusion_bias->size() * sizeof(float), conv_bias.data_.data(),
|
||||
|
|
|
@ -110,8 +110,7 @@ void ReplaceParamsAndNodes(const FuncGraphPtr &func_graph, const CNodePtr &conv_
|
|||
(void)manager->Replace(pad_cnode, pad_cnode->input(1));
|
||||
}
|
||||
|
||||
bool IsPrimitiveProper(const CNodePtr &conv_cnode, const CNodePtr &pad_cnode) {
|
||||
MS_ASSERT(conv_cnode != nullptr);
|
||||
bool IsPrimitiveProper(const CNodePtr &pad_cnode) {
|
||||
MS_ASSERT(pad_cnode != nullptr);
|
||||
if (!utils::isa<Parameter>(pad_cnode->input(kInputIndexTwo))) {
|
||||
return false;
|
||||
|
@ -240,7 +239,7 @@ AnfNodePtr ConvPadFusion::Process(const std::string &pattern_name, const FuncGra
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (!IsPrimitiveProper(conv_cnode, pad_cnode)) {
|
||||
if (!IsPrimitiveProper(pad_cnode)) {
|
||||
MS_LOG(WARNING) << conv_cnode->fullname_with_scope() << " is not match with previous "
|
||||
<< pad_cnode->fullname_with_scope() << " op. Fusion failed!";
|
||||
return nullptr;
|
||||
|
|
|
@ -267,7 +267,7 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &func_graph, const Anf
|
|||
return nullptr;
|
||||
}
|
||||
scale_primitive->set_activation_type(scale_act_type_);
|
||||
scale_primitive->set_axis(-(bias_tensor_->shape_c().size()));
|
||||
scale_primitive->set_axis(-(static_cast<int64_t>(bias_tensor_->shape_c().size())));
|
||||
// create scale op
|
||||
auto scale_node = func_graph->NewCNode(scale_primitive, {mul_input_anode_, mul_const_anode_, add_const_anode_});
|
||||
return scale_node;
|
||||
|
|
|
@ -51,7 +51,8 @@ STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) {
|
|||
}
|
||||
axes->resize(1);
|
||||
if (!axes_value->shape().empty()) {
|
||||
axes->resize(axes_value->shape()[0]);
|
||||
MS_CHECK_GE(axes_value->shape()[0], 0, lite::RET_ERROR);
|
||||
axes->resize(static_cast<size_t>(axes_value->shape()[0]));
|
||||
}
|
||||
if (memcpy_s(axes->data(), axes->size() * sizeof(int), axes_value->data_c(), axes_value->Size()) == EOK) {
|
||||
return lite::RET_OK;
|
||||
|
@ -285,7 +286,8 @@ int ExpandDimsShapeSizeInfer(const std::vector<int> &in_shape_size, const schema
|
|||
|
||||
int StridedSliceShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
|
||||
MS_ASSERT(in_shape_size.size() > 0);
|
||||
auto new_axis_mask = primitive.value.AsStridedSlice()->new_axis_mask;
|
||||
MS_ASSERT(primitive.value.AsStridedSlice() != nullptr);
|
||||
auto new_axis_mask = static_cast<size_t>(primitive.value.AsStridedSlice()->new_axis_mask);
|
||||
auto add_dims = 0;
|
||||
while (new_axis_mask != 0) {
|
||||
new_axis_mask = (new_axis_mask - 1) & new_axis_mask;
|
||||
|
@ -404,7 +406,7 @@ std::map<string, int> NormFusion::ShapeSizeInfer(const FuncGraphPtr &func_graph)
|
|||
}
|
||||
// Cal shape size infer function
|
||||
auto shape_size_infer_func = shape_size_infer_iter->second;
|
||||
auto shape_size = shape_size_infer_iter->second(in_shape_sizes, *prim_t);
|
||||
auto shape_size = shape_size_infer_func(in_shape_sizes, *prim_t);
|
||||
// Update node shape size map
|
||||
node_shape_size[cnode->fullname_with_scope()] = shape_size;
|
||||
}
|
||||
|
|
|
@ -766,8 +766,8 @@ const AnfNodePtr TfBidirectionGruFusion::Process(const FuncGraphPtr &func_graph,
|
|||
MS_CHECK_TRUE_RET(fw_cond_graph_pattern != nullptr, nullptr);
|
||||
auto fw_cond = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[0]]);
|
||||
MS_ASSERT(fw_cond != nullptr);
|
||||
auto fw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, fw_cond_graph_pattern, fw_cond_primitive_vars,
|
||||
fw_cond, kCondCNodesNum, kCondNodesNum);
|
||||
auto fw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(fw_cond_graph_pattern, fw_cond_primitive_vars, fw_cond,
|
||||
kCondCNodesNum, kCondNodesNum);
|
||||
if (fw_cond_equiv == nullptr || fw_cond_equiv->empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -778,8 +778,8 @@ const AnfNodePtr TfBidirectionGruFusion::Process(const FuncGraphPtr &func_graph,
|
|||
MS_CHECK_TRUE_RET(bw_cond_graph_pattern != nullptr, nullptr);
|
||||
auto bw_cond = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[0]]);
|
||||
MS_ASSERT(bw_cond != nullptr);
|
||||
auto bw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, bw_cond_graph_pattern, bw_cond_primitive_vars,
|
||||
bw_cond, kCondCNodesNum, kCondNodesNum);
|
||||
auto bw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(bw_cond_graph_pattern, bw_cond_primitive_vars, bw_cond,
|
||||
kCondCNodesNum, kCondNodesNum);
|
||||
if (bw_cond_equiv == nullptr || bw_cond_equiv->empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -790,8 +790,8 @@ const AnfNodePtr TfBidirectionGruFusion::Process(const FuncGraphPtr &func_graph,
|
|||
MS_CHECK_TRUE_RET(fw_body_graph_pattern != nullptr, nullptr);
|
||||
auto fw_body = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[1]]);
|
||||
MS_ASSERT(fw_body != nullptr);
|
||||
auto fw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, fw_body_graph_pattern, fw_primitive_vars_body,
|
||||
fw_body, kBodyCNodesNum, kBodyNodesNum);
|
||||
auto fw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(fw_body_graph_pattern, fw_primitive_vars_body, fw_body,
|
||||
kBodyCNodesNum, kBodyNodesNum);
|
||||
if (fw_body_equiv == nullptr || fw_body_equiv->empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -802,8 +802,8 @@ const AnfNodePtr TfBidirectionGruFusion::Process(const FuncGraphPtr &func_graph,
|
|||
MS_CHECK_TRUE_RET(bw_body_graph_pattern != nullptr, nullptr);
|
||||
auto bw_body = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[1]]);
|
||||
MS_ASSERT(bw_body != nullptr);
|
||||
auto bw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(func_graph, bw_body_graph_pattern, bw_primitive_vars_body,
|
||||
bw_body, kBodyCNodesNum, kBodyNodesNum);
|
||||
auto bw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(bw_body_graph_pattern, bw_primitive_vars_body, bw_body,
|
||||
kBodyCNodesNum, kBodyNodesNum);
|
||||
if (bw_body_equiv == nullptr || bw_body_equiv->empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -191,7 +191,7 @@ STATUS TfLstmCellFusion::SetWeightAbstractAndDefault(const ParameterPtr &weight,
|
|||
return RET_ERROR;
|
||||
}
|
||||
const auto param_num = shape[0] * shape[1] * shape[kInputIndexTwo];
|
||||
auto tensor_data = new (std::nothrow) float[param_num * sizeof(float)];
|
||||
auto tensor_data = new (std::nothrow) float[static_cast<size_t>(param_num) * sizeof(float)];
|
||||
std::vector<int> data_diff{0, 3, 2, 1};
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(DEBUG) << "new data failed";
|
||||
|
@ -204,7 +204,8 @@ STATUS TfLstmCellFusion::SetWeightAbstractAndDefault(const ParameterPtr &weight,
|
|||
}
|
||||
}
|
||||
}
|
||||
auto tensor_info = lite::CreateTensorInfo(tensor_data, param_num * sizeof(float), shape, kNumberTypeFloat32);
|
||||
auto tensor_info =
|
||||
lite::CreateTensorInfo(tensor_data, static_cast<size_t>(param_num) * sizeof(float), shape, kNumberTypeFloat32);
|
||||
delete[] tensor_data;
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "create tensor info failed.";
|
||||
|
@ -359,11 +360,6 @@ CNodePtr TfLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const
|
|||
MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
|
||||
|
||||
auto &vars = while_input_vars_;
|
||||
|
||||
auto limit1 = utils::cast<AnfNodePtr>((*equiv)[vars[3]]);
|
||||
MS_ASSERT(limit1);
|
||||
auto limit2 = utils::cast<AnfNodePtr>((*equiv)[vars[7]]);
|
||||
MS_ASSERT(limit2);
|
||||
auto weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]);
|
||||
MS_ASSERT(weight);
|
||||
auto bias = utils::cast<AnfNodePtr>((*equiv)[vars[10]]);
|
||||
|
|
|
@ -329,7 +329,7 @@ bool TfliteLstmCellFusion::CheckReferencedOutputs(const FuncGraphPtr &func_graph
|
|||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "manager is nullptr";
|
||||
return RET_ERROR;
|
||||
return false;
|
||||
}
|
||||
auto while_node_users = manager->node_users()[while_cnode];
|
||||
std::vector<size_t> valid_indexes{3, 4, 5};
|
||||
|
@ -352,9 +352,9 @@ bool TfliteLstmCellFusion::CheckReferencedOutputs(const FuncGraphPtr &func_graph
|
|||
return true;
|
||||
}
|
||||
|
||||
EquivPtr TfliteLstmCellFusion::CheckSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &pattern,
|
||||
const PrimitiveVarMapPtr &primitive_vars, const AnfNodePtr &anf_sub_graph,
|
||||
const size_t cnode_num, const size_t all_node_num) {
|
||||
EquivPtr TfliteLstmCellFusion::CheckSubGraph(const AnfNodePtr &pattern, const PrimitiveVarMapPtr &primitive_vars,
|
||||
const AnfNodePtr &anf_sub_graph, const size_t cnode_num,
|
||||
const size_t all_node_num) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(pattern != nullptr);
|
||||
MS_ASSERT(primitive_vars != nullptr);
|
||||
|
@ -370,9 +370,7 @@ EquivPtr TfliteLstmCellFusion::CheckSubGraph(const FuncGraphPtr &func_graph, con
|
|||
return MatchGraph(sub_graph, primitive_vars, pattern);
|
||||
}
|
||||
|
||||
bool TfliteLstmCellFusion::CheckBodyGraph(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
const CNodePtr &while_cnode, float *zoneout_cell,
|
||||
float *zoneout_hidden) const {
|
||||
bool TfliteLstmCellFusion::CheckBodyGraph(const EquivPtr &equiv, float *zoneout_cell, float *zoneout_hidden) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
MS_ASSERT(while_cnode != nullptr);
|
||||
|
@ -465,8 +463,9 @@ STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector<AnfNodePtr> &par
|
|||
MS_LOG(ERROR) << "bias data shape error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
step = data_shapes[0][0];
|
||||
data_size = 8 * step;
|
||||
step = static_cast<int>(data_shapes[0][0]);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(C8NUM, step, RET_ERROR);
|
||||
data_size = C8NUM * step;
|
||||
new_shape = std::vector<int64_t>({1, data_size});
|
||||
|
||||
} else {
|
||||
|
@ -475,8 +474,10 @@ STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector<AnfNodePtr> &par
|
|||
return RET_ERROR;
|
||||
}
|
||||
new_shape = std::vector<int64_t>({1, data_shapes[0][0] * kUnidirectionalGateNum, data_shapes[0][1]});
|
||||
step = data_shapes[0][0] * data_shapes[0][1];
|
||||
data_size = 4 * step;
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(data_shapes[0][0], data_shapes[0][1], RET_ERROR);
|
||||
step = static_cast<int>(data_shapes[0][0] * data_shapes[0][1]);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(C4NUM, step, RET_ERROR);
|
||||
data_size = C4NUM * step;
|
||||
}
|
||||
|
||||
auto tensor_info = lite::CreateTensorInfo(nullptr, 0, new_shape, kNumberTypeFloat32);
|
||||
|
@ -528,12 +529,6 @@ CNodePtr TfliteLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, co
|
|||
MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
|
||||
|
||||
auto &vars = while_input_vars_;
|
||||
|
||||
auto limit1 = utils::cast<AnfNodePtr>((*equiv)[vars[3]]);
|
||||
MS_ASSERT(limit1);
|
||||
auto limit2 = utils::cast<AnfNodePtr>((*equiv)[vars[7]]);
|
||||
MS_ASSERT(limit2);
|
||||
|
||||
auto i2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]);
|
||||
MS_ASSERT(i2i_weight);
|
||||
auto i2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[10]]);
|
||||
|
@ -764,8 +759,8 @@ const AnfNodePtr TfliteLstmCellFusion::Process(const FuncGraphPtr &func_graph, c
|
|||
MS_CHECK_TRUE_RET(primitive_vars_cond != nullptr, nullptr);
|
||||
auto cond_graph_pattern = GetCondGraphPattern(primitive_vars_cond);
|
||||
MS_CHECK_TRUE_RET(cond_graph_pattern != nullptr, nullptr);
|
||||
auto cond_equiv = CheckSubGraph(func_graph, cond_graph_pattern, primitive_vars_cond, while_cnode->input(1),
|
||||
cond_cnodes_num_, cond_nodes_num_);
|
||||
auto cond_equiv =
|
||||
CheckSubGraph(cond_graph_pattern, primitive_vars_cond, while_cnode->input(1), cond_cnodes_num_, cond_nodes_num_);
|
||||
if (cond_equiv == nullptr || cond_equiv->empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -773,14 +768,14 @@ const AnfNodePtr TfliteLstmCellFusion::Process(const FuncGraphPtr &func_graph, c
|
|||
MS_CHECK_TRUE_RET(primitive_vars_body != nullptr, nullptr);
|
||||
auto body_graph_pattern = GetBodyGraphPattern(primitive_vars_body);
|
||||
MS_CHECK_TRUE_RET(body_graph_pattern != nullptr, nullptr);
|
||||
auto body_equiv = CheckSubGraph(func_graph, body_graph_pattern, primitive_vars_body, while_cnode->input(2),
|
||||
body_cnodes_num_, body_nodes_num_);
|
||||
auto body_equiv =
|
||||
CheckSubGraph(body_graph_pattern, primitive_vars_body, while_cnode->input(2), body_cnodes_num_, body_nodes_num_);
|
||||
if (body_equiv == nullptr || body_equiv->empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
float zoneout_cell = 0.0f;
|
||||
float zoneout_hidden = 0.0f;
|
||||
if (!CheckBodyGraph(func_graph, body_equiv, while_cnode, &zoneout_cell, &zoneout_hidden)) {
|
||||
if (!CheckBodyGraph(body_equiv, &zoneout_cell, &zoneout_hidden)) {
|
||||
return nullptr;
|
||||
}
|
||||
const std::string lstm_name = "lstm_" + while_cnode->fullname_with_scope();
|
||||
|
|
|
@ -36,9 +36,8 @@ class TfliteLstmCellFusion : public PatternProcessPass {
|
|||
static EquivPtr MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars,
|
||||
const AnfNodePtr &pattern);
|
||||
|
||||
static EquivPtr CheckSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &pattern,
|
||||
const PrimitiveVarMapPtr &primitive_vars, const AnfNodePtr &anf_sub_graph,
|
||||
size_t cnode_num, size_t all_node_num);
|
||||
static EquivPtr CheckSubGraph(const AnfNodePtr &pattern, const PrimitiveVarMapPtr &primitive_vars,
|
||||
const AnfNodePtr &anf_sub_graph, size_t cnode_num, size_t all_node_num);
|
||||
|
||||
static lite::STATUS SetAbstractTuple(const CNodePtr &cnode, int output_num);
|
||||
|
||||
|
@ -68,8 +67,7 @@ class TfliteLstmCellFusion : public PatternProcessPass {
|
|||
|
||||
private:
|
||||
CNodePtr GetWhileCnode(const AnfNodePtr &cnode) const;
|
||||
bool CheckBodyGraph(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const CNodePtr &while_cnode,
|
||||
float *zoneout_cell, float *zoneout_hidden) const;
|
||||
bool CheckBodyGraph(const EquivPtr &equiv, float *zoneout_cell, float *zoneout_hidden) const;
|
||||
|
||||
static bool CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode);
|
||||
|
||||
|
|
|
@ -138,9 +138,8 @@ CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu
|
|||
return cnode;
|
||||
}
|
||||
|
||||
AnfNodePtr TransposeFusion::TransTransFusion(const mindspore::FuncGraphPtr &func_graph,
|
||||
const mindspore::AnfNodePtr &node) const {
|
||||
MS_ASSERT(func_graph != nullptr && node != nullptr);
|
||||
AnfNodePtr TransposeFusion::TransTransFusion(const mindspore::AnfNodePtr &node) const {
|
||||
MS_ASSERT(node != nullptr);
|
||||
auto trans_cnode_2 = node->cast<CNodePtr>();
|
||||
if (IsMarkedTrainOp(trans_cnode_2)) {
|
||||
return nullptr;
|
||||
|
@ -181,7 +180,7 @@ AnfNodePtr TransposeFusion::Process(const std::string &pattern_name, const minds
|
|||
return nullptr;
|
||||
}
|
||||
if (pattern_name == "TransTransPatternName") {
|
||||
return TransTransFusion(func_graph, node);
|
||||
return TransTransFusion(node);
|
||||
}
|
||||
if (node->cast<CNodePtr>() == nullptr) {
|
||||
return nullptr;
|
||||
|
|
|
@ -38,7 +38,7 @@ class TransposeFusion : public MultiplePatternProcessPass {
|
|||
VectorRef DefineActivationscalePattern() const;
|
||||
VectorRef DefineTransTransPattern() const;
|
||||
VectorRef DefineBiasAddPattern() const;
|
||||
AnfNodePtr TransTransFusion(const mindspore::FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const;
|
||||
AnfNodePtr TransTransFusion(const mindspore::AnfNodePtr &node) const;
|
||||
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &,
|
||||
const EquivPtr &) const override;
|
||||
};
|
||||
|
|
|
@ -24,10 +24,11 @@
|
|||
#include "ops/tensor_array_read.h"
|
||||
#include "ops/tensor_array_write.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
constexpr auto kDefaultIndex = 0;
|
||||
constexpr auto kInputIndex = 1;
|
||||
constexpr auto kInputNodeIndex = 1;
|
||||
constexpr auto kDefaultNumTensors = 1;
|
||||
constexpr auto kFlowInPlaceHolder = 1;
|
||||
|
||||
|
@ -45,6 +46,8 @@ static bool IsSupportedNode(const BaseRef &n) {
|
|||
}
|
||||
|
||||
static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tensor_array_write_node) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(tensor_array_write_node != nullptr);
|
||||
// set tensor_array_write_node as graph output to keep it
|
||||
auto return_node = func_graph->get_return();
|
||||
if (!CheckPrimitiveType(return_node, prim::kPrimReturn)) {
|
||||
|
@ -56,7 +59,7 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens
|
|||
MS_LOG(ERROR) << "graph return node is not cnode";
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
auto output_node = return_node->input(kInputIndex);
|
||||
auto output_node = return_node->input(kInputNodeIndex);
|
||||
if (output_node == nullptr) {
|
||||
MS_LOG(ERROR) << "graph output node is null";
|
||||
return lite::RET_NULL_PTR;
|
||||
|
@ -80,8 +83,9 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens
|
|||
MS_LOG(ERROR) << "make_tuple_prim_ptr is nullptr";
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
auto make_tuple_cnode =
|
||||
func_graph->NewCNode({NewValueNode(make_tuple_prim_ptr), output_node, tensor_array_write_node});
|
||||
auto make_tuple_vnode = NewValueNode(make_tuple_prim_ptr);
|
||||
MS_CHECK_TRUE_RET(make_tuple_vnode != nullptr, lite::RET_NULL_PTR);
|
||||
auto make_tuple_cnode = func_graph->NewCNode({make_tuple_vnode, output_node, tensor_array_write_node});
|
||||
if (make_tuple_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "NewCNode failed";
|
||||
return lite::RET_NULL_PTR;
|
||||
|
@ -94,7 +98,10 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens
|
|||
MS_LOG(ERROR) << "return_prim_ptr is nullptr";
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
auto new_return_node = func_graph->NewCNode({NewValueNode(return_prim_ptr), make_tuple_cnode});
|
||||
auto return_value_node = NewValueNode(return_prim_ptr);
|
||||
MS_CHECK_TRUE_RET(return_value_node != nullptr, lite::RET_NULL_PTR);
|
||||
auto new_return_node = func_graph->NewCNode({return_value_node, make_tuple_cnode});
|
||||
MS_CHECK_TRUE_RET(new_return_node != nullptr, lite::RET_NULL_PTR);
|
||||
new_return_node->set_fullname_with_scope(return_cnode->fullname_with_scope());
|
||||
MS_ASSERT(new_return_node != nullptr);
|
||||
func_graph->set_return(new_return_node);
|
||||
|
@ -147,11 +154,7 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A
|
|||
return nullptr;
|
||||
}
|
||||
auto tensor_info = utils::cast<tensor::TensorPtr>(abstract_tensor->GetValueTrack());
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor::Tensor of abstract is nullptr";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
MS_ASSERT(tensor_info != nullptr);
|
||||
if (tensor_info->data_type() == kObjectTypeTensorType) {
|
||||
MS_LOG(ERROR) << "tensor::Tensor of abstract is nullptr";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_SUPPORT);
|
||||
|
@ -160,35 +163,32 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A
|
|||
|
||||
// tensor_array
|
||||
auto tensor_array = std::make_shared<ops::TensorArray>();
|
||||
if (tensor_array == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_array is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_RET(tensor_array != nullptr, nullptr);
|
||||
std::vector<int> element_shape;
|
||||
std::for_each(tensor_info->shape().begin(), tensor_info->shape().end(),
|
||||
[&element_shape](int64_t v) { element_shape.push_back(static_cast<int>(v)); });
|
||||
tensor_array->set_element_shape(element_shape);
|
||||
tensor_array->set_data_type(tensor_info->data_type());
|
||||
auto tensor_array_node = func_graph->NewCNode({
|
||||
NewValueNode(tensor_array),
|
||||
NewValueNode(kDefaultNumTensors),
|
||||
});
|
||||
auto tensor_array_vnode = NewValueNode(tensor_array);
|
||||
MS_CHECK_TRUE_RET(tensor_array_vnode != nullptr, nullptr);
|
||||
auto num_tensors_vnode = NewValueNode(kDefaultNumTensors);
|
||||
MS_CHECK_TRUE_RET(num_tensors_vnode != nullptr, nullptr);
|
||||
auto tensor_array_node = func_graph->NewCNode({tensor_array_vnode, num_tensors_vnode});
|
||||
MS_ASSERT(tensor_array_node != nullptr);
|
||||
tensor_array_node->set_abstract(abstract->Clone());
|
||||
tensor_array_node->set_fullname_with_scope(cnode->fullname_with_scope() + "_tensor_array");
|
||||
|
||||
// {"handle", "index", "flow_in"} -> {"tensor"}
|
||||
auto tensor_array_read = std::make_shared<ops::TensorArrayRead>();
|
||||
if (tensor_array_read == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_array_read is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto tensor_array_read_node = func_graph->NewCNode({
|
||||
NewValueNode(tensor_array_read),
|
||||
tensor_array_node,
|
||||
NewValueNode(kDefaultIndex),
|
||||
NewValueNode(kFlowInPlaceHolder),
|
||||
});
|
||||
MS_CHECK_TRUE_RET(tensor_array_read != nullptr, nullptr);
|
||||
auto tensor_array_read_vnode = NewValueNode(tensor_array_read);
|
||||
MS_CHECK_TRUE_RET(tensor_array_read_vnode != nullptr, nullptr);
|
||||
auto read_index_vnode = NewValueNode(kDefaultIndex);
|
||||
MS_CHECK_TRUE_RET(read_index_vnode != nullptr, nullptr);
|
||||
auto read_flow_in_vnode = NewValueNode(kFlowInPlaceHolder);
|
||||
MS_CHECK_TRUE_RET(read_flow_in_vnode != nullptr, nullptr);
|
||||
auto tensor_array_read_node =
|
||||
func_graph->NewCNode({tensor_array_read_vnode, tensor_array_node, read_index_vnode, read_flow_in_vnode});
|
||||
MS_ASSERT(tensor_array_read_node != nullptr);
|
||||
tensor_array_read_node->set_abstract(abstract->Clone());
|
||||
tensor_array_read_node->set_fullname_with_scope(cnode->fullname_with_scope() + "_tensor_array_read");
|
||||
|
@ -196,17 +196,15 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A
|
|||
|
||||
// {"handle", "index", "value", "flow_in"} -> {"flow_out"}
|
||||
auto tensor_array_write = std::make_shared<ops::TensorArrayWrite>();
|
||||
if (tensor_array_write == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_array_write is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto tensor_array_write_node = func_graph->NewCNode({
|
||||
NewValueNode(tensor_array_write),
|
||||
tensor_array_node,
|
||||
NewValueNode(kDefaultIndex),
|
||||
cnode,
|
||||
NewValueNode(kFlowInPlaceHolder),
|
||||
});
|
||||
MS_CHECK_TRUE_RET(tensor_array_write != nullptr, nullptr);
|
||||
auto tensor_array_write_vnode = NewValueNode(tensor_array_write);
|
||||
MS_CHECK_TRUE_RET(tensor_array_write_vnode != nullptr, nullptr);
|
||||
auto write_index_vnode = NewValueNode(kDefaultIndex);
|
||||
MS_CHECK_TRUE_RET(write_index_vnode != nullptr, nullptr);
|
||||
auto write_flow_in_vnode = NewValueNode(kFlowInPlaceHolder);
|
||||
MS_CHECK_TRUE_RET(write_flow_in_vnode != nullptr, nullptr);
|
||||
auto tensor_array_write_node =
|
||||
func_graph->NewCNode({tensor_array_write_vnode, tensor_array_node, write_index_vnode, cnode, write_flow_in_vnode});
|
||||
if (tensor_array_write_node == nullptr) {
|
||||
MS_LOG(ERROR) << "rensor_array_write_node is nullptr";
|
||||
return nullptr;
|
||||
|
|
|
@ -35,6 +35,7 @@ bool ClipConvertActivationPass::Run(const FuncGraphPtr &graph) {
|
|||
MS_ASSERT(graph != nullptr);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -57,10 +58,12 @@ bool ClipConvertActivationPass::Run(const FuncGraphPtr &graph) {
|
|||
if ((min == -1) && (max == -1)) {
|
||||
if (clip_cnode->size() > kClipMinIndex) {
|
||||
auto min_tensor_info = GetTensorInfo(clip_cnode->input(kClipMinIndex));
|
||||
MS_CHECK_TRUE_MSG(min_tensor_info != nullptr, false, "min_tensor_info is nullptr");
|
||||
if (min_tensor_info->data_type() != mindspore::kNumberTypeFloat32) {
|
||||
MS_LOG(ERROR) << "Clip param type invalid";
|
||||
return false;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(min_tensor_info->data_c() != nullptr, false, "tensor data is nullptr");
|
||||
min = *reinterpret_cast<float *>(min_tensor_info->data_c());
|
||||
} else {
|
||||
min = FLT_MIN;
|
||||
|
@ -68,17 +71,19 @@ bool ClipConvertActivationPass::Run(const FuncGraphPtr &graph) {
|
|||
|
||||
if (clip_cnode->size() > kClipMaxIndex) {
|
||||
auto max_tensor_info = GetTensorInfo(clip_cnode->input(kClipMaxIndex));
|
||||
MS_CHECK_TRUE_MSG(max_tensor_info != nullptr, false, "max_tensor_info is nullptr");
|
||||
if (max_tensor_info->data_type() != mindspore::kNumberTypeFloat32) {
|
||||
MS_LOG(ERROR) << "Clip param type invalid";
|
||||
return false;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(max_tensor_info->data_c() != nullptr, false, "tensor data is nullptr");
|
||||
max = *reinterpret_cast<float *>(max_tensor_info->data_c());
|
||||
} else {
|
||||
max = FLT_MAX;
|
||||
}
|
||||
}
|
||||
auto manager = graph->manager();
|
||||
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto primitive_c = std::make_shared<mindspore::ops::Activation>();
|
||||
MS_CHECK_TRUE_MSG(primitive_c != nullptr, false, "primitive_c is nullptr");
|
||||
primitive_c->Init(0, min, max, mindspore::RELU6);
|
||||
|
|
|
@ -50,7 +50,6 @@ void ControlFlowPass::VisitedNodesUsedByAfterParts(const std::set<AnfNodePtr> &v
|
|||
std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg) {
|
||||
std::deque<AnfNodePtr> nodes{};
|
||||
std::set<AnfNodePtr> visited_nodes_used_by_after_fg_set{};
|
||||
std::set<FuncGraphPtr> visited_fg_set{};
|
||||
std::set<AnfNodePtr> remain_nodes_set{};
|
||||
nodes.assign(remain_nodes.begin(), remain_nodes.end());
|
||||
while (!nodes.empty()) {
|
||||
|
@ -150,6 +149,7 @@ int ControlFlowPass::SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow
|
|||
// notice: fg->nodes() is not work in this pass, cause too many useless parameter have been created.
|
||||
auto node_list = TopoSort(fg->get_return());
|
||||
for (auto &node : node_list) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
if (utils::isa<CNodePtr>(node) &&
|
||||
(CheckPrimitiveType(node, prim::kPrimWhile) || CheckPrimitiveType(node, prim::kPrimIf))) {
|
||||
*control_flow_node = node;
|
||||
|
@ -194,6 +194,7 @@ int ControlFlowPass::CreateAfterGraph(const FuncGraphPtr &main_fg, const std::ve
|
|||
*after_fg = std::make_shared<FuncGraph>();
|
||||
MS_CHECK_TRUE_MSG(*after_fg != nullptr, lite::RET_NULL_PTR, "*after_fg is nullptr");
|
||||
auto manager = main_fg->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
manager->AddFuncGraph(*after_fg);
|
||||
(*after_fg)->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
|
||||
(*after_fg)->set_attr("graph_name", MakeValue(aim_cnode->fullname_with_scope() + "_after_fg"));
|
||||
|
@ -244,7 +245,7 @@ int ControlFlowPass::CreateWhileCondCallNode(
|
|||
auto origin_cond_fg_inputs = cond_fg->get_inputs();
|
||||
for (auto &item : visited_nodes_used_by_after_fg) {
|
||||
bool found = false;
|
||||
size_t input_index = -1;
|
||||
size_t input_index = 0;
|
||||
for (size_t i = kPartialFirstInputSize; i < cond_partial_cnode_inputs.size(); ++i) {
|
||||
if (cond_partial_cnode_inputs[i] == item) {
|
||||
found = true;
|
||||
|
@ -276,6 +277,7 @@ int ControlFlowPass::CreateWhileCondCallNode(
|
|||
// insert call node
|
||||
std::vector<AnfNodePtr> call_node_inputs{cond_partial_cnode};
|
||||
*cond_call_cnode = fg->NewCNode(call_node_inputs);
|
||||
MS_CHECK_TRUE_MSG(*cond_call_cnode != nullptr, lite::RET_NULL_PTR, "new cnode is nullptr");
|
||||
(*cond_call_cnode)->set_fullname_with_scope("call_" + cond_partial_cnode->fullname_with_scope());
|
||||
|
||||
return RET_SUCCESS;
|
||||
|
@ -284,7 +286,7 @@ int ControlFlowPass::CreateWhileCondCallNode(
|
|||
int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, const CNodePtr &while_cnode,
|
||||
CNodePtr *body_partial_node) {
|
||||
auto body_vnode = while_cnode->input(kWhileBodyIndex);
|
||||
MS_CHECK_TRUE_MSG(body_vnode != nullptr, lite::RET_NULL_PTR, "body_vnode is nullptr");
|
||||
MS_CHECK_TRUE_MSG(body_vnode != nullptr, RET_FAILED, "body_vnode is nullptr");
|
||||
auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode);
|
||||
if (body_fg == nullptr) {
|
||||
MS_LOG(ERROR) << "Get value as func_graph failed.";
|
||||
|
@ -302,6 +304,7 @@ int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, con
|
|||
auto cond_fg_inputs = cond_fg->get_inputs();
|
||||
body_partial_node_inputs.insert(body_partial_node_inputs.end(), cond_fg_inputs.begin(), cond_fg_inputs.end());
|
||||
*body_partial_node = cond_fg->NewCNode(body_partial_node_inputs);
|
||||
MS_CHECK_TRUE_MSG(*body_partial_node != nullptr, RET_FAILED, "new cnode is nullptr");
|
||||
(*body_partial_node)->set_fullname_with_scope("CNode_" + body_fg->get_attr("graph_name")->ToString());
|
||||
|
||||
// add after inputs for body fg to call cond fg
|
||||
|
@ -312,9 +315,7 @@ int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, con
|
|||
MS_LOG(ERROR) << "fg is not right.";
|
||||
return RET_FAILED;
|
||||
}
|
||||
auto cond_fg_input_para = cond_fg_inputs[i]->cast<ParameterPtr>();
|
||||
auto new_parameter = body_fg->add_parameter();
|
||||
MS_ASSERT(cond_fg_input_para != nullptr);
|
||||
MS_CHECK_TRUE_MSG(new_parameter != nullptr, lite::RET_NULL_PTR, "new_parameter is nullptr");
|
||||
new_parameter->set_name(cond_fg_inputs[i]->fullname_with_scope() + "_body_fg_parameter");
|
||||
new_parameter->set_abstract(cond_fg_inputs[i]->abstract());
|
||||
|
@ -363,7 +364,7 @@ int ControlFlowPass::CreateWhileAfterPartialNode(
|
|||
}
|
||||
|
||||
auto after_value_node = NewValueNode(after_fg);
|
||||
MS_CHECK_TRUE_MSG(after_value_node != nullptr, lite::RET_NULL_PTR, "after_value_node is nullptr");
|
||||
MS_CHECK_TRUE_MSG(after_value_node != nullptr, RET_FAILED, "after_value_node is nullptr");
|
||||
ValueNodePtr partial_anf_primitive = lite::GetPartialFusionPrim();
|
||||
if (partial_anf_primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "GetPartialFusionPrim failed.";
|
||||
|
@ -397,7 +398,7 @@ int ControlFlowPass::CreateWhileAfterPartialNode(
|
|||
|
||||
after_partial_cnode_inputs.push_back(cond_fg_inputs.at(input_index));
|
||||
auto new_parameter = after_fg->add_parameter();
|
||||
MS_CHECK_TRUE_MSG(new_parameter != nullptr, lite::RET_NULL_PTR, "new_parameter != nullptr");
|
||||
MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter != nullptr");
|
||||
new_parameter->set_name(node->fullname_with_scope() + "_after_partial_parameter");
|
||||
new_parameter->set_abstract(node->abstract());
|
||||
after_partial_inputs_and_after_fg_inputs_replace_pairs[node] = new_parameter;
|
||||
|
@ -407,7 +408,7 @@ int ControlFlowPass::CreateWhileAfterPartialNode(
|
|||
for (auto &input : cond_nodes_used_by_after_partial) {
|
||||
after_partial_cnode_inputs.push_back(visited_nodes_and_cond_fg_inputs_replace_pairs.at(input));
|
||||
auto new_parameter = after_fg->add_parameter();
|
||||
MS_CHECK_TRUE_MSG(new_parameter != nullptr, lite::RET_NULL_PTR, "new_parameter != nullptr");
|
||||
MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter != nullptr");
|
||||
new_parameter->set_name(input->fullname_with_scope() + "_after_fg_parameter");
|
||||
new_parameter->set_abstract(input->abstract());
|
||||
visited_nodes_after_fg_replace_pair[visited_nodes_and_cond_fg_inputs_replace_pairs.at(input)] = new_parameter;
|
||||
|
@ -417,6 +418,7 @@ int ControlFlowPass::CreateWhileAfterPartialNode(
|
|||
ReplaceNode(after_fg, after_partial_inputs_and_after_fg_inputs_replace_pairs);
|
||||
ReplaceNode(after_fg, visited_nodes_after_fg_replace_pair);
|
||||
*after_partial_cnode = cond_fg->NewCNode(after_partial_cnode_inputs);
|
||||
MS_CHECK_TRUE_MSG(*after_partial_cnode != nullptr, RET_FAILED, "new cnode is nullptr");
|
||||
(*after_partial_cnode)->set_fullname_with_scope("CNode_" + after_fg->get_attr("graph_name")->ToString());
|
||||
return RET_SUCCESS;
|
||||
}
|
||||
|
@ -448,7 +450,9 @@ int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNo
|
|||
return ret;
|
||||
}
|
||||
|
||||
AnfNodePtr cond_fg_vnode = cond_call_cnode->input(kCNodePrimIndex)->cast<CNodePtr>()->input(kCNodeFirstInputIndex);
|
||||
auto cond_fg_cnode = cond_call_cnode->input(kCNodePrimIndex)->cast<CNodePtr>();
|
||||
MS_ASSERT(cond_fg_cnode != nullptr);
|
||||
AnfNodePtr cond_fg_vnode = cond_fg_cnode->input(kCNodeFirstInputIndex);
|
||||
MS_ASSERT(cond_fg_vnode != nullptr);
|
||||
auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_fg_vnode);
|
||||
MS_CHECK_TRUE_MSG(cond_fg != nullptr, RET_FAILED, "Get value as func_graph failed.");
|
||||
|
@ -492,8 +496,9 @@ int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNo
|
|||
fg->DropNode(while_cnode);
|
||||
fg->set_output(cond_call_cnode);
|
||||
|
||||
auto after_fg =
|
||||
after_partial_cnode->input(kCNodeFirstInputIndex)->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>();
|
||||
auto after_cnode = after_partial_cnode->input(kCNodeFirstInputIndex)->cast<ValueNodePtr>();
|
||||
MS_ASSERT(after_cnode != nullptr);
|
||||
auto after_fg = after_cnode->value()->cast<FuncGraphPtr>();
|
||||
if (after_fg == nullptr) {
|
||||
MS_LOG(ERROR) << "after_fg is nullptr.";
|
||||
return RET_FAILED;
|
||||
|
@ -506,7 +511,9 @@ int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNo
|
|||
int ControlFlowPass::CreateIfPartialNodeExternalInputs(const CNodePtr &if_cnode, const FuncGraphPtr &partial_fg,
|
||||
std::vector<AnfNodePtr> *then_partial_cnode_inputs) {
|
||||
auto if_inputs = if_cnode->inputs();
|
||||
auto partial_fg_name = partial_fg->get_attr("graph_name")->ToString();
|
||||
auto fg_name_attr = partial_fg->get_attr("graph_name");
|
||||
MS_CHECK_TRUE_RET(fg_name_attr != nullptr, RET_FAILED);
|
||||
auto partial_fg_name = fg_name_attr->ToString();
|
||||
std::vector<AnfNodePtr> if_external_inputs{};
|
||||
if_external_inputs.assign(if_inputs.begin() + kIfMinInputSize, if_inputs.end());
|
||||
auto origin_then_fg_inputs = partial_fg->get_inputs();
|
||||
|
@ -523,7 +530,13 @@ int ControlFlowPass::CreateIfPartialNodeExternalInputs(const CNodePtr &if_cnode,
|
|||
auto pos = partial_fg_name.size() + sizeof("_input_");
|
||||
auto pos2 = fg_input_name.find('_', pos);
|
||||
auto idx_str = fg_input_name.substr(pos - 1, pos2 - pos + 1);
|
||||
auto partial_idx = std::stoi(idx_str);
|
||||
auto partial_idx = 0;
|
||||
try {
|
||||
partial_idx = std::stoi(idx_str);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "Get index failed: " << e.what();
|
||||
return RET_FAILED;
|
||||
}
|
||||
then_partial_cnode_inputs->push_back(if_external_inputs.at(partial_idx));
|
||||
}
|
||||
}
|
||||
|
@ -552,7 +565,7 @@ int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &i
|
|||
auto origin_then_fg_inputs = then_fg->get_inputs();
|
||||
for (auto &item : *visited_nodes_used_by_after_fg) {
|
||||
bool found = false;
|
||||
size_t input_index = -1;
|
||||
size_t input_index = 0;
|
||||
for (size_t i = kPartialFirstInputSize; i < then_partial_cnode_inputs.size(); ++i) {
|
||||
if (then_partial_cnode_inputs[i] == item) {
|
||||
found = true;
|
||||
|
@ -580,7 +593,10 @@ int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &i
|
|||
then_nodes_used_by_after_partial.push_back(new_parameter);
|
||||
}
|
||||
*then_partial_cnode = fg->NewCNode(then_partial_cnode_inputs);
|
||||
auto then_fg_name = then_fg->get_attr("graph_name")->ToString();
|
||||
MS_CHECK_TRUE_MSG(*then_partial_cnode != nullptr, RET_FAILED, "new cnode is nullptr");
|
||||
auto fg_name_attr = then_fg->get_attr("graph_name");
|
||||
MS_CHECK_TRUE_RET(fg_name_attr != nullptr, RET_FAILED);
|
||||
auto then_fg_name = fg_name_attr->ToString();
|
||||
(*then_partial_cnode)->set_fullname_with_scope("partial_" + then_fg_name);
|
||||
|
||||
// create after partial node
|
||||
|
@ -764,6 +780,7 @@ int ControlFlowPass::ProcessControlOp(const FuncGraphPtr &fg) {
|
|||
}
|
||||
|
||||
bool ControlFlowPass::Run(const FuncGraphPtr &fg) {
|
||||
MS_ASSERT(fg != nullptr);
|
||||
to_process_q.push_back(fg);
|
||||
while (!to_process_q.empty()) {
|
||||
auto cur_fg = to_process_q.front();
|
||||
|
|
|
@ -79,7 +79,7 @@ bool Conv1DWeightExpandingPass::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_ASSERT(weight_node != nullptr);
|
||||
|
||||
auto prim = GetValueNode<PrimitivePtr>(conv_cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, RET_FAILED, "GetValueNode failed");
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, false, "GetValueNode failed");
|
||||
schema::Format schema_format = schema::Format::Format_KCHW;
|
||||
if (prim->GetAttr(ops::kFormat) != nullptr) {
|
||||
schema_format = static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(ops::kFormat)));
|
||||
|
@ -91,6 +91,6 @@ bool Conv1DWeightExpandingPass::Run(const FuncGraphPtr &func_graph) {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::opt
|
||||
|
|
|
@ -233,8 +233,8 @@ STATUS DecreaseTransposeAlgo::PostTransposeFusion(const FuncGraphPtr &func_graph
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS DecreaseTransposeAlgo::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm,
|
||||
bool before, size_t index) {
|
||||
STATUS DecreaseTransposeAlgo::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> perm, bool before, size_t index) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
AnfNodePtr new_input = nullptr;
|
||||
new_input = transpose_strategy_.TransposePairFuseWhenInsert(func_graph, cnode, perm, before, index);
|
||||
|
@ -388,6 +388,7 @@ STATUS DecreaseTransposeAlgo::InsertPostTransNode(const FuncGraphPtr &func_graph
|
|||
return lite::RET_ERROR;
|
||||
} else {
|
||||
tuple_get_item = GenTupleGetItemNode(func_graph, cnode, 0);
|
||||
MS_CHECK_TRUE_RET(tuple_get_item != nullptr, lite::RET_ERROR);
|
||||
post_node = tuple_get_item;
|
||||
func_graph->manager()->Replace(cnode, tuple_get_item);
|
||||
}
|
||||
|
@ -489,7 +490,13 @@ int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGra
|
|||
auto last_underline = node_name.find_last_of("_");
|
||||
node_name = node_name.substr(0, last_underline);
|
||||
last_underline = node_name.find_last_of("_");
|
||||
auto index = std::stoi(node_name.substr(last_underline + 1)) + static_cast<int>(kInputSizeThree);
|
||||
auto index = 0;
|
||||
try {
|
||||
index = std::stoi(node_name.substr(last_underline + 1)) + static_cast<int>(kInputSizeThree);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "Get index failed: " << e.what();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
param_node->set_abstract(GetCNodeInputAbstract(cnode, index)->Clone());
|
||||
if (utils::isa<CNodePtr>(cnode->input(index))) {
|
||||
ShapeVector shape_vec = {-1};
|
||||
|
|
|
@ -41,7 +41,7 @@ class DecreaseTransposeAlgo : public Pass {
|
|||
|
||||
private:
|
||||
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
|
||||
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> perm, bool before,
|
||||
size_t index = 0);
|
||||
bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph);
|
||||
bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph);
|
||||
|
|
|
@ -62,7 +62,9 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) {
|
|||
MS_LOG(ERROR) << "the node input is invalid.";
|
||||
return false;
|
||||
}
|
||||
auto data_shape = utils::cast<abstract::ShapePtr>(data_node->GetShapeTrack())->shape();
|
||||
auto data_shape_ptr = utils::cast<abstract::ShapePtr>(data_node->GetShapeTrack());
|
||||
MS_ASSERT(data_shape_ptr != nullptr);
|
||||
auto data_shape = data_shape_ptr->shape();
|
||||
if (data_shape.empty()) {
|
||||
MS_LOG(DEBUG) << "the tensor's shape is dynamic.";
|
||||
return true;
|
||||
|
@ -72,11 +74,15 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) {
|
|||
MS_LOG(ERROR) << "the weight node input is invalid.";
|
||||
return false;
|
||||
}
|
||||
auto weight_shape = utils::cast<abstract::ShapePtr>(weight_data_node->GetShapeTrack())->shape();
|
||||
auto weight_shape_ptr = utils::cast<abstract::ShapePtr>(weight_data_node->GetShapeTrack());
|
||||
MS_ASSERT(weight_shape_ptr != nullptr);
|
||||
auto weight_shape = weight_shape_ptr->shape();
|
||||
if (weight_shape.empty()) {
|
||||
MS_LOG(DEBUG) << "the weight's shape is dynamic.";
|
||||
return true;
|
||||
}
|
||||
MS_CHECK_TRUE_RET(data_shape.size() == DIMENSION_4D, false);
|
||||
MS_CHECK_TRUE_RET(weight_shape.size() == DIMENSION_4D, false);
|
||||
if (data_shape[3] == 1 || data_shape[3] != weight_shape[3]) {
|
||||
conv2d_fusion->EraseAttr(ops::kIsDepthWise);
|
||||
conv2d_fusion->set_group(static_cast<int64_t>(data_shape[3]));
|
||||
|
|
|
@ -50,7 +50,9 @@ int GetCNodeCertainInputFormat(const CNodePtr cnode, int index, mindspore::Forma
|
|||
MS_LOG(ERROR) << "cnode has no format attr. " << real_cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
*format = static_cast<mindspore::Format>(GetValue<int64_t>(primitive->GetAttr(ops::kFormat)));
|
||||
auto format_attr = primitive->GetAttr(ops::kFormat);
|
||||
MS_CHECK_TRUE_MSG(format_attr != nullptr, lite::RET_NULL_PTR, "GetAttr Failed");
|
||||
*format = static_cast<mindspore::Format>(GetValue<int64_t>(format_attr));
|
||||
if (CheckPrimitiveType(real_cnode, prim::kPrimTranspose)) {
|
||||
std::vector<int> perm;
|
||||
if (GetTransposePerm(real_cnode, &perm) != lite::RET_OK) {
|
||||
|
@ -184,7 +186,7 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
auto ret = SetSubGraphInput(cnode, sub_func_graph);
|
||||
if (ret != RET_OK) {
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphInput failed: " << ret;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
@ -202,7 +204,7 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
ret = SetSubGraphInput(cnode, sub_func_graph);
|
||||
if (ret != RET_OK) {
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphInput failed: " << ret;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
@ -215,7 +217,7 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
ret = SetSubGraphAbstract(cnode, sub_func_graph);
|
||||
if (ret != RET_OK) {
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphAbstract failed: " << ret;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
@ -241,13 +243,20 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt
|
|||
auto last_underline = node_name.find_last_of("_");
|
||||
node_name = node_name.substr(0, last_underline);
|
||||
last_underline = node_name.find_last_of("_");
|
||||
auto index = std::stoi(node_name.substr(last_underline + 1)) + 3;
|
||||
auto index = 0;
|
||||
try {
|
||||
index = std::stoi(node_name.substr(last_underline + 1)) + 3;
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "Get index failed: " << e.what();
|
||||
return RET_ERROR;
|
||||
}
|
||||
param_node->set_abstract(opt::GetCNodeInputAbstract(cnode, index)->Clone());
|
||||
if (utils::isa<CNodePtr>(cnode->input(index))) {
|
||||
ShapeVector shape_vec = {-1};
|
||||
auto out_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||
MS_ASSERT(trans_cnode != nullptr);
|
||||
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(out_prim != nullptr, lite::RET_ERROR, "GetValueNode Failed");
|
||||
if (out_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(opt::kInferDone))) {
|
||||
auto abstract_shape = std::make_shared<abstract::Shape>(shape_vec);
|
||||
CHECK_NULL_RETURN(abstract_shape);
|
||||
|
|
|
@ -34,7 +34,7 @@ namespace opt {
|
|||
namespace {
|
||||
constexpr int kInputChannal = 3;
|
||||
constexpr size_t INITIAL_SIZE = 1024;
|
||||
void RectifyFormat(const CNodePtr &cnode, const std::vector<lite::Tensor *> &inputs, FmkType fmk_type) {
|
||||
void RectifyFormat(const std::vector<lite::Tensor *> &inputs, FmkType fmk_type) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (fmk_type != converter::kFmkTypeOnnx) {
|
||||
return;
|
||||
|
@ -127,7 +127,7 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
|
|||
fbb.Clear();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
RectifyFormat(cnode, inputs, fmk_type_);
|
||||
RectifyFormat(inputs, fmk_type_);
|
||||
ret = KernelInferShape(inputs, outputs, parameter);
|
||||
if (parameter->destroy_func_ != nullptr) {
|
||||
parameter->destroy_func_(parameter);
|
||||
|
@ -208,7 +208,8 @@ std::vector<int> NodeInferShape::GetIntVecInput(const CNodePtr &cnode, size_t in
|
|||
if (specify_tensors.front()->shape().size() != 1) {
|
||||
return {};
|
||||
}
|
||||
tensor_data.resize(specify_tensors.front()->shape()[0]);
|
||||
MS_CHECK_GE(specify_tensors.front()->shape()[0], 0, {});
|
||||
tensor_data.resize(static_cast<size_t>(specify_tensors.front()->shape()[0]));
|
||||
if (memcpy_s(tensor_data.data(), tensor_data.size() * sizeof(int), specify_tensors.front()->data(),
|
||||
specify_tensors.front()->Size()) != EOK) {
|
||||
return {};
|
||||
|
|
|
@ -264,6 +264,27 @@ int RemoveRedundantOpPass::RemoveDropoutOp(const AnfNodePtr &anf_node, const Fun
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int RemoveRedundantOpPass::GetConstDataFromInputNode(const CNodePtr &cnode, lite::DataInfo *data_info) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
MS_ASSERT(data_info != nullptr);
|
||||
auto padding_node = cnode->input(kInputIndexTwo);
|
||||
MS_ASSERT(padding_node != nullptr);
|
||||
if (utils::isa<Parameter>(padding_node)) {
|
||||
auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::kFmkTypeMs, false, data_info);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "fetch data from parameter node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else if (utils::isa<ValueNode>(padding_node)) {
|
||||
auto status = lite::FetchDataFromValueNode(cnode, 2, converter::kFmkTypeMs, false, data_info);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "fetch data from value node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
||||
if (!utils::isa<CNodePtr>(anf_node)) {
|
||||
MS_LOG(DEBUG) << "anf node is node a cnode.";
|
||||
|
@ -278,20 +299,10 @@ int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const
|
|||
}
|
||||
auto is_invalid = true;
|
||||
if (cnode->size() > kInputSizeTwo) {
|
||||
auto padding_node = cnode->input(kInputIndexTwo);
|
||||
lite::DataInfo data_info;
|
||||
if (utils::isa<Parameter>(padding_node)) {
|
||||
auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::kFmkTypeMs, false, &data_info);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "fetch data from parameter node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else if (utils::isa<ValueNode>(padding_node)) {
|
||||
auto status = lite::FetchDataFromValueNode(cnode, 2, converter::kFmkTypeMs, false, &data_info);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "fetch data from value node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (GetConstDataFromInputNode(cnode, &data_info) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get pad data failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (!data_info.data_.empty()) {
|
||||
auto pad_data = reinterpret_cast<int *>(data_info.data_.data());
|
||||
|
@ -308,19 +319,18 @@ int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const
|
|||
} else {
|
||||
auto pad_prim = utils::cast<std::shared_ptr<mindspore::ops::PadFusion>>(primitive);
|
||||
MS_ASSERT(pad_prim != nullptr);
|
||||
if (pad_prim->GetAttr(ops::kPadding) != nullptr) {
|
||||
auto pad_data = pad_prim->get_paddings();
|
||||
for (size_t i = 0; i < pad_data.size(); i++) {
|
||||
for (size_t j = 0; j < pad_data[i].size(); j++) {
|
||||
if (pad_data[i][j] != 0) {
|
||||
is_invalid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_invalid == false) {
|
||||
MS_CHECK_TRUE_RET(pad_prim->GetAttr(ops::kPadding) != nullptr, lite::RET_ERROR);
|
||||
auto pad_data = pad_prim->get_paddings();
|
||||
for (size_t i = 0; i < pad_data.size(); i++) {
|
||||
for (size_t j = 0; j < pad_data[i].size(); j++) {
|
||||
if (pad_data[i][j] != 0) {
|
||||
is_invalid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_invalid == false) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (is_invalid) {
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/anf_exporter/fetch_content.h"
|
||||
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
|
@ -38,6 +39,7 @@ class RemoveRedundantOpPass : public Pass {
|
|||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
int GetConstDataFromInputNode(const CNodePtr &cnode, lite::DataInfo *data_info);
|
||||
bool is_train_model_ = false;
|
||||
std::set<AnfNodePtr> remove_cnode_;
|
||||
};
|
||||
|
|
|
@ -222,7 +222,7 @@ STATUS SlicePreposePass::SwapSliceWithPreceed(const FuncGraphPtr &graph, const C
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
ValueNodePtr SlicePreposePass::CreateSliceValueNode(const FuncGraphPtr &graph, const std::vector<int64_t> &axes) {
|
||||
ValueNodePtr SlicePreposePass::CreateSliceValueNode(const std::vector<int64_t> &axes) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(slice_cnode != nullptr);
|
||||
auto new_slice = std::make_shared<mindspore::ops::SliceFusion>();
|
||||
|
@ -233,7 +233,7 @@ ValueNodePtr SlicePreposePass::CreateSliceValueNode(const FuncGraphPtr &graph, c
|
|||
return value_node;
|
||||
}
|
||||
|
||||
ValueNodePtr SlicePreposePass::CopySliceValueNode(const FuncGraphPtr &graph, const CNodePtr &slice_cnode) {
|
||||
ValueNodePtr SlicePreposePass::CopySliceValueNode(const CNodePtr &slice_cnode) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(slice_cnode != nullptr);
|
||||
auto slice_c = GetValueNode<std::shared_ptr<mindspore::ops::SliceFusion>>(slice_cnode->input(0));
|
||||
|
@ -400,7 +400,7 @@ CNodePtr SlicePreposePass::CreateReshapeCNode(const FuncGraphPtr &graph, const s
|
|||
return reshape_cnode;
|
||||
}
|
||||
|
||||
bool SlicePreposePass::SiblingsAreSameSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &output_node_list,
|
||||
bool SlicePreposePass::SiblingsAreSameSlice(const NodeUsedListPtr &output_node_list,
|
||||
const std::vector<int64_t> &ref_shape) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(output_node_list != nullptr);
|
||||
|
@ -512,15 +512,15 @@ int64_t SlicePreposePass::GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode
|
|||
MS_ASSERT(shape_out_copy != nullptr);
|
||||
MS_ASSERT(is_normal_mode != nullptr);
|
||||
MS_ASSERT(support_abnormal_mode != nullptr);
|
||||
int64_t abnormal_index_out = -1;
|
||||
auto slice_node = GetSlice(slice_cnode);
|
||||
if (slice_node == nullptr) {
|
||||
MS_LOG(ERROR) << "slice is nullptr";
|
||||
return false;
|
||||
return abnormal_index_out;
|
||||
}
|
||||
auto slice_axes = slice_node->get_axes();
|
||||
auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
|
||||
auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
|
||||
int64_t abnormal_index_out = -1;
|
||||
for (size_t j = 0; j < shape_out.size(); ++j) {
|
||||
int index = -1;
|
||||
for (size_t i = 0; i < slice_axes.size(); ++i) {
|
||||
|
@ -532,8 +532,8 @@ int64_t SlicePreposePass::GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode
|
|||
if (index == -1) {
|
||||
continue;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > index, false, "slice_begin.size() is wrong");
|
||||
MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > index, false, "slice_size.size() is wrong");
|
||||
MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > index, abnormal_index_out, "slice_begin.size() is wrong");
|
||||
MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > index, abnormal_index_out, "slice_size.size() is wrong");
|
||||
if (slice_begin[index] != 0 || (slice_size[index] != -1 && slice_size[index] != shape_out[j])) {
|
||||
if (mapped_axe[j] == -1) {
|
||||
if (is_normal_mode) {
|
||||
|
@ -634,7 +634,7 @@ CNodePtr SlicePreposePass::CreateSlice1ForReshapePrepose(const FuncGraphPtr &gra
|
|||
} else {
|
||||
new_size1[abnormal_axe_in] = static_cast<int>(shape_in[abnormal_axe_in] - count_sliced_axe_in);
|
||||
}
|
||||
auto new_slice1 = CreateSliceValueNode(graph, new_axes1);
|
||||
auto new_slice1 = CreateSliceValueNode(new_axes1);
|
||||
if (new_slice1 == nullptr) {
|
||||
MS_LOG(ERROR) << "CreateSliceValueNode failed";
|
||||
return nullptr;
|
||||
|
@ -660,8 +660,7 @@ CNodePtr SlicePreposePass::CreateSlice1ForReshapePrepose(const FuncGraphPtr &gra
|
|||
CNodePtr SlicePreposePass::CreateSlice2ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
|
||||
const CNodePtr &new_reshape1_cnode,
|
||||
const std::vector<int64_t> &new_shape1,
|
||||
const int64_t abnormal_axe_in,
|
||||
const int64_t count_sliced_axe_in, const int64_t count_sliced2,
|
||||
const int64_t abnormal_axe_in, const int64_t count_sliced2,
|
||||
const bool slice_at_front) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(slice_cnode != nullptr);
|
||||
|
@ -679,7 +678,7 @@ CNodePtr SlicePreposePass::CreateSlice2ForReshapePrepose(const FuncGraphPtr &gra
|
|||
} else {
|
||||
new_size2[abnormal_axe_in] = static_cast<int>(count_sliced2);
|
||||
}
|
||||
auto new_slice2 = CreateSliceValueNode(graph, new_axes2);
|
||||
auto new_slice2 = CreateSliceValueNode(new_axes2);
|
||||
if (new_slice2 == nullptr) {
|
||||
MS_LOG(ERROR) << "CreateSliceValueNode failed";
|
||||
return nullptr;
|
||||
|
@ -782,9 +781,8 @@ bool SlicePreposePass::PreposeWithAbnormalReshape(const FuncGraphPtr &graph, con
|
|||
const int64_t count_sliced_abnormal_axe =
|
||||
shape_out[abnormal_axe_out] - (count_sliced_axe_front + count_sliced_axe_rear);
|
||||
const int64_t count_sliced2 = count_sliced_abnormal_axe * outer_size_out;
|
||||
auto new_slice2_cnode =
|
||||
CreateSlice2ForReshapePrepose(graph, slice_cnode, new_reshape1_cnode, new_shape1, abnormal_axe_in,
|
||||
count_sliced_axe_in, count_sliced2, slice_at_front);
|
||||
auto new_slice2_cnode = CreateSlice2ForReshapePrepose(graph, slice_cnode, new_reshape1_cnode, new_shape1,
|
||||
abnormal_axe_in, count_sliced2, slice_at_front);
|
||||
if (new_slice2_cnode == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
@ -977,7 +975,7 @@ bool SlicePreposePass::PreposeWithMatmul(const FuncGraphPtr &graph, const CNodeP
|
|||
return false;
|
||||
}
|
||||
auto slice_node = GetSlice(slice_cnode);
|
||||
MS_CHECK_TRUE_MSG(slice_node != nullptr, RET_ERROR, "slice is nullptr");
|
||||
MS_CHECK_TRUE_MSG(slice_node != nullptr, false, "slice is nullptr");
|
||||
auto axes = slice_node->get_axes();
|
||||
auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
|
||||
auto size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
|
||||
|
@ -1012,7 +1010,7 @@ bool SlicePreposePass::PreposeWithMatmul(const FuncGraphPtr &graph, const CNodeP
|
|||
left_size[i] = -1;
|
||||
}
|
||||
}
|
||||
auto left_slice_vnode = CreateSliceValueNode(graph, left_axes);
|
||||
auto left_slice_vnode = CreateSliceValueNode(left_axes);
|
||||
MS_CHECK_TRUE_MSG(left_slice_vnode != nullptr, false, "CreateSliceValueNode failed");
|
||||
auto begin_parameter = BuildIntVecParameterNode(
|
||||
graph, left_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
|
||||
|
@ -1050,7 +1048,7 @@ bool SlicePreposePass::PreposeWithMatmul(const FuncGraphPtr &graph, const CNodeP
|
|||
MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
|
||||
MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
|
||||
node_name_index += 1;
|
||||
auto right_slice_vnode = CreateSliceValueNode(graph, right_axes);
|
||||
auto right_slice_vnode = CreateSliceValueNode(right_axes);
|
||||
MS_CHECK_TRUE_MSG(right_slice_vnode != nullptr, false, "CreateSliceValueNode failed");
|
||||
const std::vector<AnfNodePtr> inputs = {right_slice_vnode, matmul_cnode->input(2), begin_parameter, size_parameter};
|
||||
auto new_slice_cnode = InsertSlice(graph, inputs, matmul_cnode, 2, tr);
|
||||
|
@ -1094,7 +1092,7 @@ bool SlicePreposePass::PreposeWithFullConnection(const FuncGraphPtr &graph, cons
|
|||
auto slice_node = GetSlice(slice_cnode);
|
||||
if (slice_node == nullptr) {
|
||||
MS_LOG(ERROR) << "slice is nullptr";
|
||||
return RET_ERROR;
|
||||
return false;
|
||||
}
|
||||
auto axes = slice_node->get_axes();
|
||||
auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
|
||||
|
@ -1134,7 +1132,7 @@ bool SlicePreposePass::PreposeWithFullConnection(const FuncGraphPtr &graph, cons
|
|||
std::vector<int> new_size(shape_in.size(), -1);
|
||||
new_begin[mapped_axe[0]] = begin[0];
|
||||
new_size[mapped_axe[0]] = size[0];
|
||||
auto new_slice_vnode = CreateSliceValueNode(graph, new_axes);
|
||||
auto new_slice_vnode = CreateSliceValueNode(new_axes);
|
||||
if (new_slice_vnode == nullptr) {
|
||||
MS_LOG(ERROR) << "CreateSliceValueNode failed";
|
||||
return false;
|
||||
|
@ -1267,7 +1265,7 @@ bool SlicePreposePass::PreposeWithArithmetic(const FuncGraphPtr &graph, const CN
|
|||
continue;
|
||||
} else if (shape.empty()) { // infershape failed at this input
|
||||
if (IsScalarNode(another_input)) { // if another input is scalar, we can process this one
|
||||
auto new_slice_vnode = CopySliceValueNode(graph, slice_cnode);
|
||||
auto new_slice_vnode = CopySliceValueNode(slice_cnode);
|
||||
if (new_slice_vnode == nullptr) {
|
||||
changed = false;
|
||||
break;
|
||||
|
@ -1299,7 +1297,7 @@ bool SlicePreposePass::PreposeWithArithmetic(const FuncGraphPtr &graph, const CN
|
|||
changed = false;
|
||||
break;
|
||||
}
|
||||
auto new_slice_vnode = CreateSliceValueNode(graph, new_axes);
|
||||
auto new_slice_vnode = CreateSliceValueNode(new_axes);
|
||||
if (new_slice_vnode == nullptr) {
|
||||
changed = false;
|
||||
break;
|
||||
|
@ -1525,7 +1523,7 @@ bool SlicePreposePass::Run(const FuncGraphPtr &graph) {
|
|||
}
|
||||
auto output_node_list = GetRealNodeUsedList(graph, utils::cast<AnfNodePtr>(preceed_node));
|
||||
if (output_node_list->size() > 1) { // referenced by multi nodes
|
||||
if (SiblingsAreSameSlice(graph, output_node_list) && MergeParallelSlice(graph, output_node_list)) {
|
||||
if (SiblingsAreSameSlice(output_node_list) && MergeParallelSlice(graph, output_node_list)) {
|
||||
this_time_changed = true;
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -43,8 +43,8 @@ class SlicePreposePass : public Pass {
|
|||
static void ClearCNodeAbstractValue(const CNodePtr &cnode);
|
||||
static STATUS SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
|
||||
const CNodePtr &preceed_cnode, int index, const TransactionPtr &tr = nullptr);
|
||||
static ValueNodePtr CreateSliceValueNode(const FuncGraphPtr &graph, const std::vector<int64_t> &axes);
|
||||
static ValueNodePtr CopySliceValueNode(const FuncGraphPtr &graph, const CNodePtr &slice_cnode);
|
||||
static ValueNodePtr CreateSliceValueNode(const std::vector<int64_t> &axes);
|
||||
static ValueNodePtr CopySliceValueNode(const CNodePtr &slice_cnode);
|
||||
static CNodePtr InsertSlice(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs,
|
||||
const CNodePtr &preceed_cnode, int index, const TransactionPtr &tr);
|
||||
static STATUS VerifySliceAttrs(const CNodePtr &slice_cnode, int dim = -1);
|
||||
|
@ -52,8 +52,7 @@ class SlicePreposePass : public Pass {
|
|||
std::vector<int64_t> *axes, std::vector<int> *begin, std::vector<int> *size);
|
||||
static CNodePtr CreateReshapeCNode(const FuncGraphPtr &graph, const std::vector<int64_t> &shape,
|
||||
const AbstractBasePtr &abstract, const CNodePtr &preceed_cnode);
|
||||
static bool SiblingsAreSameSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &output_node_list,
|
||||
const std::vector<int64_t> &ref_shape = {});
|
||||
static bool SiblingsAreSameSlice(const NodeUsedListPtr &output_node_list, const std::vector<int64_t> &ref_shape = {});
|
||||
static int64_t GetReshapeAbnormalAxeIn(const std::vector<int64_t> &shape_in, const std::vector<int64_t> &shape_out,
|
||||
std::vector<int64_t> *mapped_axe);
|
||||
static int64_t GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode, const std::vector<int64_t> &mapped_axe,
|
||||
|
@ -70,8 +69,7 @@ class SlicePreposePass : public Pass {
|
|||
static CNodePtr CreateSlice2ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
|
||||
const CNodePtr &new_reshape1_cnode,
|
||||
const std::vector<int64_t> &new_shape1, int64_t abnormal_axe_in,
|
||||
int64_t count_sliced_axe_in, int64_t count_sliced2,
|
||||
bool slice_at_front);
|
||||
int64_t count_sliced2, bool slice_at_front);
|
||||
static bool PreposeWithAbnormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
|
||||
const CNodePtr &reshape_cnode, const CNodePtr &matmul_cnode,
|
||||
const std::vector<int64_t> &shape_in, const std::vector<int64_t> &shape_out,
|
||||
|
|
|
@ -136,6 +136,7 @@ STATUS ChangeCommonOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode, For
|
|||
if (prim->GetAttr(ops::kAxis) == nullptr) {
|
||||
return lite::RET_NOT_SUPPORT;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(prim->GetAttr(ops::kAxis) != nullptr, lite::RET_NULL_PTR, "GetAttr Failed.");
|
||||
auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
|
||||
if (axis < 0) {
|
||||
axis += kInputSizeFour;
|
||||
|
@ -160,17 +161,21 @@ STATUS ChangeOpCrop(const FuncGraphPtr &func_graph, const CNodePtr &cnode, Forma
|
|||
MS_LOG(ERROR) << "cnode is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_RET(crop_prim->GetAttr(ops::kAxis) != nullptr, lite::RET_ERROR);
|
||||
auto axis = crop_prim->get_axis();
|
||||
if (axis < 0) {
|
||||
axis += kInputSizeFour;
|
||||
}
|
||||
MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
|
||||
MS_CHECK_TRUE_RET(crop_prim->GetAttr(ops::kOffsets) != nullptr, lite::RET_ERROR);
|
||||
auto offsets = crop_prim->get_offsets();
|
||||
if (trans_type == kNCHW2NHWC) {
|
||||
auto new_axis = kNH2NC[axis];
|
||||
if (new_axis == 0) {
|
||||
MS_CHECK_GE(offsets.size(), kInputIndexFour, lite::RET_ERROR);
|
||||
offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]};
|
||||
} else if (new_axis == kInputIndexThree) {
|
||||
MS_CHECK_GE(offsets.size(), kInputIndexThree, lite::RET_ERROR);
|
||||
offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]};
|
||||
} else {
|
||||
offsets.push_back(0);
|
||||
|
@ -238,7 +243,10 @@ STATUS ChangeOpPad(const FuncGraphPtr &func_graph, const CNodePtr &cnode, Format
|
|||
}
|
||||
auto param_node =
|
||||
BuildIntVec2DParameterNode(func_graph, padding_list, cnode->input(kInputIndexTwo)->fullname_with_scope());
|
||||
func_graph->manager()->Replace(cnode->input(kInputIndexTwo), param_node);
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_NULL_PTR, "BuildParameterNode Failed");
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
manager->Replace(cnode->input(kInputIndexTwo), param_node);
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
|
||||
if (prim->GetAttr(ops::kPaddings) != nullptr) {
|
||||
|
@ -280,7 +288,10 @@ STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, Form
|
|||
[](int64_t v) { return static_cast<int>(v); });
|
||||
}
|
||||
for (size_t i = 2; i < cnode->size(); ++i) {
|
||||
TransformAttrByAxes(func_graph, cnode, i, axes, trans_type, node_infer_shape);
|
||||
if (TransformAttrByAxes(func_graph, cnode, i, axes, trans_type, node_infer_shape) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Transform axes failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto tmp_axes = TransformOpAxesAttr(axes, trans_type);
|
||||
std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end());
|
||||
|
@ -312,13 +323,18 @@ STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode
|
|||
if (index == kInputIndexFour) {
|
||||
continue;
|
||||
}
|
||||
TransformAttrByAxes(func_graph, cnode, index, axes, trans_type, node_infer_shape);
|
||||
if (TransformAttrByAxes(func_graph, cnode, index, axes, trans_type, node_infer_shape) != RET_OK) {
|
||||
MS_LOG(ERROR) << "transform axes failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto cur_axes = TransformOpAxesAttr(axes, trans_type);
|
||||
auto param_node =
|
||||
BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope());
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "BuildIntVecParameterNode failed");
|
||||
func_graph->manager()->Replace(cnode->input(kInputIndexFour), param_node);
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
manager->Replace(cnode->input(kInputIndexFour), param_node);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
@ -481,6 +497,7 @@ STATUS TransposeStrategy::TransposeInsertDependOnShape(const FuncGraphPtr &func_
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
CNodePtr base_node = before ? cnode : node_users.front().first->cast<CNodePtr>();
|
||||
MS_ASSERT(base_node != nullptr);
|
||||
size_t input_index = before ? index : node_users.front().second;
|
||||
auto shape = node_infer_shape_.GetInputShape(base_node, input_index);
|
||||
if (!shape.empty() && shape.size() != kNH2NC.size()) {
|
||||
|
@ -522,7 +539,7 @@ bool TransposeStrategy::IsInOutCanFuison(const FuncGraphPtr &func_graph, const s
|
|||
return true;
|
||||
}
|
||||
|
||||
void TransposeStrategy::DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info) {
|
||||
void TransposeStrategy::DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info) const {
|
||||
if (trans_info->pre_ == trans_info->post_) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -50,7 +50,7 @@ class TransposeStrategy {
|
|||
STATUS TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool before, size_t index);
|
||||
bool IsInOutCanFuison(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, size_t *trans_count,
|
||||
FormatTransNodeType *trans_type);
|
||||
void DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info);
|
||||
void DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info) const;
|
||||
FmkType fmk_type_{converter::kFmkTypeMs};
|
||||
bool train_flag_{false};
|
||||
NodeInferShape node_infer_shape_;
|
||||
|
|
|
@ -24,7 +24,7 @@ void RemoveUnusedCastOpPass::SetFmkType(FmkType type) { this->fmk_type = type; }
|
|||
bool RemoveUnusedCastOpPass::Run(const FuncGraphPtr &func_graph) {
|
||||
if (this->fmk_type != converter::kFmkTypeMs) {
|
||||
MS_LOG(ERROR) << "The framework type of model should be mindspore.";
|
||||
return RET_ERROR;
|
||||
return false;
|
||||
}
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = func_graph->manager();
|
||||
|
@ -42,12 +42,12 @@ bool RemoveUnusedCastOpPass::Run(const FuncGraphPtr &func_graph) {
|
|||
auto abstract_base = cast_cnode->input(1)->abstract();
|
||||
if (abstract_base == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << cast_cnode->input(1)->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
return false;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, "
|
||||
<< cast_cnode->input(1)->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
return false;
|
||||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||
auto input_type = abstract_tensor->element()->GetTypeTrack();
|
||||
|
@ -56,12 +56,12 @@ bool RemoveUnusedCastOpPass::Run(const FuncGraphPtr &func_graph) {
|
|||
|
||||
if (cast_cnode->inputs().size() != kCastInputNum || !utils::isa<ValueNodePtr>(cast_cnode->input(2))) {
|
||||
MS_LOG(ERROR) << "Second input of cast should be a ValueNode";
|
||||
return RET_ERROR;
|
||||
return false;
|
||||
}
|
||||
auto output_type = GetValueNode<NumberPtr>(cast_cnode->input(2));
|
||||
if (output_type == nullptr) {
|
||||
MS_LOG(ERROR) << "Second input of cast is nullptr";
|
||||
return RET_ERROR;
|
||||
return false;
|
||||
}
|
||||
auto output_type_value = output_type->type_id();
|
||||
if ((input_type_value == kNumberTypeFloat32 && output_type_value == kNumberTypeFloat16) ||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "ops/transpose.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
constexpr size_t kTransposeInput = 1;
|
||||
|
@ -60,7 +61,7 @@ std::vector<int> GetTransposePerm(const CNodePtr &node) {
|
|||
bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) {
|
||||
if (this->fmk_type != converter::kFmkTypeOnnx) {
|
||||
MS_LOG(ERROR) << "The framework type of model should be onnx.";
|
||||
return RET_ERROR;
|
||||
return false;
|
||||
}
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = func_graph->manager();
|
||||
|
|
|
@ -67,6 +67,7 @@ STATUS UpdateConv2DParamPass::UpdateConv2DAttr(const CNodePtr &cnode) {
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_ASSERT(prim != nullptr);
|
||||
if (prim->GetAttr(ops::kFormat) == nullptr) {
|
||||
MS_LOG(ERROR) << "current conv2d's format is undefined.";
|
||||
return lite::RET_ERROR;
|
||||
|
@ -85,6 +86,7 @@ STATUS UpdateConv2DParamPass::UpdateConv2DAttr(const CNodePtr &cnode) {
|
|||
prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
|
||||
prim->AddAttr(ops::kGroup, MakeValue(is_depth_wise ? out_channel : 1));
|
||||
}
|
||||
MS_ASSERT(prim->GetAttr(ops::kGroup) != nullptr);
|
||||
auto group = GetValue<int64_t>(prim->GetAttr(ops::kGroup));
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
|
||||
std::swap(in_channel, out_channel);
|
||||
|
@ -101,8 +103,6 @@ STATUS UpdateConv2DParamPass::UpdateConv2DAttr(const CNodePtr &cnode) {
|
|||
|
||||
bool UpdateConv2DParamPass::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
|
|
|
@ -92,6 +92,7 @@ int Conv2DInfo::CheckStrategy(const SplitStrategy &strategy) {
|
|||
|
||||
int Conv2DInfo::CheckIfSplit() {
|
||||
auto conv_prim = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(cnode_->input(kAnfPrimitiveIndex));
|
||||
MS_ASSERT(conv_prim != nullptr);
|
||||
auto strides = conv_prim->get_stride();
|
||||
std::vector<int64_t> weight_shape;
|
||||
std::vector<int64_t> input_shape;
|
||||
|
@ -99,7 +100,9 @@ int Conv2DInfo::CheckIfSplit() {
|
|||
// for n, h, cin, we should checkout it's input whether bigger than split total ratio
|
||||
if (split_mode_ != SplitCOUT) {
|
||||
auto input_node_abstract = GetCNodeInputAbstract(cnode_, 1);
|
||||
MS_CHECK_TRUE_RET(input_node_abstract != nullptr, RET_ERROR);
|
||||
auto weight_node_abstract = GetCNodeInputAbstract(cnode_, 2);
|
||||
MS_CHECK_TRUE_RET(weight_node_abstract != nullptr, RET_ERROR);
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(input_node_abstract)) {
|
||||
MS_LOG(ERROR) << "conv_input_abstract of should be abstract tensor";
|
||||
return RET_ERROR;
|
||||
|
@ -163,8 +166,11 @@ AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t in
|
|||
MS_CHECK_GE(conv_prim->get_stride().size(), 1, nullptr);
|
||||
auto extend_bottom = conv_prim->get_kernel_size().at(kIndexH) - conv_prim->get_stride().at(kIndexH);
|
||||
auto bottom_vector = std::vector<int64_t>(split_num, extend_bottom);
|
||||
MS_CHECK_GE(split_num, 1, nullptr);
|
||||
bottom_vector[split_num - 1] = 0;
|
||||
split_prim->set_extend_bottom(bottom_vector);
|
||||
MS_CHECK_GE(conv_prim->get_pad_list().size(), 1, nullptr);
|
||||
MS_CHECK_TRUE_RET(input_shape.size() == DIMENSION_4D, nullptr);
|
||||
if (!UpdateRatioWithPadStride(new_splits.data(), new_splits.size(), split_num, input_shape[split_dim],
|
||||
conv_prim->get_pad_list().at(kPadUp), conv_prim->get_stride().at(kIndexH))) {
|
||||
MS_LOG(ERROR) << "UpdateRatioWithPadStride failed";
|
||||
|
@ -178,7 +184,9 @@ AnfNodePtr Conv2DInfo::CreateOutputsOfSplit(const CNodePtr &orig_node, size_t in
|
|||
split_prim->set_number_split(split_num);
|
||||
split_prim->set_ratio(new_splits);
|
||||
|
||||
std::vector<AnfNodePtr> split_inputs = {NewValueNode(split_prim)};
|
||||
auto split_primitive = NewValueNode(split_prim);
|
||||
MS_CHECK_TRUE_MSG(split_primitive != nullptr, nullptr, "create SplitWithOverlap return nullptr");
|
||||
std::vector<AnfNodePtr> split_inputs = {split_primitive};
|
||||
// ori_conv_node must only have one input
|
||||
split_inputs.push_back(orig_node->input(input_index + 1));
|
||||
auto split_cnode = func_graph_->NewCNode(split_inputs);
|
||||
|
@ -225,7 +233,6 @@ int Conv2DInfo::InferParallelCNodes() {
|
|||
if (CheckIfSplit() != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
Strategys strategys = strategy_.strategys;
|
||||
size_t dev_num = strategy_.dev_num;
|
||||
std::vector<AnfNodePtr> feature_split_outputs;
|
||||
std::vector<AnfNodePtr> kernel_split_outputs;
|
||||
|
@ -361,7 +368,7 @@ int Conv2DInfo::InferReplaceOp() {
|
|||
size_t dev_num = strategy_.dev_num;
|
||||
if (split_mode_ == SplitCIN) {
|
||||
MS_LOG(DEBUG) << name_ << " : Split Cin, infer Forward op.";
|
||||
replace_op_ = CreateReduceNode(cnode_, parallel_output_nodes_, kAxisCIn, dev_num, true);
|
||||
replace_op_ = CreateReduceNode(cnode_, parallel_output_nodes_, kAxisCIn, dev_num);
|
||||
} else {
|
||||
int32_t concat_dim;
|
||||
if (split_mode_ == SplitN) {
|
||||
|
@ -372,7 +379,7 @@ int Conv2DInfo::InferReplaceOp() {
|
|||
} else {
|
||||
concat_dim = kAxisH;
|
||||
}
|
||||
replace_op_ = CreateConcateNode(cnode_, parallel_output_nodes_, concat_dim, dev_num, true);
|
||||
replace_op_ = CreateConcateNode(cnode_, parallel_output_nodes_, concat_dim, dev_num);
|
||||
}
|
||||
|
||||
if (replace_op_ == nullptr) {
|
||||
|
|
|
@ -67,6 +67,11 @@ void CreateSplitConstantTensors(const tensor::TensorPtr &constant_tensor, const
|
|||
for (int64_t i = 0; i < split_num; i++) {
|
||||
// init shape for [split_dim]
|
||||
visited_block += splits[i];
|
||||
if (total_block_count == 0) {
|
||||
MS_LOG(ERROR) << "divisor is zero";
|
||||
split_constant_tensors->clear();
|
||||
return;
|
||||
}
|
||||
auto cur_shape = UP_DIV(split_dim_size * visited_block, total_block_count);
|
||||
split_constant_shapes.at(i).at(split_dim) = cur_shape;
|
||||
auto tensor = std::make_shared<tensor::Tensor>(weight_type_id, split_constant_shapes.at(i));
|
||||
|
@ -491,7 +496,7 @@ int DepthwiseConv2DInfo::InferParallelCNodes() {
|
|||
|
||||
int DepthwiseConv2DInfo::InferReplaceOp() {
|
||||
size_t dev_num = strategy_.dev_num;
|
||||
replace_op_ = CreateConcateNode(cnode_, parallel_output_nodes_, split_dim_, dev_num, true);
|
||||
replace_op_ = CreateConcateNode(cnode_, parallel_output_nodes_, split_dim_, dev_num);
|
||||
if (replace_op_ == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -80,10 +80,8 @@ bool MultiConvSplit::CheckSplitValid() {
|
|||
if (i >= static_cast<int64_t>(final_ratios.size()) || final_ratios.at(i) <= 0) {
|
||||
return false;
|
||||
}
|
||||
MS_CHECK_INT_ADD_NOT_OVERFLOW(total_block_count, final_ratios.at(i), false);
|
||||
total_block_count += final_ratios.at(i);
|
||||
if (i == 0) {
|
||||
MS_CHECK_INT_ADD_NOT_OVERFLOW(visited_block, final_ratios.at(i), false);
|
||||
visited_block += final_ratios.at(i);
|
||||
}
|
||||
}
|
||||
|
@ -122,7 +120,8 @@ int MultiConvSplit::GetMultiConvNodes(const AnfNodePtr &conv_node) {
|
|||
while (index < split_info_.in_num_conv - 1) {
|
||||
MS_CHECK_LT(index, static_cast<int32_t>(conv_nodes_.size()), RET_ERROR);
|
||||
auto curr_node = conv_nodes_[index];
|
||||
auto curr_cnode = conv_nodes_[index]->cast<CNodePtr>();
|
||||
MS_ASSERT(curr_node != nullptr);
|
||||
auto curr_cnode = curr_node->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_RET(curr_cnode != nullptr, RET_ERROR);
|
||||
auto tmp_node = curr_cnode->input(1);
|
||||
if (!IsConv2D(tmp_node)) {
|
||||
|
@ -198,7 +197,7 @@ bool MultiConvSplit::SplitSingleConv(const AnfNodePtr &ori_node, const std::vect
|
|||
// node inputs
|
||||
std::vector<AnfNodePtr> conv_inputs;
|
||||
conv_inputs.push_back(NewValueNode(conv_prim));
|
||||
AdJustInputs(ori_node, inputs_node, weight_nodes, bias_nodes, output_conv_index, &conv_inputs);
|
||||
AdJustInputs(ori_node, inputs_node, output_conv_index, &conv_inputs);
|
||||
// create new conv node
|
||||
if (!CreateNewConvNode(ori_node, conv_inputs, output_conv_index, outputs_node)) {
|
||||
return false;
|
||||
|
@ -208,7 +207,6 @@ bool MultiConvSplit::SplitSingleConv(const AnfNodePtr &ori_node, const std::vect
|
|||
}
|
||||
|
||||
void MultiConvSplit::AdJustInputs(const AnfNodePtr &ori_conv_node, const std::vector<AnfNodePtr> &new_inputs_node,
|
||||
const std::vector<AnfNodePtr> &weight_node, const std::vector<AnfNodePtr> &bias_nodes,
|
||||
int output_conv_index, std::vector<AnfNodePtr> *conv_inputs) {
|
||||
MS_ASSERT(ori_conv_node != nullptr && conv_inputs != nullptr);
|
||||
auto ori_conv_cnode = ori_conv_node->cast<CNodePtr>();
|
||||
|
|
|
@ -41,7 +41,6 @@ class MultiConvSplit : public MultiNodeSplit {
|
|||
virtual AnfNodePtr MultiConvNHSplit(const AnfNodePtr &node);
|
||||
|
||||
virtual void AdJustInputs(const AnfNodePtr &ori_node, const std::vector<AnfNodePtr> &new_inputs_node,
|
||||
const std::vector<AnfNodePtr> &weight_node, const std::vector<AnfNodePtr> &bias_nodes,
|
||||
int output_conv_index, std::vector<AnfNodePtr> *conv_inputs);
|
||||
|
||||
virtual bool CreateNewConvNode(const AnfNodePtr &ori_conv_node, const std::vector<AnfNodePtr> &conv_inputs,
|
||||
|
|
|
@ -42,10 +42,7 @@ int MultiNodeSplitProxy::InitResource() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int MultiNodeSplitProxy::FreeResource() {
|
||||
multi_node_split_ = nullptr;
|
||||
return RET_OK;
|
||||
}
|
||||
void MultiNodeSplitProxy::FreeResource() { multi_node_split_ = nullptr; }
|
||||
|
||||
AnfNodePtr MultiNodeSplitProxy::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
|
||||
|
@ -55,10 +52,7 @@ AnfNodePtr MultiNodeSplitProxy::DoSplit(const FuncGraphPtr &func_graph, const An
|
|||
return node;
|
||||
}
|
||||
auto res_node = multi_node_split_->DoSplit(func_graph, node);
|
||||
ret = FreeResource();
|
||||
if (ret != RET_OK) {
|
||||
return node;
|
||||
}
|
||||
FreeResource();
|
||||
return res_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ class MultiNodeSplitProxy : public MultiNodeSplit {
|
|||
|
||||
private:
|
||||
int InitResource();
|
||||
int FreeResource();
|
||||
void FreeResource();
|
||||
|
||||
private:
|
||||
SplitMode split_mode_{NoSplit};
|
||||
|
|
|
@ -35,7 +35,7 @@ bool is_any_not_none(const std::vector<int64_t> &split) {
|
|||
return std::any_of(split.begin(), split.end(), [](int64_t v) { return v != static_cast<int64_t>(NoSplit); });
|
||||
}
|
||||
|
||||
std::shared_ptr<abstract::AbstractTensor> OperatorInfo::CreateFakeAbstractTensor() {
|
||||
std::shared_ptr<abstract::AbstractTensor> OperatorInfo::CreateFakeAbstractTensor() const {
|
||||
auto type_ptr = TypeIdToType(operator_type_id_);
|
||||
std::vector<int64_t> shape_vector;
|
||||
return std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
||||
|
@ -135,7 +135,7 @@ int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t
|
|||
}
|
||||
|
||||
AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
|
||||
int32_t concat_dim, size_t input_nodes_num, bool trans_format) {
|
||||
int32_t concat_dim, size_t input_nodes_num) {
|
||||
MS_EXCEPTION_IF_NULL(orig_node);
|
||||
if (input_nodes.size() != input_nodes_num) {
|
||||
MS_LOG(ERROR) << name_ << " : Input nodes size of concat is not equal to input nodes number.";
|
||||
|
@ -162,7 +162,7 @@ AnfNodePtr OperatorInfo::CreateConcateNode(const CNodePtr &orig_node, const std:
|
|||
}
|
||||
|
||||
AnfNodePtr OperatorInfo::CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
|
||||
int32_t reduce_dim, size_t input_nodes_num, bool trans_format) {
|
||||
int32_t reduce_dim, size_t input_nodes_num) {
|
||||
MS_EXCEPTION_IF_NULL(orig_node);
|
||||
if (input_nodes.size() != input_nodes_num) {
|
||||
MS_LOG(ERROR) << name_ << " : Input nodes size of reduce is not equal to input nodes number.";
|
||||
|
|
|
@ -65,11 +65,11 @@ class OperatorInfo {
|
|||
int CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, std::vector<AnfNodePtr> *outputs);
|
||||
|
||||
AnfNodePtr CreateConcateNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes,
|
||||
int32_t concat_dim, size_t input_nodes_num, bool trans_format);
|
||||
int32_t concat_dim, size_t input_nodes_num);
|
||||
AnfNodePtr CreateReduceNode(const CNodePtr &orig_node, const std::vector<AnfNodePtr> &input_nodes, int32_t reduce_dim,
|
||||
size_t input_nodes_num, bool trans_format);
|
||||
size_t input_nodes_num);
|
||||
|
||||
std::shared_ptr<abstract::AbstractTensor> CreateFakeAbstractTensor();
|
||||
std::shared_ptr<abstract::AbstractTensor> CreateFakeAbstractTensor() const;
|
||||
|
||||
virtual AnfNodePtr CreateOutputsOfSplit(const CNodePtr &input_node, size_t input_index,
|
||||
std::vector<AnfNodePtr> *split_outputs, size_t split_dim, size_t split_num,
|
||||
|
|
|
@ -37,16 +37,15 @@ OperatorInfoFactory *OperatorInfoFactory::GeInstance() {
|
|||
return &factory;
|
||||
}
|
||||
|
||||
int OperatorInfoFactory::RegisterOperatorInfo(schema::PrimitiveType operator_type, TypeId type_id, bool is_depth_wise,
|
||||
const OperatorInfoCreatorFunc &creator_func) {
|
||||
void OperatorInfoFactory::RegisterOperatorInfo(schema::PrimitiveType operator_type, TypeId type_id, bool is_depth_wise,
|
||||
const OperatorInfoCreatorFunc &creator_func) {
|
||||
// create a key to find the only create function
|
||||
SplitOpKey op_key(operator_type, type_id, is_depth_wise);
|
||||
if (operator_info_map_.find(op_key) != operator_info_map_.end()) {
|
||||
MS_LOG(ERROR) << " Operator already exist " << op_key.ToString();
|
||||
return lite::RET_ERROR;
|
||||
return;
|
||||
}
|
||||
this->operator_info_map_.insert(std::pair<SplitOpKey, OperatorInfoCreatorFunc>(op_key, creator_func));
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
OperatorInfoCreatorFunc OperatorInfoFactory::FindOperatorInfo(const SplitOpKey &op_key) {
|
||||
|
|
|
@ -56,8 +56,8 @@ class OperatorInfoFactory {
|
|||
|
||||
OperatorInfoFactory &operator=(const OperatorInfoFactory &) = delete;
|
||||
|
||||
int RegisterOperatorInfo(schema::PrimitiveType operator_type, TypeId type_id, bool is_depth_wise,
|
||||
const OperatorInfoCreatorFunc &creator_func);
|
||||
void RegisterOperatorInfo(schema::PrimitiveType operator_type, TypeId type_id, bool is_depth_wise,
|
||||
const OperatorInfoCreatorFunc &creator_func);
|
||||
|
||||
OperatorInfoCreatorFunc FindOperatorInfo(const SplitOpKey &split_op_key);
|
||||
|
||||
|
|
|
@ -104,6 +104,7 @@ AnfNodePtr ParallelPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &n
|
|||
}
|
||||
// if current conv2d node has two output nodes ,we do not split it;
|
||||
auto manager = func_graph->manager();
|
||||
MS_CHECK_TRUE_MSG(manager != nullptr, nullptr, "manager is nullptr.");
|
||||
auto iter = manager->node_users().find(node);
|
||||
if (iter == manager->node_users().end()) {
|
||||
MS_LOG(ERROR) << "node : " << node->fullname_with_scope() << "has no output";
|
||||
|
|
Loading…
Reference in New Issue