forked from mindspore-Ecosystem/mindspore
[ME][Fallback] Review the local dict and try to eval out the interpret node
This commit is contained in:
parent
c4762bf362
commit
ee1b4d4e6d
|
@ -1369,13 +1369,25 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
||||||
MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString()
|
MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString()
|
||||||
<< ", filtered_local_dict:" << filtered_local_dict->ToString();
|
<< ", filtered_local_dict:" << filtered_local_dict->ToString();
|
||||||
ValuePtr local_dict_value = filtered_local_dict->BuildValue();
|
ValuePtr local_dict_value = filtered_local_dict->BuildValue();
|
||||||
py::object local_params_dict = ValueToPyData(local_dict_value);
|
py::dict local_params_dict = ReCheckLocalDict(filtered_local_dict);
|
||||||
MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << local_dict_value->ToString() << " -> "
|
MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << local_dict_value->ToString() << " -> "
|
||||||
<< py::str(local_params_dict);
|
<< py::str(local_params_dict);
|
||||||
params[1] = local_params_dict;
|
params[1] = local_params_dict;
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
py::dict ReCheckLocalDict(const AbstractDictionaryPtr &filtered_local_dict) const {
|
||||||
|
const auto &keys_values = filtered_local_dict->elements();
|
||||||
|
py::dict local_params_dict;
|
||||||
|
for (auto &key_value : keys_values) {
|
||||||
|
ValuePtr element_value = key_value.second->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(element_value);
|
||||||
|
auto py_data = ValueToPyData(element_value);
|
||||||
|
local_params_dict[py::str(key_value.first)] = py_data;
|
||||||
|
}
|
||||||
|
return local_params_dict;
|
||||||
|
}
|
||||||
|
|
||||||
AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const {
|
AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const {
|
||||||
std::vector<AbstractAttribute> kv;
|
std::vector<AbstractAttribute> kv;
|
||||||
const auto &keys_values = abstract_dict->elements();
|
const auto &keys_values = abstract_dict->elements();
|
||||||
|
|
|
@ -134,3 +134,29 @@ def test_list_insert():
|
||||||
x.insert(1, 2)
|
x.insert(1, 2)
|
||||||
return Tensor(x)
|
return Tensor(x)
|
||||||
assert np.all(list_insert().asnumpy() == np.array([1, 2, 3, 4]))
|
assert np.all(list_insert().asnumpy() == np.array([1, 2, 3, 4]))
|
||||||
|
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def np_fallback_func_tensor_index(x):
|
||||||
|
array_x = tuple([2, 3, 4, 5])
|
||||||
|
np_x = np.array(array_x).astype(np.float32)
|
||||||
|
me_x = Tensor(np_x)
|
||||||
|
me_x = me_x + me_x
|
||||||
|
return me_x[x]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_np_fallback_func_tensor_index():
|
||||||
|
"""
|
||||||
|
Feature: Fallback feature: support Tensor index.
|
||||||
|
Description: Fallback feature: support Tensor index.
|
||||||
|
Expectation: Fallback feature: support Tensor index.
|
||||||
|
"""
|
||||||
|
x = Tensor(1, mstype.int32)
|
||||||
|
output = np_fallback_func_tensor_index(x)
|
||||||
|
output_expect = Tensor(6, mstype.float32)
|
||||||
|
assert output == output_expect
|
||||||
|
|
|
@ -156,7 +156,6 @@ def test_div_mod_func2_tensor():
|
||||||
assert "Not support Tensor or variable type as input during running JIT Fallback, but got" in str(err.value)
|
assert "Not support Tensor or variable type as input during running JIT Fallback, but got" in str(err.value)
|
||||||
|
|
||||||
|
|
||||||
# NameError: name 'Tensor' is not defined.
|
|
||||||
@ms_function
|
@ms_function
|
||||||
def select_func(cond, x, y):
|
def select_func(cond, x, y):
|
||||||
if isinstance(cond, (tuple, list)):
|
if isinstance(cond, (tuple, list)):
|
||||||
|
@ -175,7 +174,6 @@ def test_select_func():
|
||||||
print(select_func(cond, x, y))
|
print(select_func(cond, x, y))
|
||||||
|
|
||||||
|
|
||||||
# Not interpret 'Tensor'.
|
|
||||||
@ms_function
|
@ms_function
|
||||||
def select_func2(cond, x, y):
|
def select_func2(cond, x, y):
|
||||||
if isinstance(cond, (tuple, list)):
|
if isinstance(cond, (tuple, list)):
|
||||||
|
@ -194,7 +192,6 @@ def test_select_func2():
|
||||||
print(select_func2(cond, x, y))
|
print(select_func2(cond, x, y))
|
||||||
|
|
||||||
|
|
||||||
# NameError: name 'Tensor' is not defined.
|
|
||||||
@ms_function
|
@ms_function
|
||||||
def slice_func(a, b):
|
def slice_func(a, b):
|
||||||
a[1:3, ::] = b
|
a[1:3, ::] = b
|
||||||
|
@ -207,29 +204,6 @@ def test_slice_func():
|
||||||
print(slice_func(a, b))
|
print(slice_func(a, b))
|
||||||
|
|
||||||
|
|
||||||
@ms_function
|
|
||||||
def np_fallback_func_tensor_index(x):
|
|
||||||
array_x = tuple([2, 3, 4, 5])
|
|
||||||
np_x = np.array(array_x).astype(np.float32)
|
|
||||||
me_x = Tensor(np_x)
|
|
||||||
me_x = me_x + me_x
|
|
||||||
return me_x[x]
|
|
||||||
|
|
||||||
|
|
||||||
# NameError: name 'array_x' is not defined.
|
|
||||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
|
||||||
def test_np_fallback_func_tensor_index():
|
|
||||||
"""
|
|
||||||
Feature: Fallback feature: support Tensor index.
|
|
||||||
Description: Fallback feature: support Tensor index.
|
|
||||||
Expectation: Fallback feature: support Tensor index.
|
|
||||||
"""
|
|
||||||
x = Tensor(1, mstype.int32)
|
|
||||||
output = np_fallback_func_tensor_index(x)
|
|
||||||
output_expect = Tensor(6, mstype.float32)
|
|
||||||
assert output == output_expect
|
|
||||||
|
|
||||||
|
|
||||||
# EvalCNode: This may be not defined, or it can't be a operator.
|
# EvalCNode: This may be not defined, or it can't be a operator.
|
||||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||||
def test_np_tensor_add():
|
def test_np_tensor_add():
|
||||||
|
|
Loading…
Reference in New Issue