forked from mindspore-Ecosystem/mindspore
Fixed the bug of cast eliminate.
This commit is contained in:
parent
378b787bf1
commit
5c56a45cce
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -86,12 +86,71 @@ void CastSameTypeEliminater::Visit(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool TwoCastEliminater::CheckTwoTypes(const std::map<TypeId, int> &type_map, TypeId type1, TypeId type2) {
|
||||||
|
auto type1_iter = type_map.find(type1);
|
||||||
|
auto type2_iter = type_map.find(type2);
|
||||||
|
if (type1_iter != type_map.end() && type2_iter != type_map.end()) {
|
||||||
|
return type1_iter->second <= type2_iter->second;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TwoCastEliminater::CheckThreeTypes(const std::map<TypeId, int> &type_map, TypeId type1, TypeId type2,
|
||||||
|
TypeId type3) {
|
||||||
|
auto type1_iter = type_map.find(type1);
|
||||||
|
auto type2_iter = type_map.find(type2);
|
||||||
|
auto type3_iter = type_map.find(type3);
|
||||||
|
if (type1_iter != type_map.end() && type2_iter != type_map.end() && type3_iter != type_map.end()) {
|
||||||
|
return type1_iter->second <= type2_iter->second && type2_iter->second <= type3_iter->second;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T} -> {prim::kPrimCast, X, T}
|
||||||
|
// x_type <= y_type <= t_type or x_type >= y_type >= t_type
|
||||||
|
bool TwoCastEliminater::CheckTypesIsIncrementalOrDecreasing() {
|
||||||
|
auto x_type = x_->Type();
|
||||||
|
if (x_type->isa<TensorType>()) {
|
||||||
|
x_type = x_type->cast<TensorTypePtr>()->element();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto y_type = GetValueNode<TypePtr>(y_);
|
||||||
|
MS_EXCEPTION_IF_NULL(y_type);
|
||||||
|
if (y_type->isa<TensorType>()) {
|
||||||
|
y_type = y_type->cast<TensorTypePtr>()->element();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto t_type = GetValueNode<TypePtr>(t_);
|
||||||
|
MS_EXCEPTION_IF_NULL(t_type);
|
||||||
|
if (t_type->isa<TensorType>()) {
|
||||||
|
t_type = t_type->cast<TensorTypePtr>()->element();
|
||||||
|
}
|
||||||
|
auto x_type_id = x_type->type_id();
|
||||||
|
auto y_type_id = y_type->type_id();
|
||||||
|
auto t_type_id = t_type->type_id();
|
||||||
|
if (y_type_id == t_type_id) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// If the precision is incremental or decreasing, the cast can be eliminated.
|
||||||
|
// x_type <= y_type
|
||||||
|
bool incremental = CheckTwoTypes(int_map_, x_type_id, y_type_id) || CheckTwoTypes(uint_map_, x_type_id, y_type_id) ||
|
||||||
|
CheckTwoTypes(float_map_, x_type_id, y_type_id);
|
||||||
|
// x_type >= y_type >= t_type
|
||||||
|
bool decreasing = CheckThreeTypes(int_map_, t_type_id, y_type_id, x_type_id) ||
|
||||||
|
CheckThreeTypes(uint_map_, t_type_id, y_type_id, x_type_id) ||
|
||||||
|
CheckThreeTypes(float_map_, t_type_id, y_type_id, x_type_id);
|
||||||
|
return incremental || decreasing;
|
||||||
|
}
|
||||||
|
|
||||||
// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T}
|
// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T}
|
||||||
AnfNodePtr TwoCastEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
AnfNodePtr TwoCastEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||||
Reset();
|
Reset();
|
||||||
AnfVisitor::Match(prim::kPrimCast, {IsCNode, IsNode})(node);
|
AnfVisitor::Match(prim::kPrimCast, {IsCNode, IsNode})(node);
|
||||||
|
|
||||||
if (x_ != nullptr && t_ != nullptr) {
|
if (x_ == nullptr || t_ == nullptr || y_ == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (CheckTypesIsIncrementalOrDecreasing()) {
|
||||||
auto cast_op = python_adapter::GetPyFn("mindspore.ops.operations", "Cast")();
|
auto cast_op = python_adapter::GetPyFn("mindspore.ops.operations", "Cast")();
|
||||||
ValuePtr cast = parse::data_converter::PyDataToValue(cast_op);
|
ValuePtr cast = parse::data_converter::PyDataToValue(cast_op);
|
||||||
auto cnode = NewCNode({NewValueNode(cast), x_, t_}, node->func_graph());
|
auto cnode = NewCNode({NewValueNode(cast), x_, t_}, node->func_graph());
|
||||||
|
@ -106,10 +165,13 @@ void TwoCastEliminater::Visit(const AnfNodePtr &node) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
// {prim::kPrimCast, X, Y}
|
// {prim::kPrimCast, X, Y}
|
||||||
constexpr size_t cast_size = 3;
|
constexpr size_t cast_size = 3;
|
||||||
|
constexpr size_t cast_data_index = 1;
|
||||||
|
constexpr size_t cast_type_index = 2;
|
||||||
if (cnode->size() != cast_size) {
|
if (cnode->size() != cast_size) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
x_ = cnode->input(1);
|
x_ = cnode->input(cast_data_index);
|
||||||
|
y_ = cnode->input(cast_type_index);
|
||||||
} else {
|
} else {
|
||||||
t_ = node;
|
t_ = node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -17,6 +17,7 @@
|
||||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
|
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
|
||||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
|
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
#include "frontend/optimizer/anf_visitor.h"
|
#include "frontend/optimizer/anf_visitor.h"
|
||||||
#include "frontend/optimizer/irpass.h"
|
#include "frontend/optimizer/irpass.h"
|
||||||
#include "frontend/optimizer/optimizer.h"
|
#include "frontend/optimizer/optimizer.h"
|
||||||
|
@ -46,10 +47,26 @@ class TwoCastEliminater : public AnfVisitor {
|
||||||
void Reset() {
|
void Reset() {
|
||||||
x_ = nullptr;
|
x_ = nullptr;
|
||||||
t_ = nullptr;
|
t_ = nullptr;
|
||||||
|
y_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
AnfNodePtr x_{nullptr}, t_{nullptr};
|
bool CheckTypesIsIncrementalOrDecreasing();
|
||||||
|
bool CheckTwoTypes(const std::map<TypeId, int> &type_map, TypeId type1, TypeId type2);
|
||||||
|
bool CheckThreeTypes(const std::map<TypeId, int> &type_map, TypeId type1, TypeId type2, TypeId type3);
|
||||||
|
std::map<TypeId, int> int_map_ = {
|
||||||
|
{kNumberTypeInt, 0}, {kNumberTypeInt8, 1}, {kNumberTypeInt16, 2}, {kNumberTypeInt32, 3}, {kNumberTypeInt64, 4}};
|
||||||
|
std::map<TypeId, int> uint_map_ = {{kNumberTypeUInt, 0},
|
||||||
|
{kNumberTypeUInt8, 1},
|
||||||
|
{kNumberTypeUInt16, 2},
|
||||||
|
{kNumberTypeUInt32, 3},
|
||||||
|
{kNumberTypeUInt64, 4}};
|
||||||
|
std::map<TypeId, int> float_map_ = {{kNumberTypeFloat, 0},
|
||||||
|
{kNumberTypeFloat16, 1},
|
||||||
|
{kNumberTypeFloat32, 2},
|
||||||
|
{kNumberTypeFloat64, 3},
|
||||||
|
{kNumberTypeDouble, 4}};
|
||||||
|
AnfNodePtr x_{nullptr}, t_{nullptr}, y_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
class CastEliminater : public OptimizerCaller {
|
class CastEliminater : public OptimizerCaller {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
# Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -22,6 +22,8 @@ from mindspore.common.tensor import Tensor
|
||||||
from mindspore.nn import Cell
|
from mindspore.nn import Cell
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops.operations import _inner_ops as inner
|
from mindspore.ops.operations import _inner_ops as inner
|
||||||
|
from mindspore import nn
|
||||||
|
import mindspore as ms
|
||||||
|
|
||||||
|
|
||||||
class Net(Cell):
|
class Net(Cell):
|
||||||
|
@ -619,3 +621,25 @@ def test_cast32():
|
||||||
assert (output[0].asnumpy() == expected).all()
|
assert (output[0].asnumpy() == expected).all()
|
||||||
type1 = output[1].asnumpy().dtype
|
type1 = output[1].asnumpy().dtype
|
||||||
assert type1 == 'float64'
|
assert type1 == 'float64'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_two_cast():
|
||||||
|
"""
|
||||||
|
Feature: test cast eliminate.
|
||||||
|
Description: test two cast eliminater.
|
||||||
|
Expectation: no exception.
|
||||||
|
"""
|
||||||
|
class TwoCastNet(nn.Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return ms.ops.Cast()(ms.ops.Cast()(x, ms.int32), ms.float32)
|
||||||
|
|
||||||
|
ms.set_context(mode=context.GRAPH_MODE)
|
||||||
|
input_x = ms.Tensor([1.1], ms.float32)
|
||||||
|
res = TwoCastNet()(input_x)
|
||||||
|
expected = [1.]
|
||||||
|
assert (res.asnumpy() == expected).all()
|
||||||
|
type1 = res.asnumpy().dtype
|
||||||
|
assert type1 == 'float32'
|
||||||
|
|
|
@ -238,14 +238,6 @@ TEST_F(TestOptLib, elim_two_reshape) {
|
||||||
ASSERT_TRUE(CheckOpt(before, after, patterns));
|
ASSERT_TRUE(CheckOpt(before, after, patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestOptLib, elim_two_cast) {
|
|
||||||
FuncGraphPtr before = getPyFun.CallAndParseRet("elim_two_cast", "before");
|
|
||||||
FuncGraphPtr after = getPyFun.CallAndParseRet("elim_two_cast", "after");
|
|
||||||
|
|
||||||
auto patterns = std::vector<SubstitutionPtr>({irpass.cast_eliminate_});
|
|
||||||
ASSERT_TRUE(CheckOpt(before, after, patterns));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TestOptLib, test_elim_transpose) {
|
TEST_F(TestOptLib, test_elim_transpose) {
|
||||||
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_transpose", "before");
|
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_transpose", "before");
|
||||||
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_transpose", "after");
|
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_transpose", "after");
|
||||||
|
|
|
@ -488,22 +488,6 @@ def elim_two_reshape(tag):
|
||||||
return fns[tag]
|
return fns[tag]
|
||||||
|
|
||||||
|
|
||||||
def elim_two_cast(tag):
|
|
||||||
""" elim_two_cast """
|
|
||||||
fns = FnDict()
|
|
||||||
cast = P.Cast()
|
|
||||||
|
|
||||||
@fns
|
|
||||||
def before(x, a, b):
|
|
||||||
return cast(cast(x, a), b)
|
|
||||||
|
|
||||||
@fns
|
|
||||||
def after(x, a, b):
|
|
||||||
return cast(x, b)
|
|
||||||
|
|
||||||
return fns[tag]
|
|
||||||
|
|
||||||
|
|
||||||
def test_elim_transpose(tag):
|
def test_elim_transpose(tag):
|
||||||
""" test_elim_transpose """
|
""" test_elim_transpose """
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
|
|
Loading…
Reference in New Issue