add checkinput in expand dsl and close expandfallback on ascend

This commit is contained in:
zengzitao 2022-04-26 11:51:06 +08:00
parent 8ce6015ae6
commit 751a9e0094
5 changed files with 35 additions and 3 deletions

View File

@ -25,6 +25,7 @@
#include "common/graph_kernel/substitute_dropout.h"
#include "common/graph_kernel/graph_kernel_helper.h"
#include "common/graph_kernel/adapter/callback_impl.h"
#include "kernel/common_utils.h"
namespace mindspore::graphkernel {
ExpanderPtr GetExpander(const AnfNodePtr &node, bool abstract) {
@ -55,6 +56,13 @@ ExpanderPtr GetExpander(const AnfNodePtr &node, bool abstract) {
}
FuncGraphPtr TryExpandCNode(const AnfNodePtr &node, const std::function<bool(const CNodePtr &kernel_node)> &func) {
auto processor = kernel::GetStrProcessorFromContext();
if (processor == kernel::kProcessorAiCore) {
auto use_expand_fallback = common::GetEnv("EXPANDERFALLBACK");
if (use_expand_fallback.empty()) {
return nullptr;
}
}
auto expand_fg = GetCNodeFuncGraph(graphkernel::GetExpander(node)->Run(node));
if (expand_fg != nullptr) {
auto todos = TopoSort(expand_fg->get_return());

View File

@ -36,6 +36,17 @@ class BiasAdd : public OpDesc {
~BiasAdd() = default;
protected:
bool CheckInputs() override {
auto it = std::find_if(std::begin(inputs_info_), std::end(inputs_info_), [](const inner::NodeBase &input) {
return input.type != kNumberTypeFloat32 && input.type != kNumberTypeFloat16;
});
if (it != std::end(inputs_info_)) {
MS_LOG(INFO) << "In BiasAdd, input's dtype must be float16 or float32, But input's type is " << it->type;
return false;
}
return true;
}
NodePtrList Expand(const NodePtrList &inputs) override {
auto input_x = inputs[0];
auto input_y = inputs[1];

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ===========================================================================
"""generate json desc for squared_difference"""
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD
@ -20,6 +21,18 @@ from ._utils import Expander, ExpanderInfoValidator as VLD
class SquaredDifference(Expander):
"""SquaredDifference expander"""
def __init__(self, expand_info):
super().__init__(expand_info)
self.dtype_x = self.inputs[0]['data_type']
self.dtype_y = self.inputs[1]['data_type']
def _check(self):
if self.dtype_x == "float64" or self.dtype_y == "float64":
raise GKException("For 'SquaredDifference', the inputs data type must not be float64")
if self.dtype_x != self.dtype_y:
raise GKException("For 'SquaredDifference', the inputs data type should be same, but got {} and {}"
.format(self.dtype_x, self.dtype_y))
def _expand(self, graph_builder):
input_x = self.inputs[0]
input_y = self.inputs[1]

View File

@ -72,7 +72,7 @@ def test_gpu_fp32():
basic_test(np.float32)
@pytest.mark.level0
@pytest.mark.level2
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_ascend_graph_mode_fp32():
@ -85,7 +85,7 @@ def test_ascend_graph_mode_fp32():
basic_test(np.float32)
@pytest.mark.level0
@pytest.mark.level2
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_ascend_pynative_mode_fp32():