support cell and primitive for str.format

This commit is contained in:
lianliguang 2022-05-07 17:46:24 +08:00
parent 9afc5afd3d
commit bd619120ac
5 changed files with 211 additions and 12 deletions

View File

@ -540,6 +540,41 @@ void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, bool only_co
(*dic)[ATTR_DTYPE] = arg_tensor->BuildType();
(*dic)[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
}
namespace {
py::object GetPyObjForPrimitiveAbstract(const PrimitiveAbstractClosurePtr &prim_abs) {
auto prim = prim_abs->BuildValue();
if (prim == nullptr) {
return py::none();
}
if (prim->isa<prim::DoSignaturePrimitive>()) {
auto do_sig_prim = prim->cast<prim::DoSignaturePrimitivePtr>();
auto value = do_sig_prim->function();
if (!value->isa<PrimitivePy>()) {
return py::none();
}
auto prim_py = value->cast<PrimitivePyPtr>();
return prim_py->GetPyObj();
}
if (prim->isa<PrimitivePy>()) {
auto prim_py = prim->cast<PrimitivePyPtr>();
return prim_py->GetPyObj();
}
return py::none();
}
bool IsCallInstance(const PartialAbstractClosurePtr &partial_abs) {
auto fn = partial_abs->fn();
if (!fn->isa<PrimitiveAbstractClosure>()) {
return false;
}
auto abs_prim = fn->cast<PrimitiveAbstractClosurePtr>();
auto prim = abs_prim->prim();
if (prim->name() == prim::kPrimCallInstance->name()) {
return true;
}
return false;
}
} // namespace
void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
MS_EXCEPTION_IF_NULL(dic);
@ -548,16 +583,29 @@ void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *
(*dic)[ATTR_DTYPE] = abs_base->BuildType();
(*dic)[ATTR_VALUE] = py::none();
if (abs_base->isa<PartialAbstractClosure>()) {
AbstractBasePtrList args = abs_base->cast<PartialAbstractClosurePtr>()->args();
auto partial_abs = abs_base->cast<PartialAbstractClosurePtr>();
AbstractBasePtrList args = partial_abs->args();
if (!args.empty()) {
MS_EXCEPTION_IF_NULL(args[0]->BuildValue());
auto value = args[0]->BuildValue()->cast<parse::ClassTypePtr>();
if (value != nullptr) {
auto value = args[0]->BuildValue();
MS_EXCEPTION_IF_NULL(value);
if (IsCallInstance(partial_abs)) {
auto value_obj = value->cast<parse::MsClassObjectPtr>();
if (value_obj != nullptr) {
(*dic)[ATTR_DTYPE] = std::make_shared<MsClassType>();
(*dic)[ATTR_VALUE] = value_obj->obj();
return;
}
}
auto value_obj = value->cast<parse::ClassTypePtr>();
if (value_obj != nullptr) {
(*dic)[ATTR_DTYPE] = std::make_shared<TypeType>();
(*dic)[ATTR_VALUE] = value->obj();
(*dic)[ATTR_VALUE] = value_obj->obj();
}
}
}
if (abs_base->isa<PrimitiveAbstractClosure>()) {
(*dic)[ATTR_VALUE] = GetPyObjForPrimitiveAbstract(abs_base->cast<PrimitiveAbstractClosurePtr>());
}
}
bool CheckType(const TypePtr &expected_type, const TypePtr &x) {

View File

@ -176,5 +176,6 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
(void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init());
(void)py::class_<TypeEllipsis, Type, std::shared_ptr<TypeEllipsis>>(m_sub, "TypeEllipsis").def(py::init());
(void)py::class_<MsClassType, Type, std::shared_ptr<MsClassType>>(m_sub, "TypeMsClassType").def(py::init());
}));
} // namespace mindspore

View File

