[ME][Fallback] Review the local dict and try to eval out the interpret node

This commit is contained in:
Margaret_wangrui 2021-11-30 11:38:28 +08:00
parent c4762bf362
commit ee1b4d4e6d
3 changed files with 39 additions and 27 deletions

View File

@ -1369,13 +1369,25 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString() MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString()
<< ", filtered_local_dict:" << filtered_local_dict->ToString(); << ", filtered_local_dict:" << filtered_local_dict->ToString();
ValuePtr local_dict_value = filtered_local_dict->BuildValue(); ValuePtr local_dict_value = filtered_local_dict->BuildValue();
py::object local_params_dict = ValueToPyData(local_dict_value); py::dict local_params_dict = ReCheckLocalDict(filtered_local_dict);
MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << local_dict_value->ToString() << " -> " MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << local_dict_value->ToString() << " -> "
<< py::str(local_params_dict); << py::str(local_params_dict);
params[1] = local_params_dict; params[1] = local_params_dict;
return params; return params;
} }
py::dict ReCheckLocalDict(const AbstractDictionaryPtr &filtered_local_dict) const {
const auto &keys_values = filtered_local_dict->elements();
py::dict local_params_dict;
for (auto &key_value : keys_values) {
ValuePtr element_value = key_value.second->BuildValue();
MS_EXCEPTION_IF_NULL(element_value);
auto py_data = ValueToPyData(element_value);
local_params_dict[py::str(key_value.first)] = py_data;
}
return local_params_dict;
}
AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const { AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const {
std::vector<AbstractAttribute> kv; std::vector<AbstractAttribute> kv;
const auto &keys_values = abstract_dict->elements(); const auto &keys_values = abstract_dict->elements();

View File

@ -134,3 +134,29 @@ def test_list_insert():
x.insert(1, 2) x.insert(1, 2)
return Tensor(x) return Tensor(x)
assert np.all(list_insert().asnumpy() == np.array([1, 2, 3, 4])) assert np.all(list_insert().asnumpy() == np.array([1, 2, 3, 4]))
@ms_function
def np_fallback_func_tensor_index(x):
array_x = tuple([2, 3, 4, 5])
np_x = np.array(array_x).astype(np.float32)
me_x = Tensor(np_x)
me_x = me_x + me_x
return me_x[x]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_fallback_func_tensor_index():
"""
Feature: Fallback feature: support Tensor index.
Description: Fallback feature: support Tensor index.
Expectation: Fallback feature: support Tensor index.
"""
x = Tensor(1, mstype.int32)
output = np_fallback_func_tensor_index(x)
output_expect = Tensor(6, mstype.float32)
assert output == output_expect

View File

@ -156,7 +156,6 @@ def test_div_mod_func2_tensor():
assert "Not support Tensor or variable type as input during running JIT Fallback, but got" in str(err.value) assert "Not support Tensor or variable type as input during running JIT Fallback, but got" in str(err.value)
# NameError: name 'Tensor' is not defined.
@ms_function @ms_function
def select_func(cond, x, y): def select_func(cond, x, y):
if isinstance(cond, (tuple, list)): if isinstance(cond, (tuple, list)):
@ -175,7 +174,6 @@ def test_select_func():
print(select_func(cond, x, y)) print(select_func(cond, x, y))
# Not interpret 'Tensor'.
@ms_function @ms_function
def select_func2(cond, x, y): def select_func2(cond, x, y):
if isinstance(cond, (tuple, list)): if isinstance(cond, (tuple, list)):
@ -194,7 +192,6 @@ def test_select_func2():
print(select_func2(cond, x, y)) print(select_func2(cond, x, y))
# NameError: name 'Tensor' is not defined.
@ms_function @ms_function
def slice_func(a, b): def slice_func(a, b):
a[1:3, ::] = b a[1:3, ::] = b
@ -207,29 +204,6 @@ def test_slice_func():
print(slice_func(a, b)) print(slice_func(a, b))
@ms_function
def np_fallback_func_tensor_index(x):
array_x = tuple([2, 3, 4, 5])
np_x = np.array(array_x).astype(np.float32)
me_x = Tensor(np_x)
me_x = me_x + me_x
return me_x[x]
# NameError: name 'array_x' is not defined.
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_fallback_func_tensor_index():
"""
Feature: Fallback feature: support Tensor index.
Description: Fallback feature: support Tensor index.
Expectation: Fallback feature: support Tensor index.
"""
x = Tensor(1, mstype.int32)
output = np_fallback_func_tensor_index(x)
output_expect = Tensor(6, mstype.float32)
assert output == output_expect
# EvalCNode: This may be not defined, or it can't be a operator. # EvalCNode: This may be not defined, or it can't be a operator.
@pytest.mark.skip(reason='Not support graph fallback feature yet') @pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_tensor_add(): def test_np_tensor_add():