forked from OSSInnovation/mindspore
!6545 improve the recognition of Parameter object and raise error when convert keywordarg to pydata
Merge pull request !6545 from zhangbuxue/improve_the_recognition_of_Parameter_object_and_raise_error_when_convert_keywordarg_to_pydata
This commit is contained in:
commit
4905de06bd
|
@ -118,5 +118,5 @@ class ClassMemberNamespace(Namespace):
|
|||
except ValueError:
|
||||
raise UnboundLocalError(name)
|
||||
except KeyError:
|
||||
logger.warning(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', "
|
||||
f"so will return None.")
|
||||
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
|
||||
raise AttributeError(name)
|
||||
|
|
|
@ -89,7 +89,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
|
|||
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
|
||||
}
|
||||
|
||||
std::string param_name = py::cast<std::string>(name_attr);
|
||||
auto param_name = py::cast<std::string>(name_attr);
|
||||
auto top_graph = Parser::GetTopFuncGraph();
|
||||
// if the parameter node has been created , return it
|
||||
AnfNodePtr para_node = nullptr;
|
||||
|
@ -115,7 +115,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
|
|||
|
||||
bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) {
|
||||
AnfNodePtr output = nullptr;
|
||||
if (py::hasattr(obj, "__parameter__")) {
|
||||
if (py::hasattr(obj, "__parameter__") && py::isinstance<tensor::MetaTensor>(obj)) {
|
||||
auto param = ResolveParameterObj(func_graph, obj);
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr";
|
||||
|
|
|
@ -1014,7 +1014,8 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
// create class instance
|
||||
auto obj = parse::data_converter::CreatePythonObject(class_type, params);
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "Create python object failed, only support Cell and Primitive type";
|
||||
MS_LOG(EXCEPTION) << "Create python object" << py::str(class_type)
|
||||
<< " failed, only support create Cell or Primitive object.";
|
||||
}
|
||||
|
||||
// process the object
|
||||
|
|
|
@ -1355,8 +1355,8 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
|
|||
auto params = newfg->parameters();
|
||||
auto manager = Manage({newfg}, false);
|
||||
if (args.size() > params.size()) {
|
||||
MS_EXCEPTION(ValueError) << "The number of arguments " << args.size()
|
||||
<< " is more than the number of parameters required, which is " << params.size();
|
||||
MS_EXCEPTION(TypeError) << "The number of arguments " << args.size()
|
||||
<< " is more than the number of parameters required, which is " << params.size();
|
||||
}
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
ValuePtr value = PyAttrValue(args[i]);
|
||||
|
|
|
@ -147,7 +147,7 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
|
|||
} else if (value->isa<None>()) {
|
||||
ret = py::none();
|
||||
} else {
|
||||
MS_LOG(INFO) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
|
||||
MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -48,6 +48,8 @@ class Parameter(MetaTensor):
|
|||
Note:
|
||||
Each parameter of Cell is represented by Parameter class.
|
||||
A Parameter has to belong to a Cell.
|
||||
If there is an operator in the network that requires part of the inputs to be Parameter,
|
||||
then the Parameters as this part of the inputs are not allowed to be cast.
|
||||
|
||||
Args:
|
||||
default_input (Union[Tensor, Initializer, Number]): Parameter data, to be set initialized.
|
||||
|
|
|
@ -22,6 +22,7 @@ from mindspore import context
|
|||
from .._c_expression import Primitive_, real_run_op, prim_type
|
||||
from . import signature as sig
|
||||
|
||||
|
||||
class Primitive(Primitive_):
|
||||
"""
|
||||
Primitive is the base class of primitives in python.
|
||||
|
@ -168,7 +169,7 @@ class Primitive(Primitive_):
|
|||
return type(self)(**self.init_attrs)
|
||||
|
||||
def __repr__(self):
|
||||
attr = ', '.join([f'{k}={self.attrs[k]}'for k in self.attrs if not k in Primitive._repr_ignore_list])
|
||||
attr = ', '.join([f'{k}={self.attrs[k]}' for k in self.attrs if not k in Primitive._repr_ignore_list])
|
||||
info_str = f'Prim[{self.name}]'
|
||||
if attr:
|
||||
info_str += f'<{attr}>'
|
||||
|
@ -425,6 +426,7 @@ def prim_attr_register(fn):
|
|||
Returns:
|
||||
function, original function.
|
||||
"""
|
||||
|
||||
def deco(self, *args, **kwargs):
|
||||
if isinstance(self, PrimitiveWithInfer):
|
||||
PrimitiveWithInfer.__init__(self, self.__class__.__name__)
|
||||
|
@ -442,6 +444,7 @@ def prim_attr_register(fn):
|
|||
self.add_prim_attr(name, value)
|
||||
self.init_attrs[name] = value
|
||||
fn(self, *args, **kwargs)
|
||||
|
||||
deco.decorated_func = fn
|
||||
return deco
|
||||
|
||||
|
@ -470,6 +473,7 @@ def constexpr(fn=None, get_instance=True, name=None):
|
|||
>>> return len(x)
|
||||
>>> assert tuple_len_class()(a) == 2
|
||||
"""
|
||||
|
||||
def deco(fn):
|
||||
class CompileOp(PrimitiveWithInfer):
|
||||
def __init__(self):
|
||||
|
@ -479,9 +483,11 @@ def constexpr(fn=None, get_instance=True, name=None):
|
|||
|
||||
def infer_value(self, *args):
|
||||
return fn(*args)
|
||||
|
||||
if get_instance:
|
||||
return CompileOp()
|
||||
return CompileOp
|
||||
|
||||
if fn is not None:
|
||||
return deco(fn)
|
||||
return deco
|
||||
|
|
|
@ -143,6 +143,7 @@ class TestSummaryOps:
|
|||
(SummaryEnum.TENSOR.value, Tensor(0)),
|
||||
(SummaryEnum.HISTOGRAM.value, Tensor(0))
|
||||
])
|
||||
|
||||
def test_value_shape_invalid(self, summary_type, value):
|
||||
"""Test invalid shape of every summary operators."""
|
||||
net = SummaryNet(summary_type, tag='tag', data=value)
|
||||
|
|
|
@ -122,7 +122,6 @@ class SummaryDemo(nn.Cell):
|
|||
self.s("y1", y)
|
||||
return z
|
||||
|
||||
|
||||
def test_tensor_summary_with_ge():
|
||||
""" test_tensor_summary_with_ge """
|
||||
log.debug("begin test_tensor_summary_with_ge")
|
||||
|
|
Loading…
Reference in New Issue