!6545 improve the recognition of Parameter object and raise error when convert keywordarg to pydata

Merge pull request !6545 from zhangbuxue/improve_the_recognition_of_Parameter_object_and_raise_error_when_convert_keywordarg_to_pydata
This commit is contained in:
mindspore-ci-bot 2020-09-19 21:25:02 +08:00 committed by Gitee
commit 4905de06bd
9 changed files with 19 additions and 10 deletions

View File

@ -118,5 +118,5 @@ class ClassMemberNamespace(Namespace):
except ValueError:
raise UnboundLocalError(name)
except KeyError:
logger.warning(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', "
f"so will return None.")
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
raise AttributeError(name)

View File

@ -89,7 +89,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
}
std::string param_name = py::cast<std::string>(name_attr);
auto param_name = py::cast<std::string>(name_attr);
auto top_graph = Parser::GetTopFuncGraph();
// if the parameter node has been created , return it
AnfNodePtr para_node = nullptr;
@ -115,7 +115,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) {
AnfNodePtr output = nullptr;
if (py::hasattr(obj, "__parameter__")) {
if (py::hasattr(obj, "__parameter__") && py::isinstance<tensor::MetaTensor>(obj)) {
auto param = ResolveParameterObj(func_graph, obj);
if (param == nullptr) {
MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr";

View File

@ -1014,7 +1014,8 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
// create class instance
auto obj = parse::data_converter::CreatePythonObject(class_type, params);
if (py::isinstance<py::none>(obj)) {
MS_LOG(EXCEPTION) << "Create python object failed, only support Cell and Primitive type";
MS_LOG(EXCEPTION) << "Create python object" << py::str(class_type)
<< " failed, only support create Cell or Primitive object.";
}
// process the object

View File

@ -1355,8 +1355,8 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
auto params = newfg->parameters();
auto manager = Manage({newfg}, false);
if (args.size() > params.size()) {
MS_EXCEPTION(ValueError) << "The number of arguments " << args.size()
<< " is more than the number of parameters required, which is " << params.size();
MS_EXCEPTION(TypeError) << "The number of arguments " << args.size()
<< " is more than the number of parameters required, which is " << params.size();
}
for (size_t i = 0; i < args.size(); i++) {
ValuePtr value = PyAttrValue(args[i]);

View File

@ -147,7 +147,7 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
} else if (value->isa<None>()) {
ret = py::none();
} else {
MS_LOG(INFO) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
}
return ret;
}

View File

@ -48,6 +48,8 @@ class Parameter(MetaTensor):
Note:
Each parameter of Cell is represented by Parameter class.
A Parameter has to belong to a Cell.
If there is an operator in the network that requires part of the inputs to be Parameter,
then the Parameters as this part of the inputs are not allowed to be cast.
Args:
default_input (Union[Tensor, Initializer, Number]): Parameter data, to be set initialized.

View File

@ -22,6 +22,7 @@ from mindspore import context
from .._c_expression import Primitive_, real_run_op, prim_type
from . import signature as sig
class Primitive(Primitive_):
"""
Primitive is the base class of primitives in python.
@ -168,7 +169,7 @@ class Primitive(Primitive_):
return type(self)(**self.init_attrs)
def __repr__(self):
attr = ', '.join([f'{k}={self.attrs[k]}'for k in self.attrs if not k in Primitive._repr_ignore_list])
attr = ', '.join([f'{k}={self.attrs[k]}' for k in self.attrs if not k in Primitive._repr_ignore_list])
info_str = f'Prim[{self.name}]'
if attr:
info_str += f'<{attr}>'
@ -425,6 +426,7 @@ def prim_attr_register(fn):
Returns:
function, original function.
"""
def deco(self, *args, **kwargs):
if isinstance(self, PrimitiveWithInfer):
PrimitiveWithInfer.__init__(self, self.__class__.__name__)
@ -442,6 +444,7 @@ def prim_attr_register(fn):
self.add_prim_attr(name, value)
self.init_attrs[name] = value
fn(self, *args, **kwargs)
deco.decorated_func = fn
return deco
@ -470,6 +473,7 @@ def constexpr(fn=None, get_instance=True, name=None):
>>> return len(x)
>>> assert tuple_len_class()(a) == 2
"""
def deco(fn):
class CompileOp(PrimitiveWithInfer):
def __init__(self):
@ -479,9 +483,11 @@ def constexpr(fn=None, get_instance=True, name=None):
def infer_value(self, *args):
return fn(*args)
if get_instance:
return CompileOp()
return CompileOp
if fn is not None:
return deco(fn)
return deco

View File

@ -143,6 +143,7 @@ class TestSummaryOps:
(SummaryEnum.TENSOR.value, Tensor(0)),
(SummaryEnum.HISTOGRAM.value, Tensor(0))
])
def test_value_shape_invalid(self, summary_type, value):
"""Test invalid shape of every summary operators."""
net = SummaryNet(summary_type, tag='tag', data=value)

View File

@ -122,7 +122,6 @@ class SummaryDemo(nn.Cell):
self.s("y1", y)
return z
def test_tensor_summary_with_ge():
""" test_tensor_summary_with_ge """
log.debug("begin test_tensor_summary_with_ge")