diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 1871b73d32e..19ab6639328 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -540,6 +540,41 @@ void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, bool only_co (*dic)[ATTR_DTYPE] = arg_tensor->BuildType(); (*dic)[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue()); } +namespace { +py::object GetPyObjForPrimitiveAbstract(const PrimitiveAbstractClosurePtr &prim_abs) { + auto prim = prim_abs->BuildValue(); + if (prim == nullptr) { + return py::none(); + } + if (prim->isa()) { + auto do_sig_prim = prim->cast(); + auto value = do_sig_prim->function(); + if (!value->isa()) { + return py::none(); + } + auto prim_py = value->cast(); + return prim_py->GetPyObj(); + } + if (prim->isa()) { + auto prim_py = prim->cast(); + return prim_py->GetPyObj(); + } + return py::none(); +} + +bool IsCallInstance(const PartialAbstractClosurePtr &partial_abs) { + auto fn = partial_abs->fn(); + if (!fn->isa()) { + return false; + } + auto abs_prim = fn->cast(); + auto prim = abs_prim->prim(); + if (prim->name() == prim::kPrimCallInstance->name()) { + return true; + } + return false; +} +} // namespace void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *dic) { MS_EXCEPTION_IF_NULL(dic); @@ -548,16 +583,29 @@ void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict * (*dic)[ATTR_DTYPE] = abs_base->BuildType(); (*dic)[ATTR_VALUE] = py::none(); if (abs_base->isa()) { - AbstractBasePtrList args = abs_base->cast()->args(); + auto partial_abs = abs_base->cast(); + AbstractBasePtrList args = partial_abs->args(); if (!args.empty()) { - MS_EXCEPTION_IF_NULL(args[0]->BuildValue()); - auto value = args[0]->BuildValue()->cast(); - if (value != nullptr) { + auto value = args[0]->BuildValue(); + MS_EXCEPTION_IF_NULL(value); + if (IsCallInstance(partial_abs)) { + auto value_obj = value->cast(); + if (value_obj != nullptr) { + (*dic)[ATTR_DTYPE] = std::make_shared(); + (*dic)[ATTR_VALUE] = value_obj->obj(); + return; + } + } + auto value_obj = value->cast(); + if (value_obj != nullptr) { (*dic)[ATTR_DTYPE] = std::make_shared(); - (*dic)[ATTR_VALUE] = value->obj(); + (*dic)[ATTR_VALUE] = value_obj->obj(); } } } + if (abs_base->isa()) { + (*dic)[ATTR_VALUE] = GetPyObjForPrimitiveAbstract(abs_base->cast()); + } } bool CheckType(const TypePtr &expected_type, const TypePtr &x) { diff --git a/mindspore/ccsrc/pybind_api/ir/dtype_py.cc b/mindspore/ccsrc/pybind_api/ir/dtype_py.cc index 3ffc2ceedf6..259760f8185 100644 --- a/mindspore/ccsrc/pybind_api/ir/dtype_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/dtype_py.cc @@ -176,5 +176,6 @@ REGISTER_PYBIND_DEFINE( (void)py::class_>(m_sub, "TypeAnything").def(py::init()); (void)py::class_>(m_sub, "Slice").def(py::init()); (void)py::class_>(m_sub, "TypeEllipsis").def(py::init()); + (void)py::class_>(m_sub, "TypeMsClassType").def(py::init()); })); } // namespace mindspore diff --git a/mindspore/core/abstract/abstract_function.cc b/mindspore/core/abstract/abstract_function.cc index ca8dff5e5a3..081df1c8cdf 100644 --- a/mindspore/core/abstract/abstract_function.cc +++ b/mindspore/core/abstract/abstract_function.cc @@ -18,6 +18,7 @@ #include #include "utils/hashing.h" +#include "ops/core_ops.h" namespace mindspore { namespace abstract { diff --git a/mindspore/python/mindspore/common/dtype.py b/mindspore/python/mindspore/common/dtype.py index 46c7e2d0843..e7d95448a95 100644 --- a/mindspore/python/mindspore/common/dtype.py +++ b/mindspore/python/mindspore/common/dtype.py @@ -113,6 +113,7 @@ Dict = typing.Dict Slice = typing.Slice function_type = typing.Function Ellipsis_ = typing.TypeEllipsis +MsClassType = typing.TypeMsClassType none_type = typing.TypeNone env_type_type = typing.EnvType tensor_type = typing.TensorType diff --git a/tests/ut/python/fallback/test_graph_fallback_str_format.py b/tests/ut/python/fallback/test_graph_fallback_str_format.py index 3c5f38c0b56..d3081c4dddf 100644 --- a/tests/ut/python/fallback/test_graph_fallback_str_format.py +++ b/tests/ut/python/fallback/test_graph_fallback_str_format.py @@ -14,7 +14,9 @@ # ============================================================================ """ test graph fallback """ import pytest -from mindspore import ms_function, Tensor +from mindspore import ms_function, Tensor, ms_class, context +from mindspore.ops import prim_attr_register, Primitive +from mindspore.nn import Cell def test_str_format_single_input(): @@ -126,12 +128,13 @@ def test_format_with_key_input(): @ms_function def foo(): - ms_str = "hello {name2},It's me,{name1}" + ms_str = "hello {name2},It's me, {name1}" ms_format_str = ms_str.format(name2="Mind", name1="Spore") return ms_format_str with pytest.raises(TypeError) as ex: - foo() + result_st = foo() + assert result_st == "hello Mind,It's me, Spore" assert "Unsupported parameter type for python primitive," \ " the parameter value is KeywordArg[key : name2, value : Mind]" in str(ex.value) @@ -154,8 +157,6 @@ def test_format_with_list_index(): assert result_st == "hello Spore,It's me Mind" -@pytest.mark.skip("Need to support kwargs input of primitive " - "operations same as test_format_with_key_input") def test_format_with_map(): """ Feature: JIT Fallback @@ -169,9 +170,12 @@ def test_format_with_map(): names = {"name1": "Mind", "name2": "Spore"} ms_format_str = ms_str.format(names) return ms_format_str + with pytest.raises(TypeError) as ex: + result_st = foo() + assert result_st == "hello Spore,It's me Mind" + assert "Unsupported parameter type for python primitive," \ + " the parameter value is dict: {keys: (name1, name2), values: (Mind, Spore)}" in str(ex.value) - result_st = foo() - assert result_st == "hello Spore,It's me Mind" def test_format_as_function(): @@ -249,3 +253,147 @@ def test_format_padding(): correct_str = ("05", "5xxx", "x10x", " 13", "13 ", " 13 ") result_str = foo() assert result_str == correct_str + + +def test_str_format_using_ms_class(): + """ + Feature: JIT Fallback + Description: Test str.format() in graph mode. + Expectation: No exception.git + """ + + @ms_class + class TestClass: + def __init__(self, value): + self.value = value + + @ms_function + def test_func(): + test_obj = TestClass(123) + format_str = "value is {0.value}".format(test_obj) + return format_str + format_str = test_func() + assert format_str == "value is 123" + + +def test_str_format_using_ms_class_in_init(): + """ + Feature: JIT Fallback + Description: Test str.format() in graph mode. + Expectation: No exception.git + """ + + context.set_context(mode=context.GRAPH_MODE) + + @ms_class + class TestClass: + def __init__(self, value): + self.value = value + + class TestCell(Cell): + def __init__(self): + super(TestCell, self).__init__() + self.obj = TestClass(123) + def construct(self): + format_str = "value is {0.value}".format(self.obj) + return format_str + + test_cell = TestCell() + format_str = test_cell() + assert format_str == "value is 123" + + +def test_str_format_using_primitive(): + """ + Feature: JIT Fallback + Description: Test str.format() in graph mode. + Expectation: No exception.git + """ + + class TestPrim(Primitive): + @prim_attr_register + def __init__(self, x): + self.x = x + + @ms_function + def test_func(): + test_obj = TestPrim(123) + format_str = "value is {0.x}".format(test_obj) + return format_str + format_str = test_func() + assert format_str == "value is 123" + + +def test_str_format_using_primitive_in_init(): + """ + Feature: JIT Fallback + Description: Test str.format() in graph mode. + Expectation: No exception.git + """ + + class TestPrim(Primitive): + @prim_attr_register + def __init__(self, x): + self.x = x + + class TestCell(Cell): + def __init__(self): + super(TestCell, self).__init__() + self.prim = TestPrim(123) + def construct(self): + format_str = "value is {0.x}".format(self.prim) + return format_str + + test_cell = TestCell() + format_str = test_cell() + assert format_str == "value is 123" + + +@pytest.mark.skip("Not support yet") +def test_str_format_using_cell(): + """ + Feature: JIT Fallback + Description: Test str.format() in graph mode. + Expectation: No exception.git + """ + + class TestSubCell(Cell): + def __init__(self, x): + super(TestSubCell, self).__init__() + self.x = x + + class TestCell(Cell): + def construct(self): + test_obj = TestSubCell(123) + format_str = "value is {0.x}".format(test_obj) + return format_str + + test_obj = TestCell() + format_str = test_obj() + assert format_str == "value is 123" + + +@pytest.mark.skip("Not support yet") +def test_str_format_using_cell_in_init(): + """ + Feature: JIT Fallback + Description: Test str.format() in graph mode. + Expectation: No exception.git + """ + + class TestSubCell(Cell): + def __init__(self, x): + super(TestSubCell, self).__init__() + self.x = x + + class TestCell(Cell): + def __init__(self): + super(TestCell, self).__init__() + self.test_sub_cell = TestSubCell(123) + def construct(self): + format_str = "value is {0.x}".format(self.test_sub_cell) + return format_str + + test_cell = TestCell() + format_str = test_cell() + assert format_str == "value is 123"