convert tool, cast op opt
This commit is contained in:
parent
8b490355c8
commit
d11c383f3d
|
@ -93,6 +93,7 @@
|
|||
#include "src/common/log_util.h"
|
||||
#include "tools/optimizer/fusion/groupnorm_fusion.h"
|
||||
#include "tools/optimizer/fusion/mul_reduce_fusion.h"
|
||||
#include "tools/converter/import/cast_op_adjust.h"
|
||||
|
||||
using std::string;
|
||||
namespace mindspore::lite {
|
||||
|
@ -333,6 +334,7 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const std::share
|
|||
CHECK_NULL_RETURN(convert_pm);
|
||||
convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model));
|
||||
convert_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
|
||||
convert_pm->AddPass(std::make_shared<opt::CastOpAdjust>());
|
||||
convert_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
|
||||
optimizer->AddPassManager(convert_pm);
|
||||
if (optimizer->Optimize(old_graph) == nullptr) {
|
||||
|
|
|
@ -14,10 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/import/cast_op_adjust.h"
|
||||
#include <vector>
|
||||
#include "tools/lite_exporter/fetch_content.h"
|
||||
#include "mindspore/lite/include/errorcode.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace mindspore::opt {
|
||||
constexpr size_t kCastInputNum = 3;
|
||||
|
||||
bool GetInOutDataTypeValue(const CNodePtr &cast_cnode, int *output_type_value, int *input_type_value) {
|
||||
|
@ -36,7 +37,7 @@ bool GetInOutDataTypeValue(const CNodePtr &cast_cnode, int *output_type_value, i
|
|||
MS_CHECK_TRUE_RET(input_type != nullptr, false);
|
||||
*input_type_value = input_type->type_id();
|
||||
|
||||
DataInfo data_info;
|
||||
lite::DataInfo data_info;
|
||||
auto output_type_node = cast_cnode->input(opt::kInputIndexTwo);
|
||||
if (utils::isa<ParameterPtr>(output_type_node)) {
|
||||
if (FetchDataFromParameterNode(cast_cnode, opt::kInputIndexTwo, converter::kFmkTypeMs, &data_info, true) !=
|
||||
|
@ -77,11 +78,60 @@ bool GetInOutDataTypeValue(const CNodePtr &cast_cnode, int *output_type_value, i
|
|||
return true;
|
||||
}
|
||||
|
||||
bool CastOpAdjust::Run(const FuncGraphPtr &func_graph) {
|
||||
bool MatchRemoveCastOpRule(int input_type_value, int output_type_value, bool strict_mode_flag) {
|
||||
if (strict_mode_flag) {
|
||||
if (input_type_value == output_type_value) {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
bool int_input = input_type_value == kNumberTypeInt || input_type_value == kNumberTypeInt32;
|
||||
bool int_output = output_type_value == kNumberTypeInt || output_type_value == kNumberTypeInt32 ||
|
||||
output_type_value == kNumberTypeInt64;
|
||||
|
||||
bool float_input = input_type_value == kNumberTypeFloat16 || input_type_value == kNumberTypeFloat32;
|
||||
bool float_output = output_type_value == kNumberTypeFloat16 || output_type_value == kNumberTypeFloat32;
|
||||
|
||||
if ((int_input && int_output) || (float_input && float_output) || (input_type_value == output_type_value)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsControlFlow(const std::vector<AnfNodePtr> &node_list) {
|
||||
static bool control_flow_flag = false;
|
||||
if (control_flow_flag) {
|
||||
return true;
|
||||
}
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
for (size_t i = 0; i < cnode->size(); ++i) {
|
||||
auto in_node = cnode->input(i);
|
||||
MS_CHECK_TRUE_RET(in_node != nullptr, false);
|
||||
auto sub_func = GetValueNode<FuncGraphPtr>(in_node);
|
||||
if (sub_func != nullptr) {
|
||||
control_flow_flag = true;
|
||||
return control_flow_flag;
|
||||
}
|
||||
}
|
||||
}
|
||||
return control_flow_flag;
|
||||
}
|
||||
|
||||
bool CastOpAdjust::Run(const FuncGraphPtr &func_graph) { return CastOpAdjust::Run(func_graph, true); }
|
||||
|
||||
bool CastOpAdjust::Run(const FuncGraphPtr &func_graph, bool strict_mode_flag) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
if (IsControlFlow(node_list) && strict_mode_flag) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node) || !opt::CheckPrimitiveType(node, prim::kPrimCast)) {
|
||||
continue;
|
||||
|
@ -95,18 +145,14 @@ bool CastOpAdjust::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "Get output type failed.";
|
||||
return false;
|
||||
}
|
||||
bool int_input = input_type_value == kNumberTypeInt || input_type_value == kNumberTypeInt32;
|
||||
bool int_output = output_type_value == kNumberTypeInt || output_type_value == kNumberTypeInt32 ||
|
||||
output_type_value == kNumberTypeInt64;
|
||||
if ((int_input && int_output) || (input_type_value == output_type_value) ||
|
||||
(input_type_value == kNumberTypeFloat32 && output_type_value == kNumberTypeFloat16) ||
|
||||
(input_type_value == kNumberTypeFloat16 && output_type_value == kNumberTypeFloat32)) {
|
||||
if (MatchRemoveCastOpRule(input_type_value, output_type_value, strict_mode_flag)) {
|
||||
auto ret = manager->Replace(node, cast_cnode->input(1));
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Replace node to its input failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (output_type_value == kNumberTypeInt64) {
|
||||
auto parameter = opt::BuildIntValueParameterNode(func_graph, kNumberTypeInt32,
|
||||
cast_cnode->input(opt::kInputIndexTwo)->fullname_with_scope());
|
||||
|
@ -123,4 +169,4 @@ bool CastOpAdjust::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
} // namespace mindspore::opt
|
||||
|
|
|
@ -17,14 +17,18 @@
|
|||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_IMPORT_CAST_OP_ADJUST_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_IMPORT_CAST_OP_ADJUST_H_
|
||||
#include <string>
|
||||
#include "backend/common/optimizer/pass.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class CastOpAdjust {
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class CastOpAdjust : public Pass {
|
||||
public:
|
||||
CastOpAdjust() {}
|
||||
CastOpAdjust() : Pass("CastOpAdjust") {}
|
||||
~CastOpAdjust() = default;
|
||||
bool Run(const FuncGraphPtr &graph);
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
bool Run(const FuncGraphPtr &graph, bool strict_mode_flag);
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_IMPORT_CAST_OP_ADJUST_H_
|
||||
|
|
|
@ -63,9 +63,9 @@ STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph,
|
|||
return RET_ERROR;
|
||||
}
|
||||
if (!param->train_model) {
|
||||
auto cast_op_adjust = std::make_shared<CastOpAdjust>();
|
||||
auto cast_op_adjust = std::make_shared<opt::CastOpAdjust>();
|
||||
MS_CHECK_TRUE_MSG(cast_op_adjust != nullptr, RET_NULL_PTR, "cast_op_adjust is nullptr.");
|
||||
if (!cast_op_adjust->Run(func_graph)) {
|
||||
if (!cast_op_adjust->Run(func_graph, false)) {
|
||||
MS_LOG(ERROR) << "MindIr adjust cast operator failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return RET_ERROR;
|
||||
|
|
Loading…
Reference in New Issue