forked from mindspore-Ecosystem/mindspore
!26982 [ME][Fallback] Review the local dict and try to eval out the interpret node.
Merge pull request !26982 from Margaret_wangrui/recheck_local_dict
This commit is contained in:
commit
b00b1cc276
|
@ -1375,13 +1375,25 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString()
|
||||
<< ", filtered_local_dict:" << filtered_local_dict->ToString();
|
||||
ValuePtr local_dict_value = filtered_local_dict->BuildValue();
|
||||
py::object local_params_dict = ValueToPyData(local_dict_value);
|
||||
py::dict local_params_dict = ReCheckLocalDict(filtered_local_dict);
|
||||
MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << local_dict_value->ToString() << " -> "
|
||||
<< py::str(local_params_dict);
|
||||
params[1] = local_params_dict;
|
||||
return params;
|
||||
}
|
||||
|
||||
py::dict ReCheckLocalDict(const AbstractDictionaryPtr &filtered_local_dict) const {
|
||||
const auto &keys_values = filtered_local_dict->elements();
|
||||
py::dict local_params_dict;
|
||||
for (auto &key_value : keys_values) {
|
||||
ValuePtr element_value = key_value.second->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(element_value);
|
||||
auto py_data = ValueToPyData(element_value);
|
||||
local_params_dict[py::str(key_value.first)] = py_data;
|
||||
}
|
||||
return local_params_dict;
|
||||
}
|
||||
|
||||
AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const {
|
||||
std::vector<AbstractAttribute> kv;
|
||||
const auto &keys_values = abstract_dict->elements();
|
||||
|
|
|
@ -134,3 +134,29 @@ def test_list_insert():
|
|||
x.insert(1, 2)
|
||||
return Tensor(x)
|
||||
assert np.all(list_insert().asnumpy() == np.array([1, 2, 3, 4]))
|
||||
|
||||
|
||||
@ms_function
|
||||
def np_fallback_func_tensor_index(x):
|
||||
array_x = tuple([2, 3, 4, 5])
|
||||
np_x = np.array(array_x).astype(np.float32)
|
||||
me_x = Tensor(np_x)
|
||||
me_x = me_x + me_x
|
||||
return me_x[x]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_fallback_func_tensor_index():
|
||||
"""
|
||||
Feature: Fallback feature: support Tensor index.
|
||||
Description: Fallback feature: support Tensor index.
|
||||
Expectation: Fallback feature: support Tensor index.
|
||||
"""
|
||||
x = Tensor(1, mstype.int32)
|
||||
output = np_fallback_func_tensor_index(x)
|
||||
output_expect = Tensor(6, mstype.float32)
|
||||
assert output == output_expect
|
||||
|
|
|
@ -156,7 +156,6 @@ def test_div_mod_func2_tensor():
|
|||
assert "Not support Tensor or variable type as input during running JIT Fallback, but got" in str(err.value)
|
||||
|
||||
|
||||
# NameError: name 'Tensor' is not defined.
|
||||
@ms_function
|
||||
def select_func(cond, x, y):
|
||||
if isinstance(cond, (tuple, list)):
|
||||
|
@ -175,7 +174,6 @@ def test_select_func():
|
|||
print(select_func(cond, x, y))
|
||||
|
||||
|
||||
# Not interpret 'Tensor'.
|
||||
@ms_function
|
||||
def select_func2(cond, x, y):
|
||||
if isinstance(cond, (tuple, list)):
|
||||
|
@ -194,7 +192,6 @@ def test_select_func2():
|
|||
print(select_func2(cond, x, y))
|
||||
|
||||
|
||||
# NameError: name 'Tensor' is not defined.
|
||||
@ms_function
|
||||
def slice_func(a, b):
|
||||
a[1:3, ::] = b
|
||||
|
@ -207,29 +204,6 @@ def test_slice_func():
|
|||
print(slice_func(a, b))
|
||||
|
||||
|
||||
@ms_function
|
||||
def np_fallback_func_tensor_index(x):
|
||||
array_x = tuple([2, 3, 4, 5])
|
||||
np_x = np.array(array_x).astype(np.float32)
|
||||
me_x = Tensor(np_x)
|
||||
me_x = me_x + me_x
|
||||
return me_x[x]
|
||||
|
||||
|
||||
# NameError: name 'array_x' is not defined.
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_np_fallback_func_tensor_index():
|
||||
"""
|
||||
Feature: Fallback feature: support Tensor index.
|
||||
Description: Fallback feature: support Tensor index.
|
||||
Expectation: Fallback feature: support Tensor index.
|
||||
"""
|
||||
x = Tensor(1, mstype.int32)
|
||||
output = np_fallback_func_tensor_index(x)
|
||||
output_expect = Tensor(6, mstype.float32)
|
||||
assert output == output_expect
|
||||
|
||||
|
||||
# EvalCNode: This may be not defined, or it can't be a operator.
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_np_tensor_add():
|
||||
|
|
Loading…
Reference in New Issue