!49121 Support None is input of real kernel.

Merge pull request !49121 from Margaret_wangrui/new_fallback_none_real_kernel
This commit is contained in:
i-robot 2023-02-20 11:47:44 +00:00 committed by Gitee
commit 8744bbfe3d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 51 additions and 21 deletions

View File

@ -34,6 +34,7 @@
#include "ir/value.h"
#include "pipeline/jit/parse/resolve.h"
#include "utils/hash_map.h"
#include "utils/anf_utils.h"
namespace mindspore {
/* namespace to support opt */
@ -969,6 +970,9 @@ class CleanAfterOptARewriter : public BaseRewriter {
if (!support_fallback_runtime) {
return;
}
if (AnfUtils::IsRealKernel(cnode)) {
return;
}
const auto &inputs = cnode->inputs();
const auto &cur_func = cnode->func_graph();
MS_EXCEPTION_IF_NULL(cur_func);

View File

@ -70,8 +70,7 @@ class PrintConstStringWrapper : public AnfVisitor {
return std::any_of(elements.cbegin(), elements.cend(),
[&](const AbstractBasePtr &ele) { return CheckNeedConvert(ele); });
}
return !abs->isa<abstract::AbstractScalar>() && !abs->isa<abstract::AbstractTensor>() &&
!abs->isa<abstract::AbstractNone>();
return !abs->isa<abstract::AbstractScalar>() && !abs->isa<abstract::AbstractTensor>();
}
AnfNodePtr ConvertString(const AbstractBasePtr &abs) const {

View File

@ -69,7 +69,6 @@ class PyExecuteInitializer {
// so special handling of None is required.
if (script->ToString() == "None") {
const auto &output = py::none();
MS_LOG(DEBUG) << "Python output type: " << py::str(output.get_type()) << ", output: " << output;
PushPyExecuteOutput(output);
const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
return abstract::MakeAbstract(infer_shape, kFloat64);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -493,19 +493,14 @@ py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr
}
return ref_tuple;
}
// The size of seq_abs may be larger than the size of value_list, because the backend will eliminate None.
size_t ref_idx = 0;
for (size_t i = 0; i < seq_abs->size(); i++) {
auto elem_abs = seq_abs->elements()[i];
if (elem_abs->isa<abstract::AbstractNone>()) {
continue;
}
ref_tuple[ref_idx] = BaseRefToPyData(value_list[ref_idx], elem_abs);
ref_idx++;
}
if (ref_idx != value_size) {
MS_LOG(EXCEPTION) << "The size of elements (excluding None) should be equal to " << value_size << ", but got "
<< ref_idx;
MS_LOG(EXCEPTION) << "The size of elements should be equal to " << value_size << ", but got " << ref_idx;
}
ret = ref_tuple;
return ret;

View File

@ -16,8 +16,11 @@ import pytest
import numpy as np
from mindspore import Tensor, jit, context, Parameter
from mindspore.nn import Cell
from mindspore.nn.probability import distribution
import mindspore.nn as nn
import mindspore.common.dtype as mstype
import mindspore as ms
context.set_context(mode=context.GRAPH_MODE)
@ -222,7 +225,6 @@ def test_none_is_slice_in_list():
assert res == 0
@pytest.mark.skip(reason="No support print None.")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@ -397,3 +399,44 @@ def test_none_is_input_of_tuple_return():
out = foo()
assert out == (1, "a", None)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_none_is_input_of_tuple_return_2():
"""
Feature: Support None.
Description: Support None is input of tuple, and the tuple is return.
Expectation: No exception.
"""
class BernoulliCrossEntropy(Cell):
def __init__(self, probs, seed=10, dtype=ms.int32, name='Bernoulli', dist='Bernoulli'):
super().__init__()
self.b = distribution.Bernoulli(probs, seed, dtype, name)
self.dist = dist
def construct(self, probs1_b, probs1=None):
if probs1 is None:
out1 = self.b.cross_entropy(self.dist, probs1_b)
out2 = self.b.kl_loss(self.dist, probs1_b)
else:
out1 = self.b.cross_entropy(self.dist, probs1_b, probs1)
out2 = self.b.kl_loss(self.dist, probs1_b, probs1)
out3 = self.b.probs
return out1, out2, out3
probs = None
probs1 = 0.2
probs1_b = 0.9
context.set_context(mode=context.GRAPH_MODE)
net_graph = BernoulliCrossEntropy(probs)
out_me_graph = net_graph(Tensor(probs1_b), Tensor(probs1))
print("out_me_graph: ", out_me_graph)
context.set_context(mode=context.PYNATIVE_MODE)
net_pynative = BernoulliCrossEntropy(probs)
out_me_pynative = net_pynative(Tensor(probs1_b), Tensor(probs1))
print("out_me_pynative: ", out_me_pynative)
assert out_me_graph == out_me_pynative

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================
import os
import math
import pytest
import numpy as np
@ -94,7 +93,6 @@ class GRUWeightBias():
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_sit_gru_forward_input_3_32_32_is_32_hs_16():
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
input_size = 32
hidden_size = 16
has_bias = True
@ -117,7 +115,6 @@ def test_sit_gru_forward_input_3_32_32_is_32_hs_16():
net.gru.b_ih_list = b_ih_list
net.gru.b_hh_list = b_hh_list
out, hy = net(input_ms, h0)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
@ -138,7 +135,6 @@ def test_sit_gru_forward_input_3_32_32_is_32_hs_16():
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_sit_gru_grad_input_3_32_32_is_32_hs_16():
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
input_size = 32
hidden_size = 16
has_bias = True
@ -166,7 +162,6 @@ def test_sit_gru_grad_input_3_32_32_is_32_hs_16():
out_grad, _ = grad_net_inp(input_ms, h0)
x_grad = out_grad[0].asnumpy()
h_grad = out_grad[1].asnumpy()
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================
import os
import math
import pytest
import numpy as np
@ -94,7 +93,6 @@ class LSTMWeightBias():
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_sit_lstm_forward_input_3_32_32_is_32_hs_16():
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
input_s = 32
hidden_s = 16
has_bias = True
@ -118,7 +116,6 @@ def test_sit_lstm_forward_input_3_32_32_is_32_hs_16():
net.lstm.b_ih_list = b_ih_list
net.lstm.b_hh_list = b_hh_list
out, (hy, cy) = net(input_ms, h0, c0)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
@ -141,7 +138,6 @@ def test_sit_lstm_forward_input_3_32_32_is_32_hs_16():
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_sit_lstm_grad_input_3_32_32_is_32_hs_16():
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
input_s = 32
hidden_s = 16
has_bias = True
@ -171,7 +167,6 @@ def test_sit_lstm_grad_input_3_32_32_is_32_hs_16():
x_grad = out_grad[0].asnumpy()
h_grad = out_grad[1].asnumpy()
c_grad = out_grad[2].asnumpy()
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)