@ -18,6 +18,7 @@
#include <vector>
#include "utils/hashing.h"
#include "ops/core_ops.h"
namespace mindspore {
namespace abstract {

View File

@ -113,6 +113,7 @@ Dict = typing.Dict
Slice = typing.Slice
function_type = typing.Function
Ellipsis_ = typing.TypeEllipsis
MsClassType = typing.TypeMsClassType
none_type = typing.TypeNone
env_type_type = typing.EnvType
tensor_type = typing.TensorType

View File

@ -14,7 +14,9 @@
# ============================================================================
""" test graph fallback """
import pytest
from mindspore import ms_function, Tensor
from mindspore import ms_function, Tensor, ms_class, context
from mindspore.ops import prim_attr_register, Primitive
from mindspore.nn import Cell
def test_str_format_single_input():
@ -126,12 +128,13 @@ def test_format_with_key_input():
@ms_function
def foo():
ms_str = "hello {name2},It's me,{name1}"
ms_str = "hello {name2},It's me, {name1}"
ms_format_str = ms_str.format(name2="Mind", name1="Spore")
return ms_format_str
with pytest.raises(TypeError) as ex:
foo()
result_st = foo()
assert result_st == "hello Mind,It's me, Spore"
assert "Unsupported parameter type for python primitive," \
" the parameter value is KeywordArg[key : name2, value : Mind]" in str(ex.value)
@ -154,8 +157,6 @@ def test_format_with_list_index():
assert result_st == "hello Spore,It's me Mind"
@pytest.mark.skip("Need to support kwargs input of primitive "
"operations same as test_format_with_key_input")
def test_format_with_map():
"""
Feature: JIT Fallback
@ -169,9 +170,12 @@ def test_format_with_map():
names = {"name1": "Mind", "name2": "Spore"}
ms_format_str = ms_str.format(names)
return ms_format_str
with pytest.raises(TypeError) as ex:
result_st = foo()
assert result_st == "hello Spore,It's me Mind"
assert "Unsupported parameter type for python primitive," \
" the parameter value is dict: {keys: (name1, name2), values: (Mind, Spore)}" in str(ex.value)
result_st = foo()
assert result_st == "hello Spore,It's me Mind"
def test_format_as_function():
@ -249,3 +253,147 @@ def test_format_padding():
correct_str = ("05", "5xxx", "x10x", " 13", "13 ", " 13 ")
result_str = foo()
assert result_str == correct_str
def test_str_format_using_ms_class():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.git
"""
@ms_class
class TestClass:
def __init__(self, value):
self.value = value
@ms_function
def test_func():
test_obj = TestClass(123)
format_str = "value is {0.value}".format(test_obj)
return format_str
format_str = test_func()
assert format_str == "value is 123"
def test_str_format_using_ms_class_in_init():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.git
"""
context.set_context(mode=context.GRAPH_MODE)
@ms_class
class TestClass:
def __init__(self, value):
self.value = value
class TestCell(Cell):
def __init__(self):
super(TestCell, self).__init__()
self.obj = TestClass(123)
def construct(self):
format_str = "value is {0.value}".format(self.obj)
return format_str
test_cell = TestCell()
format_str = test_cell()
assert format_str == "value is 123"
def test_str_format_using_primitive():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.git
"""
class TestPrim(Primitive):
@prim_attr_register
def __init__(self, x):
self.x = x
@ms_function
def test_func():
test_obj = TestPrim(123)
format_str = "value is {0.x}".format(test_obj)
return format_str
format_str = test_func()
assert format_str == "value is 123"
def test_str_format_using_primitive_in_init():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.git
"""
class TestPrim(Primitive):
@prim_attr_register
def __init__(self, x):
self.x = x
class TestCell(Cell):
def __init__(self):
super(TestCell, self).__init__()
self.prim = TestPrim(123)
def construct(self):
format_str = "value is {0.x}".format(self.prim)
return format_str
test_cell = TestCell()
format_str = test_cell()
assert format_str == "value is 123"
@pytest.mark.skip("Not support yet")
def test_str_format_using_cell():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.git
"""
class TestSubCell(Cell):
def __init__(self, x):
super(TestSubCell, self).__init__()
self.x = x
class TestCell(Cell):
def construct(self):
test_obj = TestSubCell(123)
format_str = "value is {0.x}".format(test_obj)
return format_str
test_obj = TestCell()
format_str = test_obj()
assert format_str == "value is 123"
@pytest.mark.skip("Not support yet")
def test_str_format_using_cell_in_init():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.git
"""
class TestSubCell(Cell):
def __init__(self, x):
super(TestSubCell, self).__init__()
self.x = x
class TestCell(Cell):
def __init__(self):
super(TestCell, self).__init__()
self.test_sub_cell = TestSubCell(123)
def construct(self):
format_str = "value is {0.x}".format(self.test_sub_cell)
return format_str
test_cell = TestCell()
format_str = test_cell()
assert format_str == "value is 123"