Fixed the bug of cast eliminate.

This commit is contained in:
Margaret_wangrui 2023-01-03 14:21:07 +08:00
parent 378b787bf1
commit 5c56a45cce
5 changed files with 109 additions and 30 deletions

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2023 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -86,12 +86,71 @@ void CastSameTypeEliminater::Visit(const AnfNodePtr &node) {
} }
} }
bool TwoCastEliminater::CheckTwoTypes(const std::map<TypeId, int> &type_map, TypeId type1, TypeId type2) {
auto type1_iter = type_map.find(type1);
auto type2_iter = type_map.find(type2);
if (type1_iter != type_map.end() && type2_iter != type_map.end()) {
return type1_iter->second <= type2_iter->second;
}
return false;
}
bool TwoCastEliminater::CheckThreeTypes(const std::map<TypeId, int> &type_map, TypeId type1, TypeId type2,
TypeId type3) {
auto type1_iter = type_map.find(type1);
auto type2_iter = type_map.find(type2);
auto type3_iter = type_map.find(type3);
if (type1_iter != type_map.end() && type2_iter != type_map.end() && type3_iter != type_map.end()) {
return type1_iter->second <= type2_iter->second && type2_iter->second <= type3_iter->second;
}
return false;
}
// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T} -> {prim::kPrimCast, X, T}
// x_type <= y_type <= t_type or x_type >= y_type >= t_type
bool TwoCastEliminater::CheckTypesIsIncrementalOrDecreasing() {
auto x_type = x_->Type();
if (x_type->isa<TensorType>()) {
x_type = x_type->cast<TensorTypePtr>()->element();
}
auto y_type = GetValueNode<TypePtr>(y_);
MS_EXCEPTION_IF_NULL(y_type);
if (y_type->isa<TensorType>()) {
y_type = y_type->cast<TensorTypePtr>()->element();
}
auto t_type = GetValueNode<TypePtr>(t_);
MS_EXCEPTION_IF_NULL(t_type);
if (t_type->isa<TensorType>()) {
t_type = t_type->cast<TensorTypePtr>()->element();
}
auto x_type_id = x_type->type_id();
auto y_type_id = y_type->type_id();
auto t_type_id = t_type->type_id();
if (y_type_id == t_type_id) {
return true;
}
// If the precision is incremental or decreasing, the cast can be eliminated.
// x_type <= y_type
bool incremental = CheckTwoTypes(int_map_, x_type_id, y_type_id) || CheckTwoTypes(uint_map_, x_type_id, y_type_id) ||
CheckTwoTypes(float_map_, x_type_id, y_type_id);
// x_type >= y_type >= t_type
bool decreasing = CheckThreeTypes(int_map_, t_type_id, y_type_id, x_type_id) ||
CheckThreeTypes(uint_map_, t_type_id, y_type_id, x_type_id) ||
CheckThreeTypes(float_map_, t_type_id, y_type_id, x_type_id);
return incremental || decreasing;
}
// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T} // {prim::kPrimCast, {prim::kPrimCast, X, Y}, T}
AnfNodePtr TwoCastEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { AnfNodePtr TwoCastEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
Reset(); Reset();
AnfVisitor::Match(prim::kPrimCast, {IsCNode, IsNode})(node); AnfVisitor::Match(prim::kPrimCast, {IsCNode, IsNode})(node);
if (x_ != nullptr && t_ != nullptr) { if (x_ == nullptr || t_ == nullptr || y_ == nullptr) {
return nullptr;
}
if (CheckTypesIsIncrementalOrDecreasing()) {
auto cast_op = python_adapter::GetPyFn("mindspore.ops.operations", "Cast")(); auto cast_op = python_adapter::GetPyFn("mindspore.ops.operations", "Cast")();
ValuePtr cast = parse::data_converter::PyDataToValue(cast_op); ValuePtr cast = parse::data_converter::PyDataToValue(cast_op);
auto cnode = NewCNode({NewValueNode(cast), x_, t_}, node->func_graph()); auto cnode = NewCNode({NewValueNode(cast), x_, t_}, node->func_graph());
@ -106,10 +165,13 @@ void TwoCastEliminater::Visit(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
// {prim::kPrimCast, X, Y} // {prim::kPrimCast, X, Y}
constexpr size_t cast_size = 3; constexpr size_t cast_size = 3;
constexpr size_t cast_data_index = 1;
constexpr size_t cast_type_index = 2;
if (cnode->size() != cast_size) { if (cnode->size() != cast_size) {
return; return;
} }
x_ = cnode->input(1); x_ = cnode->input(cast_data_index);
y_ = cnode->input(cast_type_index);
} else { } else {
t_ = node; t_ = node;
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2023 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#include <map>
#include "frontend/optimizer/anf_visitor.h" #include "frontend/optimizer/anf_visitor.h"
#include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/optimizer.h"
@ -46,10 +47,26 @@ class TwoCastEliminater : public AnfVisitor {
void Reset() { void Reset() {
x_ = nullptr; x_ = nullptr;
t_ = nullptr; t_ = nullptr;
y_ = nullptr;
} }
private: private:
AnfNodePtr x_{nullptr}, t_{nullptr}; bool CheckTypesIsIncrementalOrDecreasing();
bool CheckTwoTypes(const std::map<TypeId, int> &type_map, TypeId type1, TypeId type2);
bool CheckThreeTypes(const std::map<TypeId, int> &type_map, TypeId type1, TypeId type2, TypeId type3);
std::map<TypeId, int> int_map_ = {
{kNumberTypeInt, 0}, {kNumberTypeInt8, 1}, {kNumberTypeInt16, 2}, {kNumberTypeInt32, 3}, {kNumberTypeInt64, 4}};
std::map<TypeId, int> uint_map_ = {{kNumberTypeUInt, 0},
{kNumberTypeUInt8, 1},
{kNumberTypeUInt16, 2},
{kNumberTypeUInt32, 3},
{kNumberTypeUInt64, 4}};
std::map<TypeId, int> float_map_ = {{kNumberTypeFloat, 0},
{kNumberTypeFloat16, 1},
{kNumberTypeFloat32, 2},
{kNumberTypeFloat64, 3},
{kNumberTypeDouble, 4}};
AnfNodePtr x_{nullptr}, t_{nullptr}, y_{nullptr};
}; };
class CastEliminater : public OptimizerCaller { class CastEliminater : public OptimizerCaller {

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2019-2023 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -22,6 +22,8 @@ from mindspore.common.tensor import Tensor
from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.operations import _inner_ops as inner
from mindspore import nn
import mindspore as ms
class Net(Cell): class Net(Cell):
@ -619,3 +621,25 @@ def test_cast32():
assert (output[0].asnumpy() == expected).all() assert (output[0].asnumpy() == expected).all()
type1 = output[1].asnumpy().dtype type1 = output[1].asnumpy().dtype
assert type1 == 'float64' assert type1 == 'float64'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_two_cast():
"""
Feature: test cast eliminate.
Description: test two cast eliminater.
Expectation: no exception.
"""
class TwoCastNet(nn.Cell):
def construct(self, x):
return ms.ops.Cast()(ms.ops.Cast()(x, ms.int32), ms.float32)
ms.set_context(mode=context.GRAPH_MODE)
input_x = ms.Tensor([1.1], ms.float32)
res = TwoCastNet()(input_x)
expected = [1.]
assert (res.asnumpy() == expected).all()
type1 = res.asnumpy().dtype
assert type1 == 'float32'

View File

@ -238,14 +238,6 @@ TEST_F(TestOptLib, elim_two_reshape) {
ASSERT_TRUE(CheckOpt(before, after, patterns)); ASSERT_TRUE(CheckOpt(before, after, patterns));
} }
TEST_F(TestOptLib, elim_two_cast) {
FuncGraphPtr before = getPyFun.CallAndParseRet("elim_two_cast", "before");
FuncGraphPtr after = getPyFun.CallAndParseRet("elim_two_cast", "after");
auto patterns = std::vector<SubstitutionPtr>({irpass.cast_eliminate_});
ASSERT_TRUE(CheckOpt(before, after, patterns));
}
TEST_F(TestOptLib, test_elim_transpose) { TEST_F(TestOptLib, test_elim_transpose) {
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_transpose", "before"); FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_transpose", "before");
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_transpose", "after"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_transpose", "after");

View File

@ -488,22 +488,6 @@ def elim_two_reshape(tag):
return fns[tag] return fns[tag]
def elim_two_cast(tag):
""" elim_two_cast """
fns = FnDict()
cast = P.Cast()
@fns
def before(x, a, b):
return cast(cast(x, a), b)
@fns
def after(x, a, b):
return cast(x, b)
return fns[tag]
def test_elim_transpose(tag): def test_elim_transpose(tag):
""" test_elim_transpose """ """ test_elim_transpose """
fns = FnDict() fns = FnDict()