diff --git a/mindspore/ccsrc/common/graph_kernel/adapter/expander.cc b/mindspore/ccsrc/common/graph_kernel/adapter/expander.cc index 4864f7bf51c..aa50dfdbcea 100644 --- a/mindspore/ccsrc/common/graph_kernel/adapter/expander.cc +++ b/mindspore/ccsrc/common/graph_kernel/adapter/expander.cc @@ -25,6 +25,7 @@ #include "common/graph_kernel/substitute_dropout.h" #include "common/graph_kernel/graph_kernel_helper.h" #include "common/graph_kernel/adapter/callback_impl.h" +#include "kernel/common_utils.h" namespace mindspore::graphkernel { ExpanderPtr GetExpander(const AnfNodePtr &node, bool abstract) { @@ -55,6 +56,13 @@ ExpanderPtr GetExpander(const AnfNodePtr &node, bool abstract) { } FuncGraphPtr TryExpandCNode(const AnfNodePtr &node, const std::function &func) { + auto processor = kernel::GetStrProcessorFromContext(); + if (processor == kernel::kProcessorAiCore) { + auto use_expand_fallback = common::GetEnv("EXPANDERFALLBACK"); + if (use_expand_fallback.empty()) { + return nullptr; + } + } auto expand_fg = GetCNodeFuncGraph(graphkernel::GetExpander(node)->Run(node)); if (expand_fg != nullptr) { auto todos = TopoSort(expand_fg->get_return()); diff --git a/mindspore/ccsrc/common/graph_kernel/expanders/bias_add.cc b/mindspore/ccsrc/common/graph_kernel/expanders/bias_add.cc index 0643995ec10..e47f9b6c3f7 100644 --- a/mindspore/ccsrc/common/graph_kernel/expanders/bias_add.cc +++ b/mindspore/ccsrc/common/graph_kernel/expanders/bias_add.cc @@ -36,6 +36,17 @@ class BiasAdd : public OpDesc { ~BiasAdd() = default; protected: + bool CheckInputs() override { + auto it = std::find_if(std::begin(inputs_info_), std::end(inputs_info_), [](const inner::NodeBase &input) { + return input.type != kNumberTypeFloat32 && input.type != kNumberTypeFloat16; + }); + if (it != std::end(inputs_info_)) { + MS_LOG(INFO) << "In BiasAdd, input's dtype must be float16 or float32, But input's type is " << it->type; + return false; + } + return true; + } + NodePtrList Expand(const NodePtrList &inputs) override { auto input_x = inputs[0]; auto input_y = inputs[1]; diff --git a/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py b/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py index 873a3ca2344..b9f4ae5cc2c 100644 --- a/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +++ b/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2021-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mindspore/python/mindspore/_extends/graph_kernel/expanders/squared_difference.py b/mindspore/python/mindspore/_extends/graph_kernel/expanders/squared_difference.py index 316b000e346..ed707cdc899 100644 --- a/mindspore/python/mindspore/_extends/graph_kernel/expanders/squared_difference.py +++ b/mindspore/python/mindspore/_extends/graph_kernel/expanders/squared_difference.py @@ -13,6 +13,7 @@ # limitations under the License. # =========================================================================== """generate json desc for squared_difference""" +from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException from ._utils import Expander, ExpanderInfoValidator as VLD @@ -20,6 +21,18 @@ from ._utils import Expander, ExpanderInfoValidator as VLD class SquaredDifference(Expander): """SquaredDifference expander""" + def __init__(self, expand_info): + super().__init__(expand_info) + self.dtype_x = self.inputs[0]['data_type'] + self.dtype_y = self.inputs[1]['data_type'] + + def _check(self): + if self.dtype_x == "float64" or self.dtype_y == "float64": + raise GKException("For 'SquaredDifference', the inputs data type must not be float64") + if self.dtype_x != self.dtype_y: + raise GKException("For 'SquaredDifference', the inputs data type should be same, but got {} and {}" + .format(self.dtype_x, self.dtype_y)) + def _expand(self, graph_builder): input_x = self.inputs[0] input_y = self.inputs[1] diff --git a/tests/st/ops/graph_kernel/test_equal_count.py b/tests/st/ops/graph_kernel/test_equal_count.py index f127d6e650b..011e6a5921e 100644 --- a/tests/st/ops/graph_kernel/test_equal_count.py +++ b/tests/st/ops/graph_kernel/test_equal_count.py @@ -72,7 +72,7 @@ def test_gpu_fp32(): basic_test(np.float32) -@pytest.mark.level0 +@pytest.mark.level2 @pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard def test_ascend_graph_mode_fp32(): @@ -85,7 +85,7 @@ def test_ascend_graph_mode_fp32(): basic_test(np.float32) -@pytest.mark.level0 +@pytest.mark.level2 @pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard def test_ascend_pynative_mode_fp32():