forked from mindspore-Ecosystem/mindspore
add checkinput in expand dsl and close expandfallback on ascend
This commit is contained in:
parent
8ce6015ae6
commit
751a9e0094
|
@ -25,6 +25,7 @@
|
|||
#include "common/graph_kernel/substitute_dropout.h"
|
||||
#include "common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "common/graph_kernel/adapter/callback_impl.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
ExpanderPtr GetExpander(const AnfNodePtr &node, bool abstract) {
|
||||
|
@ -55,6 +56,13 @@ ExpanderPtr GetExpander(const AnfNodePtr &node, bool abstract) {
|
|||
}
|
||||
|
||||
FuncGraphPtr TryExpandCNode(const AnfNodePtr &node, const std::function<bool(const CNodePtr &kernel_node)> &func) {
|
||||
auto processor = kernel::GetStrProcessorFromContext();
|
||||
if (processor == kernel::kProcessorAiCore) {
|
||||
auto use_expand_fallback = common::GetEnv("EXPANDERFALLBACK");
|
||||
if (use_expand_fallback.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
auto expand_fg = GetCNodeFuncGraph(graphkernel::GetExpander(node)->Run(node));
|
||||
if (expand_fg != nullptr) {
|
||||
auto todos = TopoSort(expand_fg->get_return());
|
||||
|
|
|
@ -36,6 +36,17 @@ class BiasAdd : public OpDesc {
|
|||
~BiasAdd() = default;
|
||||
|
||||
protected:
|
||||
bool CheckInputs() override {
|
||||
auto it = std::find_if(std::begin(inputs_info_), std::end(inputs_info_), [](const inner::NodeBase &input) {
|
||||
return input.type != kNumberTypeFloat32 && input.type != kNumberTypeFloat16;
|
||||
});
|
||||
if (it != std::end(inputs_info_)) {
|
||||
MS_LOG(INFO) << "In BiasAdd, input's dtype must be float16 or float32, But input's type is " << it->type;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
NodePtrList Expand(const NodePtrList &inputs) override {
|
||||
auto input_x = inputs[0];
|
||||
auto input_y = inputs[1];
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for squared_difference"""
|
||||
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
|
@ -20,6 +21,18 @@ from ._utils import Expander, ExpanderInfoValidator as VLD
|
|||
class SquaredDifference(Expander):
|
||||
"""SquaredDifference expander"""
|
||||
|
||||
def __init__(self, expand_info):
|
||||
super().__init__(expand_info)
|
||||
self.dtype_x = self.inputs[0]['data_type']
|
||||
self.dtype_y = self.inputs[1]['data_type']
|
||||
|
||||
def _check(self):
|
||||
if self.dtype_x == "float64" or self.dtype_y == "float64":
|
||||
raise GKException("For 'SquaredDifference', the inputs data type must not be float64")
|
||||
if self.dtype_x != self.dtype_y:
|
||||
raise GKException("For 'SquaredDifference', the inputs data type should be same, but got {} and {}"
|
||||
.format(self.dtype_x, self.dtype_y))
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x = self.inputs[0]
|
||||
input_y = self.inputs[1]
|
||||
|
|
|
@ -72,7 +72,7 @@ def test_gpu_fp32():
|
|||
basic_test(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ascend_graph_mode_fp32():
|
||||
|
@ -85,7 +85,7 @@ def test_ascend_graph_mode_fp32():
|
|||
basic_test(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ascend_pynative_mode_fp32():
|
||||
|
|
Loading…
Reference in New Issue