forked from mindspore-Ecosystem/mindspore
!49121 Support None is input of real kernel.
Merge pull request !49121 from Margaret_wangrui/new_fallback_none_real_kernel
This commit is contained in:
commit
8744bbfe3d
|
@ -34,6 +34,7 @@
|
|||
#include "ir/value.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/anf_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
/* namespace to support opt */
|
||||
|
@ -969,6 +970,9 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
if (!support_fallback_runtime) {
|
||||
return;
|
||||
}
|
||||
if (AnfUtils::IsRealKernel(cnode)) {
|
||||
return;
|
||||
}
|
||||
const auto &inputs = cnode->inputs();
|
||||
const auto &cur_func = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(cur_func);
|
||||
|
|
|
@ -70,8 +70,7 @@ class PrintConstStringWrapper : public AnfVisitor {
|
|||
return std::any_of(elements.cbegin(), elements.cend(),
|
||||
[&](const AbstractBasePtr &ele) { return CheckNeedConvert(ele); });
|
||||
}
|
||||
return !abs->isa<abstract::AbstractScalar>() && !abs->isa<abstract::AbstractTensor>() &&
|
||||
!abs->isa<abstract::AbstractNone>();
|
||||
return !abs->isa<abstract::AbstractScalar>() && !abs->isa<abstract::AbstractTensor>();
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertString(const AbstractBasePtr &abs) const {
|
||||
|
|
|
@ -69,7 +69,6 @@ class PyExecuteInitializer {
|
|||
// so special handling of None is required.
|
||||
if (script->ToString() == "None") {
|
||||
const auto &output = py::none();
|
||||
MS_LOG(DEBUG) << "Python output type: " << py::str(output.get_type()) << ", output: " << output;
|
||||
PushPyExecuteOutput(output);
|
||||
const auto &infer_shape = std::make_shared<abstract::Shape>(ShapeVector({1}));
|
||||
return abstract::MakeAbstract(infer_shape, kFloat64);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -493,19 +493,14 @@ py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr
|
|||
}
|
||||
return ref_tuple;
|
||||
}
|
||||
// The size of seq_abs may be larger than the size of value_list, because the backend will eliminate None.
|
||||
size_t ref_idx = 0;
|
||||
for (size_t i = 0; i < seq_abs->size(); i++) {
|
||||
auto elem_abs = seq_abs->elements()[i];
|
||||
if (elem_abs->isa<abstract::AbstractNone>()) {
|
||||
continue;
|
||||
}
|
||||
ref_tuple[ref_idx] = BaseRefToPyData(value_list[ref_idx], elem_abs);
|
||||
ref_idx++;
|
||||
}
|
||||
if (ref_idx != value_size) {
|
||||
MS_LOG(EXCEPTION) << "The size of elements (excluding None) should be equal to " << value_size << ", but got "
|
||||
<< ref_idx;
|
||||
MS_LOG(EXCEPTION) << "The size of elements should be equal to " << value_size << ", but got " << ref_idx;
|
||||
}
|
||||
ret = ref_tuple;
|
||||
return ret;
|
||||
|
|
|
@ -16,8 +16,11 @@ import pytest
|
|||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, jit, context, Parameter
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.nn.probability import distribution
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore as ms
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -222,7 +225,6 @@ def test_none_is_slice_in_list():
|
|||
assert res == 0
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support print None.")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -397,3 +399,44 @@ def test_none_is_input_of_tuple_return():
|
|||
|
||||
out = foo()
|
||||
assert out == (1, "a", None)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_none_is_input_of_tuple_return_2():
|
||||
"""
|
||||
Feature: Support None.
|
||||
Description: Support None is input of tuple, and the tuple is return.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class BernoulliCrossEntropy(Cell):
|
||||
def __init__(self, probs, seed=10, dtype=ms.int32, name='Bernoulli', dist='Bernoulli'):
|
||||
super().__init__()
|
||||
self.b = distribution.Bernoulli(probs, seed, dtype, name)
|
||||
self.dist = dist
|
||||
|
||||
def construct(self, probs1_b, probs1=None):
|
||||
if probs1 is None:
|
||||
out1 = self.b.cross_entropy(self.dist, probs1_b)
|
||||
out2 = self.b.kl_loss(self.dist, probs1_b)
|
||||
else:
|
||||
out1 = self.b.cross_entropy(self.dist, probs1_b, probs1)
|
||||
out2 = self.b.kl_loss(self.dist, probs1_b, probs1)
|
||||
out3 = self.b.probs
|
||||
return out1, out2, out3
|
||||
|
||||
probs = None
|
||||
probs1 = 0.2
|
||||
probs1_b = 0.9
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net_graph = BernoulliCrossEntropy(probs)
|
||||
out_me_graph = net_graph(Tensor(probs1_b), Tensor(probs1))
|
||||
print("out_me_graph: ", out_me_graph)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net_pynative = BernoulliCrossEntropy(probs)
|
||||
out_me_pynative = net_pynative(Tensor(probs1_b), Tensor(probs1))
|
||||
print("out_me_pynative: ", out_me_pynative)
|
||||
assert out_me_graph == out_me_pynative
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import math
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
@ -94,7 +93,6 @@ class GRUWeightBias():
|
|||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sit_gru_forward_input_3_32_32_is_32_hs_16():
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
|
||||
input_size = 32
|
||||
hidden_size = 16
|
||||
has_bias = True
|
||||
|
@ -117,7 +115,6 @@ def test_sit_gru_forward_input_3_32_32_is_32_hs_16():
|
|||
net.gru.b_ih_list = b_ih_list
|
||||
net.gru.b_hh_list = b_hh_list
|
||||
out, hy = net(input_ms, h0)
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
@ -138,7 +135,6 @@ def test_sit_gru_forward_input_3_32_32_is_32_hs_16():
|
|||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sit_gru_grad_input_3_32_32_is_32_hs_16():
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
|
||||
input_size = 32
|
||||
hidden_size = 16
|
||||
has_bias = True
|
||||
|
@ -166,7 +162,6 @@ def test_sit_gru_grad_input_3_32_32_is_32_hs_16():
|
|||
out_grad, _ = grad_net_inp(input_ms, h0)
|
||||
x_grad = out_grad[0].asnumpy()
|
||||
h_grad = out_grad[1].asnumpy()
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import math
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
@ -94,7 +93,6 @@ class LSTMWeightBias():
|
|||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sit_lstm_forward_input_3_32_32_is_32_hs_16():
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
|
||||
input_s = 32
|
||||
hidden_s = 16
|
||||
has_bias = True
|
||||
|
@ -118,7 +116,6 @@ def test_sit_lstm_forward_input_3_32_32_is_32_hs_16():
|
|||
net.lstm.b_ih_list = b_ih_list
|
||||
net.lstm.b_hh_list = b_hh_list
|
||||
out, (hy, cy) = net(input_ms, h0, c0)
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
@ -141,7 +138,6 @@ def test_sit_lstm_forward_input_3_32_32_is_32_hs_16():
|
|||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sit_lstm_grad_input_3_32_32_is_32_hs_16():
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
|
||||
input_s = 32
|
||||
hidden_s = 16
|
||||
has_bias = True
|
||||
|
@ -171,7 +167,6 @@ def test_sit_lstm_grad_input_3_32_32_is_32_hs_16():
|
|||
x_grad = out_grad[0].asnumpy()
|
||||
h_grad = out_grad[1].asnumpy()
|
||||
c_grad = out_grad[2].asnumpy()
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
|
Loading…
Reference in New Issue