!26982 [ME][Fallback] Review the local dict and try to eval out the interpret node.

Merge pull request !26982 from Margaret_wangrui/recheck_local_dict
This commit is contained in:
i-robot 2021-12-02 01:09:12 +00:00 committed by Gitee
commit b00b1cc276
3 changed files with 39 additions and 27 deletions

View File

@ -1375,13 +1375,25 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString()
<< ", filtered_local_dict:" << filtered_local_dict->ToString();
ValuePtr local_dict_value = filtered_local_dict->BuildValue();
py::object local_params_dict = ValueToPyData(local_dict_value);
py::dict local_params_dict = ReCheckLocalDict(filtered_local_dict);
MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << local_dict_value->ToString() << " -> "
<< py::str(local_params_dict);
params[1] = local_params_dict;
return params;
}
py::dict ReCheckLocalDict(const AbstractDictionaryPtr &filtered_local_dict) const {
const auto &keys_values = filtered_local_dict->elements();
py::dict local_params_dict;
for (auto &key_value : keys_values) {
ValuePtr element_value = key_value.second->BuildValue();
MS_EXCEPTION_IF_NULL(element_value);
auto py_data = ValueToPyData(element_value);
local_params_dict[py::str(key_value.first)] = py_data;
}
return local_params_dict;
}
AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const {
std::vector<AbstractAttribute> kv;
const auto &keys_values = abstract_dict->elements();

View File

@ -134,3 +134,29 @@ def test_list_insert():
x.insert(1, 2)
return Tensor(x)
assert np.all(list_insert().asnumpy() == np.array([1, 2, 3, 4]))
@ms_function
def np_fallback_func_tensor_index(x):
array_x = tuple([2, 3, 4, 5])
np_x = np.array(array_x).astype(np.float32)
me_x = Tensor(np_x)
me_x = me_x + me_x
return me_x[x]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_fallback_func_tensor_index():
"""
Feature: Fallback feature: support Tensor index.
Description: Fallback feature: support Tensor index.
Expectation: Fallback feature: support Tensor index.
"""
x = Tensor(1, mstype.int32)
output = np_fallback_func_tensor_index(x)
output_expect = Tensor(6, mstype.float32)
assert output == output_expect

View File

@ -156,7 +156,6 @@ def test_div_mod_func2_tensor():
assert "Not support Tensor or variable type as input during running JIT Fallback, but got" in str(err.value)
# NameError: name 'Tensor' is not defined.
@ms_function
def select_func(cond, x, y):
if isinstance(cond, (tuple, list)):
@ -175,7 +174,6 @@ def test_select_func():
print(select_func(cond, x, y))
# Not interpret 'Tensor'.
@ms_function
def select_func2(cond, x, y):
if isinstance(cond, (tuple, list)):
@ -194,7 +192,6 @@ def test_select_func2():
print(select_func2(cond, x, y))
# NameError: name 'Tensor' is not defined.
@ms_function
def slice_func(a, b):
a[1:3, ::] = b
@ -207,29 +204,6 @@ def test_slice_func():
print(slice_func(a, b))
@ms_function
def np_fallback_func_tensor_index(x):
array_x = tuple([2, 3, 4, 5])
np_x = np.array(array_x).astype(np.float32)
me_x = Tensor(np_x)
me_x = me_x + me_x
return me_x[x]
# NameError: name 'array_x' is not defined.
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_fallback_func_tensor_index():
"""
Feature: Fallback feature: support Tensor index.
Description: Fallback feature: support Tensor index.
Expectation: Fallback feature: support Tensor index.
"""
x = Tensor(1, mstype.int32)
output = np_fallback_func_tensor_index(x)
output_expect = Tensor(6, mstype.float32)
assert output == output_expect
# EvalCNode: This may be not defined, or it can't be a operator.
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_tensor_add():