convert tool, cast op opt

This commit is contained in:
greatpan 2022-06-13 10:07:25 +08:00
parent 8b490355c8
commit d11c383f3d
4 changed files with 69 additions and 17 deletions

View File

@ -93,6 +93,7 @@
#include "src/common/log_util.h"
#include "tools/optimizer/fusion/groupnorm_fusion.h"
#include "tools/optimizer/fusion/mul_reduce_fusion.h"
#include "tools/converter/import/cast_op_adjust.h"
using std::string;
namespace mindspore::lite {
@ -333,6 +334,7 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const std::share
CHECK_NULL_RETURN(convert_pm);
convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model));
convert_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
convert_pm->AddPass(std::make_shared<opt::CastOpAdjust>());
convert_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
optimizer->AddPassManager(convert_pm);
if (optimizer->Optimize(old_graph) == nullptr) {

View File

@ -14,10 +14,11 @@
* limitations under the License.
*/
#include "tools/converter/import/cast_op_adjust.h"
#include <vector>
#include "tools/lite_exporter/fetch_content.h"
#include "mindspore/lite/include/errorcode.h"
namespace mindspore::lite {
namespace mindspore::opt {
constexpr size_t kCastInputNum = 3;
bool GetInOutDataTypeValue(const CNodePtr &cast_cnode, int *output_type_value, int *input_type_value) {
@ -36,7 +37,7 @@ bool GetInOutDataTypeValue(const CNodePtr &cast_cnode, int *output_type_value, i
MS_CHECK_TRUE_RET(input_type != nullptr, false);
*input_type_value = input_type->type_id();
DataInfo data_info;
lite::DataInfo data_info;
auto output_type_node = cast_cnode->input(opt::kInputIndexTwo);
if (utils::isa<ParameterPtr>(output_type_node)) {
if (FetchDataFromParameterNode(cast_cnode, opt::kInputIndexTwo, converter::kFmkTypeMs, &data_info, true) !=
@ -77,11 +78,60 @@ bool GetInOutDataTypeValue(const CNodePtr &cast_cnode, int *output_type_value, i
return true;
}
bool CastOpAdjust::Run(const FuncGraphPtr &func_graph) {
bool MatchRemoveCastOpRule(int input_type_value, int output_type_value, bool strict_mode_flag) {
if (strict_mode_flag) {
if (input_type_value == output_type_value) {
return true;
}
} else {
bool int_input = input_type_value == kNumberTypeInt || input_type_value == kNumberTypeInt32;
bool int_output = output_type_value == kNumberTypeInt || output_type_value == kNumberTypeInt32 ||
output_type_value == kNumberTypeInt64;
bool float_input = input_type_value == kNumberTypeFloat16 || input_type_value == kNumberTypeFloat32;
bool float_output = output_type_value == kNumberTypeFloat16 || output_type_value == kNumberTypeFloat32;
if ((int_input && int_output) || (float_input && float_output) || (input_type_value == output_type_value)) {
return true;
}
}
return false;
}
bool IsControlFlow(const std::vector<AnfNodePtr> &node_list) {
static bool control_flow_flag = false;
if (control_flow_flag) {
return true;
}
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
for (size_t i = 0; i < cnode->size(); ++i) {
auto in_node = cnode->input(i);
MS_CHECK_TRUE_RET(in_node != nullptr, false);
auto sub_func = GetValueNode<FuncGraphPtr>(in_node);
if (sub_func != nullptr) {
control_flow_flag = true;
return control_flow_flag;
}
}
}
return control_flow_flag;
}
bool CastOpAdjust::Run(const FuncGraphPtr &func_graph) { return CastOpAdjust::Run(func_graph, true); }
bool CastOpAdjust::Run(const FuncGraphPtr &func_graph, bool strict_mode_flag) {
MS_ASSERT(func_graph != nullptr);
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto node_list = TopoSort(func_graph->get_return());
if (IsControlFlow(node_list) && strict_mode_flag) {
return false;
}
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node) || !opt::CheckPrimitiveType(node, prim::kPrimCast)) {
continue;
@ -95,18 +145,14 @@ bool CastOpAdjust::Run(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "Get output type failed.";
return false;
}
bool int_input = input_type_value == kNumberTypeInt || input_type_value == kNumberTypeInt32;
bool int_output = output_type_value == kNumberTypeInt || output_type_value == kNumberTypeInt32 ||
output_type_value == kNumberTypeInt64;
if ((int_input && int_output) || (input_type_value == output_type_value) ||
(input_type_value == kNumberTypeFloat32 && output_type_value == kNumberTypeFloat16) ||
(input_type_value == kNumberTypeFloat16 && output_type_value == kNumberTypeFloat32)) {
if (MatchRemoveCastOpRule(input_type_value, output_type_value, strict_mode_flag)) {
auto ret = manager->Replace(node, cast_cnode->input(1));
if (!ret) {
MS_LOG(ERROR) << "Replace node to its input failed.";
return false;
}
}
if (output_type_value == kNumberTypeInt64) {
auto parameter = opt::BuildIntValueParameterNode(func_graph, kNumberTypeInt32,
cast_cnode->input(opt::kInputIndexTwo)->fullname_with_scope());
@ -123,4 +169,4 @@ bool CastOpAdjust::Run(const FuncGraphPtr &func_graph) {
}
return true;
}
} // namespace mindspore::lite
} // namespace mindspore::opt

View File

@ -17,14 +17,18 @@
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_IMPORT_CAST_OP_ADJUST_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_IMPORT_CAST_OP_ADJUST_H_
#include <string>
#include "backend/common/optimizer/pass.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::lite {
class CastOpAdjust {
namespace mindspore {
namespace opt {
class CastOpAdjust : public Pass {
public:
CastOpAdjust() {}
CastOpAdjust() : Pass("CastOpAdjust") {}
~CastOpAdjust() = default;
bool Run(const FuncGraphPtr &graph);
bool Run(const FuncGraphPtr &graph) override;
bool Run(const FuncGraphPtr &graph, bool strict_mode_flag);
};
} // namespace mindspore::lite
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_IMPORT_CAST_OP_ADJUST_H_

View File

@ -63,9 +63,9 @@ STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph,
return RET_ERROR;
}
if (!param->train_model) {
auto cast_op_adjust = std::make_shared<CastOpAdjust>();
auto cast_op_adjust = std::make_shared<opt::CastOpAdjust>();
MS_CHECK_TRUE_MSG(cast_op_adjust != nullptr, RET_NULL_PTR, "cast_op_adjust is nullptr.");
if (!cast_op_adjust->Run(func_graph)) {
if (!cast_op_adjust->Run(func_graph, false)) {
MS_LOG(ERROR) << "MindIr adjust cast operator failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;