forked from mindspore-Ecosystem/mindspore
Fixed the bug of cast eliminate.
This commit is contained in:
parent
378b787bf1
commit
5c56a45cce
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -86,12 +86,71 @@ void CastSameTypeEliminater::Visit(const AnfNodePtr &node) {
|
|||
}
|
||||
}
|
||||
|
||||
bool TwoCastEliminater::CheckTwoTypes(const std::map<TypeId, int> &type_map, TypeId type1, TypeId type2) {
|
||||
auto type1_iter = type_map.find(type1);
|
||||
auto type2_iter = type_map.find(type2);
|
||||
if (type1_iter != type_map.end() && type2_iter != type_map.end()) {
|
||||
return type1_iter->second <= type2_iter->second;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TwoCastEliminater::CheckThreeTypes(const std::map<TypeId, int> &type_map, TypeId type1, TypeId type2,
|
||||
TypeId type3) {
|
||||
auto type1_iter = type_map.find(type1);
|
||||
auto type2_iter = type_map.find(type2);
|
||||
auto type3_iter = type_map.find(type3);
|
||||
if (type1_iter != type_map.end() && type2_iter != type_map.end() && type3_iter != type_map.end()) {
|
||||
return type1_iter->second <= type2_iter->second && type2_iter->second <= type3_iter->second;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T} -> {prim::kPrimCast, X, T}
|
||||
// x_type <= y_type <= t_type or x_type >= y_type >= t_type
|
||||
bool TwoCastEliminater::CheckTypesIsIncrementalOrDecreasing() {
|
||||
auto x_type = x_->Type();
|
||||
if (x_type->isa<TensorType>()) {
|
||||
x_type = x_type->cast<TensorTypePtr>()->element();
|
||||
}
|
||||
|
||||
auto y_type = GetValueNode<TypePtr>(y_);
|
||||
MS_EXCEPTION_IF_NULL(y_type);
|
||||
if (y_type->isa<TensorType>()) {
|
||||
y_type = y_type->cast<TensorTypePtr>()->element();
|
||||
}
|
||||
|
||||
auto t_type = GetValueNode<TypePtr>(t_);
|
||||
MS_EXCEPTION_IF_NULL(t_type);
|
||||
if (t_type->isa<TensorType>()) {
|
||||
t_type = t_type->cast<TensorTypePtr>()->element();
|
||||
}
|
||||
auto x_type_id = x_type->type_id();
|
||||
auto y_type_id = y_type->type_id();
|
||||
auto t_type_id = t_type->type_id();
|
||||
if (y_type_id == t_type_id) {
|
||||
return true;
|
||||
}
|
||||
// If the precision is incremental or decreasing, the cast can be eliminated.
|
||||
// x_type <= y_type
|
||||
bool incremental = CheckTwoTypes(int_map_, x_type_id, y_type_id) || CheckTwoTypes(uint_map_, x_type_id, y_type_id) ||
|
||||
CheckTwoTypes(float_map_, x_type_id, y_type_id);
|
||||
// x_type >= y_type >= t_type
|
||||
bool decreasing = CheckThreeTypes(int_map_, t_type_id, y_type_id, x_type_id) ||
|
||||
CheckThreeTypes(uint_map_, t_type_id, y_type_id, x_type_id) ||
|
||||
CheckThreeTypes(float_map_, t_type_id, y_type_id, x_type_id);
|
||||
return incremental || decreasing;
|
||||
}
|
||||
|
||||
// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T}
|
||||
AnfNodePtr TwoCastEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
Reset();
|
||||
AnfVisitor::Match(prim::kPrimCast, {IsCNode, IsNode})(node);
|
||||
|
||||
if (x_ != nullptr && t_ != nullptr) {
|
||||
if (x_ == nullptr || t_ == nullptr || y_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (CheckTypesIsIncrementalOrDecreasing()) {
|
||||
auto cast_op = python_adapter::GetPyFn("mindspore.ops.operations", "Cast")();
|
||||
ValuePtr cast = parse::data_converter::PyDataToValue(cast_op);
|
||||
auto cnode = NewCNode({NewValueNode(cast), x_, t_}, node->func_graph());
|
||||
|
@ -106,10 +165,13 @@ void TwoCastEliminater::Visit(const AnfNodePtr &node) {
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
// {prim::kPrimCast, X, Y}
|
||||
constexpr size_t cast_size = 3;
|
||||
constexpr size_t cast_data_index = 1;
|
||||
constexpr size_t cast_type_index = 2;
|
||||
if (cnode->size() != cast_size) {
|
||||
return;
|
||||
}
|
||||
x_ = cnode->input(1);
|
||||
x_ = cnode->input(cast_data_index);
|
||||
y_ = cnode->input(cast_type_index);
|
||||
} else {
|
||||
t_ = node;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
|
||||
|
||||
#include <map>
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
|
@ -46,10 +47,26 @@ class TwoCastEliminater : public AnfVisitor {
|
|||
void Reset() {
|
||||
x_ = nullptr;
|
||||
t_ = nullptr;
|
||||
y_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
AnfNodePtr x_{nullptr}, t_{nullptr};
|
||||
bool CheckTypesIsIncrementalOrDecreasing();
|
||||
bool CheckTwoTypes(const std::map<TypeId, int> &type_map, TypeId type1, TypeId type2);
|
||||
bool CheckThreeTypes(const std::map<TypeId, int> &type_map, TypeId type1, TypeId type2, TypeId type3);
|
||||
std::map<TypeId, int> int_map_ = {
|
||||
{kNumberTypeInt, 0}, {kNumberTypeInt8, 1}, {kNumberTypeInt16, 2}, {kNumberTypeInt32, 3}, {kNumberTypeInt64, 4}};
|
||||
std::map<TypeId, int> uint_map_ = {{kNumberTypeUInt, 0},
|
||||
{kNumberTypeUInt8, 1},
|
||||
{kNumberTypeUInt16, 2},
|
||||
{kNumberTypeUInt32, 3},
|
||||
{kNumberTypeUInt64, 4}};
|
||||
std::map<TypeId, int> float_map_ = {{kNumberTypeFloat, 0},
|
||||
{kNumberTypeFloat16, 1},
|
||||
{kNumberTypeFloat32, 2},
|
||||
{kNumberTypeFloat64, 3},
|
||||
{kNumberTypeDouble, 4}};
|
||||
AnfNodePtr x_{nullptr}, t_{nullptr}, y_{nullptr};
|
||||
};
|
||||
|
||||
class CastEliminater : public OptimizerCaller {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
# Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -22,6 +22,8 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore import nn
|
||||
import mindspore as ms
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
|
@ -619,3 +621,25 @@ def test_cast32():
|
|||
assert (output[0].asnumpy() == expected).all()
|
||||
type1 = output[1].asnumpy().dtype
|
||||
assert type1 == 'float64'
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_two_cast():
|
||||
"""
|
||||
Feature: test cast eliminate.
|
||||
Description: test two cast eliminater.
|
||||
Expectation: no exception.
|
||||
"""
|
||||
class TwoCastNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return ms.ops.Cast()(ms.ops.Cast()(x, ms.int32), ms.float32)
|
||||
|
||||
ms.set_context(mode=context.GRAPH_MODE)
|
||||
input_x = ms.Tensor([1.1], ms.float32)
|
||||
res = TwoCastNet()(input_x)
|
||||
expected = [1.]
|
||||
assert (res.asnumpy() == expected).all()
|
||||
type1 = res.asnumpy().dtype
|
||||
assert type1 == 'float32'
|
||||
|
|
|
@ -238,14 +238,6 @@ TEST_F(TestOptLib, elim_two_reshape) {
|
|||
ASSERT_TRUE(CheckOpt(before, after, patterns));
|
||||
}
|
||||
|
||||
TEST_F(TestOptLib, elim_two_cast) {
|
||||
FuncGraphPtr before = getPyFun.CallAndParseRet("elim_two_cast", "before");
|
||||
FuncGraphPtr after = getPyFun.CallAndParseRet("elim_two_cast", "after");
|
||||
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.cast_eliminate_});
|
||||
ASSERT_TRUE(CheckOpt(before, after, patterns));
|
||||
}
|
||||
|
||||
TEST_F(TestOptLib, test_elim_transpose) {
|
||||
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_transpose", "before");
|
||||
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_transpose", "after");
|
||||
|
|
|
@ -488,22 +488,6 @@ def elim_two_reshape(tag):
|
|||
return fns[tag]
|
||||
|
||||
|
||||
def elim_two_cast(tag):
|
||||
""" elim_two_cast """
|
||||
fns = FnDict()
|
||||
cast = P.Cast()
|
||||
|
||||
@fns
|
||||
def before(x, a, b):
|
||||
return cast(cast(x, a), b)
|
||||
|
||||
@fns
|
||||
def after(x, a, b):
|
||||
return cast(x, b)
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_elim_transpose(tag):
|
||||
""" test_elim_transpose """
|
||||
fns = FnDict()
|
||||
|
|
Loading…
Reference in New Issue