!22124 fix uint8 overflow bug
Merge pull request !22124 from yuchaojie/ub_fusion2
This commit is contained in:
commit
ff4932cfbd
|
@ -28,12 +28,10 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
void BatchMatmulFusedMulAddFusionPass::MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode,
|
||||
const session::KernelGraph &kernel_graph,
|
||||
const session::KernelGraph & /*kernel_graph*/,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto batch_matmul = cnode->input(kIndex2);
|
||||
MS_EXCEPTION_IF_NULL(batch_matmul);
|
||||
if (batch_matmul->isa<CNode>() && AnfAlgo::CheckPrimitiveType(batch_matmul, prim::kPrimBatchMatMul)) {
|
||||
|
|
|
@ -33,8 +33,6 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
|
|||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(relu_input);
|
||||
auto add = relu_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(add);
|
||||
|
|
|
@ -33,8 +33,6 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr
|
|||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
auto getitem = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(getitem);
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNodePtr &cnode,
|
||||
const session::KernelGraph &,
|
||||
const session::KernelGraph & /*kernel_graph*/,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
|
|
|
@ -33,8 +33,6 @@ void ConvBnReduceFusionPass::MatchConvBnreduce(const CNodePtr &cnode, const sess
|
|||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto conv = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(conv);
|
||||
if (conv->isa<CNode>() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name() &&
|
||||
|
|
|
@ -29,12 +29,10 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnode,
|
||||
const session::KernelGraph &kernel_graph,
|
||||
const session::KernelGraph & /*kernel_graph*/,
|
||||
FusedNodeRecord *candidate_fusion, bool is_order) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (is_order) {
|
||||
// DepthwiseConvolution--->Elemwise
|
||||
auto depthwise_conv = cnode->input(kIndex1);
|
||||
|
|
|
@ -28,12 +28,10 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNodePtr &cnode,
|
||||
const session::KernelGraph &kernel_graph,
|
||||
const session::KernelGraph & /*kernel_graph*/,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto matmul = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(matmul);
|
||||
if (matmul->isa<CNode>() && (AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul) ||
|
||||
|
|
|
@ -28,12 +28,10 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
void MatmulEltwiseFusionPass::MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input,
|
||||
const session::KernelGraph &kernel_graph,
|
||||
const session::KernelGraph & /*kernel_graph*/,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (fusion_id_allocator->HasFusionIdAttr(relu_input)) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -31,8 +31,6 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
|
|||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::unordered_set<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
|
|
|
@ -72,11 +72,11 @@ CNodePtr CreateFusionOp(const std::vector<AnfNodePtr> &inputs_list, const std::v
|
|||
MS_EXCEPTION_IF_NULL(fusion_op);
|
||||
|
||||
std::vector<std::string> input_names;
|
||||
for (uint8_t i = 0; i < inputs_list.size(); i++) {
|
||||
for (size_t i = 0; i < inputs_list.size(); i++) {
|
||||
(void)input_names.emplace_back("input" + std::to_string(i));
|
||||
}
|
||||
std::vector<std::string> output_names;
|
||||
for (uint8_t i = 0; i < outputs_list.size(); i++) {
|
||||
for (size_t i = 0; i < outputs_list.size(); i++) {
|
||||
(void)output_names.emplace_back("output" + std::to_string(i));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue