forked from mindspore-Ecosystem/mindspore
support cell and primitive for str.format
This commit is contained in:
parent
9afc5afd3d
commit
bd619120ac
|
@ -540,6 +540,41 @@ void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, bool only_co
|
|||
(*dic)[ATTR_DTYPE] = arg_tensor->BuildType();
|
||||
(*dic)[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
|
||||
}
|
||||
namespace {
|
||||
py::object GetPyObjForPrimitiveAbstract(const PrimitiveAbstractClosurePtr &prim_abs) {
|
||||
auto prim = prim_abs->BuildValue();
|
||||
if (prim == nullptr) {
|
||||
return py::none();
|
||||
}
|
||||
if (prim->isa<prim::DoSignaturePrimitive>()) {
|
||||
auto do_sig_prim = prim->cast<prim::DoSignaturePrimitivePtr>();
|
||||
auto value = do_sig_prim->function();
|
||||
if (!value->isa<PrimitivePy>()) {
|
||||
return py::none();
|
||||
}
|
||||
auto prim_py = value->cast<PrimitivePyPtr>();
|
||||
return prim_py->GetPyObj();
|
||||
}
|
||||
if (prim->isa<PrimitivePy>()) {
|
||||
auto prim_py = prim->cast<PrimitivePyPtr>();
|
||||
return prim_py->GetPyObj();
|
||||
}
|
||||
return py::none();
|
||||
}
|
||||
|
||||
bool IsCallInstance(const PartialAbstractClosurePtr &partial_abs) {
|
||||
auto fn = partial_abs->fn();
|
||||
if (!fn->isa<PrimitiveAbstractClosure>()) {
|
||||
return false;
|
||||
}
|
||||
auto abs_prim = fn->cast<PrimitiveAbstractClosurePtr>();
|
||||
auto prim = abs_prim->prim();
|
||||
if (prim->name() == prim::kPrimCallInstance->name()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
|
||||
MS_EXCEPTION_IF_NULL(dic);
|
||||
|
@ -548,16 +583,29 @@ void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *
|
|||
(*dic)[ATTR_DTYPE] = abs_base->BuildType();
|
||||
(*dic)[ATTR_VALUE] = py::none();
|
||||
if (abs_base->isa<PartialAbstractClosure>()) {
|
||||
AbstractBasePtrList args = abs_base->cast<PartialAbstractClosurePtr>()->args();
|
||||
auto partial_abs = abs_base->cast<PartialAbstractClosurePtr>();
|
||||
AbstractBasePtrList args = partial_abs->args();
|
||||
if (!args.empty()) {
|
||||
MS_EXCEPTION_IF_NULL(args[0]->BuildValue());
|
||||
auto value = args[0]->BuildValue()->cast<parse::ClassTypePtr>();
|
||||
if (value != nullptr) {
|
||||
auto value = args[0]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (IsCallInstance(partial_abs)) {
|
||||
auto value_obj = value->cast<parse::MsClassObjectPtr>();
|
||||
if (value_obj != nullptr) {
|
||||
(*dic)[ATTR_DTYPE] = std::make_shared<MsClassType>();
|
||||
(*dic)[ATTR_VALUE] = value_obj->obj();
|
||||
return;
|
||||
}
|
||||
}
|
||||
auto value_obj = value->cast<parse::ClassTypePtr>();
|
||||
if (value_obj != nullptr) {
|
||||
(*dic)[ATTR_DTYPE] = std::make_shared<TypeType>();
|
||||
(*dic)[ATTR_VALUE] = value->obj();
|
||||
(*dic)[ATTR_VALUE] = value_obj->obj();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (abs_base->isa<PrimitiveAbstractClosure>()) {
|
||||
(*dic)[ATTR_VALUE] = GetPyObjForPrimitiveAbstract(abs_base->cast<PrimitiveAbstractClosurePtr>());
|
||||
}
|
||||
}
|
||||
|
||||
bool CheckType(const TypePtr &expected_type, const TypePtr &x) {
|
||||
|
|
|
@ -176,5 +176,6 @@ REGISTER_PYBIND_DEFINE(
|
|||
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
|
||||
(void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init());
|
||||
(void)py::class_<TypeEllipsis, Type, std::shared_ptr<TypeEllipsis>>(m_sub, "TypeEllipsis").def(py::init());
|
||||
(void)py::class_<MsClassType, Type, std::shared_ptr<MsClassType>>(m_sub, "TypeMsClassType").def(py::init());
|
||||
}));
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include "utils/hashing.h"
|
||||
#include "ops/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
|
|
@ -113,6 +113,7 @@ Dict = typing.Dict
|
|||
Slice = typing.Slice
|
||||
function_type = typing.Function
|
||||
Ellipsis_ = typing.TypeEllipsis
|
||||
MsClassType = typing.TypeMsClassType
|
||||
none_type = typing.TypeNone
|
||||
env_type_type = typing.EnvType
|
||||
tensor_type = typing.TensorType
|
||||
|
|
|
@ -14,7 +14,9 @@
|
|||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
import pytest
|
||||
from mindspore import ms_function, Tensor
|
||||
from mindspore import ms_function, Tensor, ms_class, context
|
||||
from mindspore.ops import prim_attr_register, Primitive
|
||||
from mindspore.nn import Cell
|
||||
|
||||
|
||||
def test_str_format_single_input():
|
||||
|
@ -126,12 +128,13 @@ def test_format_with_key_input():
|
|||
|
||||
@ms_function
|
||||
def foo():
|
||||
ms_str = "hello {name2},It's me,{name1}"
|
||||
ms_str = "hello {name2},It's me, {name1}"
|
||||
ms_format_str = ms_str.format(name2="Mind", name1="Spore")
|
||||
return ms_format_str
|
||||
|
||||
with pytest.raises(TypeError) as ex:
|
||||
foo()
|
||||
result_st = foo()
|
||||
assert result_st == "hello Mind,It's me, Spore"
|
||||
assert "Unsupported parameter type for python primitive," \
|
||||
" the parameter value is KeywordArg[key : name2, value : Mind]" in str(ex.value)
|
||||
|
||||
|
@ -154,8 +157,6 @@ def test_format_with_list_index():
|
|||
assert result_st == "hello Spore,It's me Mind"
|
||||
|
||||
|
||||
@pytest.mark.skip("Need to support kwargs input of primitive "
|
||||
"operations same as test_format_with_key_input")
|
||||
def test_format_with_map():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -169,9 +170,12 @@ def test_format_with_map():
|
|||
names = {"name1": "Mind", "name2": "Spore"}
|
||||
ms_format_str = ms_str.format(names)
|
||||
return ms_format_str
|
||||
with pytest.raises(TypeError) as ex:
|
||||
result_st = foo()
|
||||
assert result_st == "hello Spore,It's me Mind"
|
||||
assert "Unsupported parameter type for python primitive," \
|
||||
" the parameter value is dict: {keys: (name1, name2), values: (Mind, Spore)}" in str(ex.value)
|
||||
|
||||
result_st = foo()
|
||||
assert result_st == "hello Spore,It's me Mind"
|
||||
|
||||
|
||||
def test_format_as_function():
|
||||
|
@ -249,3 +253,147 @@ def test_format_padding():
|
|||
correct_str = ("05", "5xxx", "x10x", " 13", "13 ", " 13 ")
|
||||
result_str = foo()
|
||||
assert result_str == correct_str
|
||||
|
||||
|
||||
def test_str_format_using_ms_class():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.git
|
||||
"""
|
||||
|
||||
@ms_class
|
||||
class TestClass:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
@ms_function
|
||||
def test_func():
|
||||
test_obj = TestClass(123)
|
||||
format_str = "value is {0.value}".format(test_obj)
|
||||
return format_str
|
||||
format_str = test_func()
|
||||
assert format_str == "value is 123"
|
||||
|
||||
|
||||
def test_str_format_using_ms_class_in_init():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.git
|
||||
"""
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
@ms_class
|
||||
class TestClass:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
class TestCell(Cell):
|
||||
def __init__(self):
|
||||
super(TestCell, self).__init__()
|
||||
self.obj = TestClass(123)
|
||||
def construct(self):
|
||||
format_str = "value is {0.value}".format(self.obj)
|
||||
return format_str
|
||||
|
||||
test_cell = TestCell()
|
||||
format_str = test_cell()
|
||||
assert format_str == "value is 123"
|
||||
|
||||
|
||||
def test_str_format_using_primitive():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.git
|
||||
"""
|
||||
|
||||
class TestPrim(Primitive):
|
||||
@prim_attr_register
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
@ms_function
|
||||
def test_func():
|
||||
test_obj = TestPrim(123)
|
||||
format_str = "value is {0.x}".format(test_obj)
|
||||
return format_str
|
||||
format_str = test_func()
|
||||
assert format_str == "value is 123"
|
||||
|
||||
|
||||
def test_str_format_using_primitive_in_init():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.git
|
||||
"""
|
||||
|
||||
class TestPrim(Primitive):
|
||||
@prim_attr_register
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
class TestCell(Cell):
|
||||
def __init__(self):
|
||||
super(TestCell, self).__init__()
|
||||
self.prim = TestPrim(123)
|
||||
def construct(self):
|
||||
format_str = "value is {0.x}".format(self.prim)
|
||||
return format_str
|
||||
|
||||
test_cell = TestCell()
|
||||
format_str = test_cell()
|
||||
assert format_str == "value is 123"
|
||||
|
||||
|
||||
@pytest.mark.skip("Not support yet")
|
||||
def test_str_format_using_cell():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.git
|
||||
"""
|
||||
|
||||
class TestSubCell(Cell):
|
||||
def __init__(self, x):
|
||||
super(TestSubCell, self).__init__()
|
||||
self.x = x
|
||||
|
||||
class TestCell(Cell):
|
||||
def construct(self):
|
||||
test_obj = TestSubCell(123)
|
||||
format_str = "value is {0.x}".format(test_obj)
|
||||
return format_str
|
||||
|
||||
test_obj = TestCell()
|
||||
format_str = test_obj()
|
||||
assert format_str == "value is 123"
|
||||
|
||||
|
||||
@pytest.mark.skip("Not support yet")
|
||||
def test_str_format_using_cell_in_init():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.git
|
||||
"""
|
||||
|
||||
class TestSubCell(Cell):
|
||||
def __init__(self, x):
|
||||
super(TestSubCell, self).__init__()
|
||||
self.x = x
|
||||
|
||||
class TestCell(Cell):
|
||||
def __init__(self):
|
||||
super(TestCell, self).__init__()
|
||||
self.test_sub_cell = TestSubCell(123)
|
||||
def construct(self):
|
||||
format_str = "value is {0.x}".format(self.test_sub_cell)
|
||||
return format_str
|
||||
|
||||
test_cell = TestCell()
|
||||
format_str = test_cell()
|
||||
assert format_str == "value is 123"
|
||||
|
|
Loading…
Reference in New Issue