From 24a10225cf13db49294a58c3d026595e387d9a48 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Mon, 24 Aug 2020 15:55:26 +0800 Subject: [PATCH] change base class of ref to tensor in cpp --- mindspore/_checkparam.py | 188 ++--------------- .../backend/session/anf_runtime_algorithm.cc | 5 - .../operator/composite/do_signature.cc | 92 ++++----- .../operator/composite/multitype_funcgraph.cc | 8 +- .../operator/ops_front_infer_function.cc | 12 ++ .../frontend/parallel/graph_util/node_info.cc | 4 +- .../ccsrc/frontend/parallel/step_parallel.cc | 11 +- mindspore/ccsrc/pipeline/jit/action.cc | 15 +- .../pipeline/jit/parse/data_converter.cc | 6 - .../pipeline/jit/parse/function_block.cc | 22 +- .../ccsrc/pipeline/jit/parse/function_block.h | 3 - mindspore/ccsrc/pipeline/jit/parse/parse.cc | 5 +- mindspore/ccsrc/pipeline/jit/parse/parse.h | 2 +- mindspore/ccsrc/pipeline/jit/parse/resolve.cc | 18 +- .../pipeline/pynative/pynative_execute.cc | 25 ++- mindspore/ccsrc/pybind_api/ir/dtype_py.cc | 2 +- .../ccsrc/pybind_api/ir/param_info_py.cc | 4 +- mindspore/ccsrc/pybind_api/ir/tensor_py.cc | 1 + mindspore/common/parameter.py | 24 +-- mindspore/core/abstract/abstract_value.cc | 48 ++--- mindspore/core/abstract/abstract_value.h | 37 ++-- mindspore/core/abstract/prim_others.cc | 13 +- mindspore/core/ir/anf.cc | 11 + mindspore/core/ir/anf.h | 3 +- mindspore/core/ir/dtype.cc | 173 +--------------- mindspore/core/ir/dtype.h | 98 +-------- mindspore/core/ir/dtype/number.h | 2 + mindspore/core/ir/dtype/ref.cc | 8 +- mindspore/core/ir/dtype/ref.h | 27 +-- mindspore/core/ir/dtype/tensor_type.cc | 194 ++++++++++++++++++ mindspore/core/ir/dtype/tensor_type.h | 132 ++++++++++++ mindspore/core/ir/func_graph.h | 3 - mindspore/core/ir/meta_tensor.h | 13 ++ mindspore/core/ir/meta_tensor_extends.cc | 11 +- mindspore/core/ir/named.h | 15 ++ mindspore/core/ir/param_info.h | 9 +- mindspore/core/ir/tensor.cc | 11 +- mindspore/core/ir/value.cc | 10 - mindspore/core/ir/value.h | 19 +- mindspore/lite/test/CMakeLists.txt | 2 + mindspore/lite/tools/converter/CMakeLists.txt | 2 + mindspore/ops/operations/other_ops.py | 2 +- mindspore/ops/primitive.py | 17 +- tests/st/control/test_ascend_control_sink.py | 15 +- tests/ut/python/pipeline/parse/test_parse.py | 58 ++++++ .../python/pipeline/parse/test_while_param.py | 144 +++++++++++++ tests/vm_impl/array_ops_vm_impl.py | 8 +- 47 files changed, 812 insertions(+), 720 deletions(-) create mode 100644 mindspore/core/ir/dtype/tensor_type.cc create mode 100644 mindspore/core/ir/dtype/tensor_type.h create mode 100644 tests/ut/python/pipeline/parse/test_while_param.py diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 801e3ac554..40a402bb49 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -185,14 +185,23 @@ class Validator: raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") @staticmethod - def check_subclass(arg_name, type_, template_type, prim_name): + def check_subclass(arg_name, type_, template_types, prim_name): """Checks whether some type is subclass of another type""" - if not isinstance(template_type, Iterable): - template_type = (template_type,) - if not any([mstype.issubclass_(type_, x) for x in template_type]): + if not isinstance(template_types, Iterable): + template_types = (template_types,) + hit = False + for template_type in template_types: + if isinstance(template_type, mstype.Type): + if mstype.issubclass_(type_, template_type): + hit = True + break + elif type_ is template_type: + hit = True + break + if not hit: type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' - f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') + f' of {",".join((str(x) for x in template_types))}, but got {type_str}.') @staticmethod def check_const_input(arg_name, arg_value, prim_name): @@ -206,13 +215,7 @@ class Validator: def _check_tensor_type(arg): arg_key, arg_val = arg elem_type = arg_val - if not elem_type in valid_values: - type_names = [] - for t in valid_values: - type_names.append(str(t)) - types_info = '[' + ', '.join(type_names) + ']' - raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},' - f' but got {elem_type}.') + Validator.check_subclass(arg_key, elem_type, valid_values, prim_name) return (arg_key, elem_type) def _check_types_same(arg1, arg2): @@ -335,12 +338,6 @@ class Validator: class ParamValidator: """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" - @staticmethod - def equal(arg_name, arg_value, cond_str, cond): - """Judging valid value.""" - if not cond: - raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.') - @staticmethod def check(arg_name, arg_value, value_name, value, rel=Rel.EQ): """This method is only used for check int values, since when compare float values, @@ -360,27 +357,6 @@ class ParamValidator: raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') return arg_value - @staticmethod - def check_shape_length(arg_name, arg_value, value, rel): - """Shape length judgment.""" - rel_fn = Rel.get_fns(rel) - type_mismatch = not isinstance(arg_value, int) - if type_mismatch or not rel_fn(arg_value, value): - rel_str = Rel.get_strs(rel).format(value) - raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}') - return arg_value - - @staticmethod - def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel): - """This method is only used for check int values, - since when compare float values, we need consider float error.""" - rel_fn = Rel.get_fns(rel) - type_mismatch = not isinstance(arg_value, int) - if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): - rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) - raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.') - return arg_value - @staticmethod def check_isinstance(arg_name, arg_value, classes): """Check arg isinstance of classes""" @@ -388,33 +364,6 @@ class ParamValidator: raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') return arg_value - @staticmethod - def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel): - """Is it necessary to consider error when comparing float values.""" - rel_fn = Rel.get_fns(rel) - if not rel_fn(arg_value, lower_limit, upper_limit): - rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) - raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.') - return arg_value - - @staticmethod - def check_subclass(arg_name, type_, template_type, with_type_of=True): - """Check whether some type is subclass of another type""" - if not isinstance(template_type, Iterable): - template_type = (template_type,) - if not any([mstype.issubclass_(type_, x) for x in template_type]): - type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) - raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass' - f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') - - @staticmethod - def check_args_tensor(args): - """Check whether args are all tensor.""" - if not isinstance(args, dict): - raise TypeError("The args should be a dict.") - for arg, value in args.items(): - ParamValidator.check_subclass(arg, value, mstype.tensor) - @staticmethod def check_bool(arg_name, arg_value): """Check arg isinstance of bool""" @@ -442,113 +391,6 @@ class ParamValidator: return arg_value raise_error_msg() - @staticmethod - def check_typename(arg_name, arg_type, valid_types): - """Does it contain the _name_ attribute.""" - - def get_typename(t): - return t.__name__ if hasattr(t, '__name__') else str(t) - - if isinstance(arg_type, type(mstype.tensor)): - arg_type = arg_type.element_type() - - if arg_type in valid_types: - return arg_type - type_names = [get_typename(t) for t in valid_types] - if len(valid_types) == 1: - raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},' - f' but got {get_typename(arg_type)}.') - raise ValueError(f'The type of `{arg_name}` should be one of {type_names},' - f' but got {get_typename(arg_type)}.') - - @staticmethod - def check_string(arg_name, arg_value, valid_values): - """String type judgment.""" - if isinstance(arg_value, str) and arg_value in valid_values: - return arg_value - if len(valid_values) == 1: - raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},' - f' but got {arg_value}.') - raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},' - f' but got {arg_value}.') - - @staticmethod - def check_type_same(args, valid_values): - """Determine whether the types are the same.""" - name = list(args.keys())[0] - value = list(args.values())[0] - if isinstance(value, type(mstype.tensor)): - value = value.element_type() - for arg_name, arg_value in args.items(): - if isinstance(arg_value, type(mstype.tensor)): - arg_value = arg_value.element_type() - - if arg_value not in valid_values: - raise TypeError(f'The `{arg_name}` should be in {valid_values},' - f' but `{arg_name}` is {arg_value}.') - if arg_value != value: - raise TypeError(f'`{arg_name}` should be same as `{name}`,' - f' but `{arg_name}` is {arg_value}, `{name}` is {value}.') - - @staticmethod - def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type): - """Determine whether the types of two variables are the same.""" - if arg1_type != arg2_type: - raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.') - - @staticmethod - def check_value_on_integer(arg_name, arg_value, value, rel): - """Judging integer type.""" - rel_fn = Rel.get_fns(rel) - type_match = isinstance(arg_value, int) - if type_match and (not rel_fn(arg_value, value)): - rel_str = Rel.get_strs(rel).format(value) - raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') - return arg_value - - @staticmethod - def check_param_equal(param1_name, param1_value, param2_name, param2_value): - """Judging the equality of parameters.""" - if param1_value != param2_value: - raise ValueError(f"`{param1_name}` must equal `{param2_name}`," - f" but got `{param1_name}` = {param1_value}," - f" `{param2_name}` = {param2_value}.") - - @staticmethod - def check_const_input(arg_name, arg_value): - """Check valid value.""" - if arg_value is None: - raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.') - - @staticmethod - def check_float_positive(arg_name, arg_value): - """Float type judgment.""" - if isinstance(arg_value, float): - if arg_value > 0: - return arg_value - raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.") - - raise TypeError(f"`{arg_name}` must be float!") - - @staticmethod - def check_pad_value_by_mode(op_name, pad_mode, padding): - """Validate value of padding according to pad_mode""" - if pad_mode != 'pad' and padding != 0: - raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.") - return padding - - @staticmethod - def check_empty_shape_input(arg_name, arg_value): - """Check zeros value.""" - if 0 in arg_value: - raise ValueError(f"Input `{arg_name}` cannot be empty.") - - @staticmethod - def check_scalar_shape_input(arg_name, arg_value): - """Check scalar shape input.""" - if arg_value != []: - raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}") - def check_int(input_param): """Int type judgment.""" diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 2962ba3be2..63332429e7 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -592,11 +592,6 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_ return get_single_type((*tuple_ptr)[output_idx]); }; TypePtr type_ptr = node->Type(); - if (type_ptr->isa()) { - auto ref_type_ptr = type_ptr->cast(); - MS_EXCEPTION_IF_NULL(ref_type_ptr); - return get_tuple_type(ref_type_ptr->subtype(), output_idx); - } return get_tuple_type(type_ptr, output_idx); } diff --git a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc index 0bdb3b2d50..706aafe418 100644 --- a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc @@ -20,6 +20,7 @@ #include "abstract/abstract_value.h" #include "ir/anf.h" +#include "ir/dtype.h" #include "abstract/dshape.h" #include "abstract/param_validator.h" #include "frontend/operator/cc_implementations.h" @@ -43,15 +44,15 @@ const std::vector &GetSignature(const ValuePtr &function) { return empty; } -void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, - const std::vector &signature, bool has_var, std::vector *const op_inputs) { +void ProcessDefault(const std::string &func_name, size_t actual_param_number, const std::vector &signature, + bool has_var, std::vector *const op_inputs) { std::size_t sig_size = signature.size(); auto positional_size = sig_size; if (has_var) { positional_size = sig_size - 1; } - if (args_spec_list.size() < positional_size) { - for (size_t i = args_spec_list.size(); i < sig_size; ++i) { + if (actual_param_number < positional_size) { + for (size_t i = actual_param_number; i < sig_size; ++i) { auto default_value = signature[i].default_value; if (default_value == nullptr) { MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; @@ -67,23 +68,11 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_ *max_type_number = type_number; } -bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, +bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, bool is_write, TypeId *arg_type_id, TypeId *arg_type = nullptr) { - if (arg_value->isa()) { - auto ref = arg_value->cast(); - arg_value = ref->ref(); - if (!is_write && ref->need_cast()) { - auto tensor_type = ref->target_type(); - *arg_type_id = tensor_type->type_id(); - if (arg_type != nullptr) { - *arg_type = kObjectTypeTensorType; - } - return true; - } - } - if (arg_value->isa()) { - auto tensor = arg_value->cast(); - auto tensor_type = tensor->element()->BuildType(); + if (arg_type_origin->isa()) { + auto tensor = arg_type_origin->cast(); + auto tensor_type = tensor->element(); MS_EXCEPTION_IF_NULL(tensor_type); *arg_type_id = tensor_type->type_id(); if (arg_type != nullptr) { @@ -91,9 +80,8 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId } return true; } - if (arg_value->isa()) { - auto scalar = arg_value->cast(); - auto scalar_type = scalar->BuildType(); + if (arg_type_origin->isa()) { + auto scalar_type = arg_type_origin->cast(); MS_EXCEPTION_IF_NULL(scalar_type); *arg_type_id = scalar_type->type_id(); if (arg_type != nullptr) { @@ -104,7 +92,7 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId return false; } -TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector indices, +TypeId GetMaxTypeId(const std::vector &input_types, std::vector indices, const std::set &write_indices) { TypeId max_type_id = kTypeUnknown; size_t max_type_number = 0; @@ -115,7 +103,7 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve TypeId arg_type_id = kTypeUnknown; TypeId arg_type = kTypeUnknown; auto is_write = (write_indices.find(index) != write_indices.end()); - if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { + if (!GetTensorOrScalarTypeInfo(input_types[index], is_write, &arg_type_id, &arg_type)) { continue; } if (arg_type != kObjectTypeTensorType) { @@ -161,8 +149,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve // Get the largest type of index in the same SignatureEnumDType of arguments. using MaxTypeMap = std::map; -MaxTypeMap GetMaxDtype(const std::vector &dtypes, - const abstract::AbstractBasePtrList &args_spec_list, const std::set &write_indices) { +MaxTypeMap GetMaxDtype(const std::vector &dtypes, const std::vector &input_types, + const std::set &write_indices) { // record index for signature.dtypes of the same type // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} std::map> type_indices; @@ -184,11 +172,8 @@ MaxTypeMap GetMaxDtype(const std::vector &dtypes, } bool has_tensor = false; for (const auto &index : indices) { - AbstractBasePtr arg_value = args_spec_list[index]; - if (arg_value->isa()) { - arg_value = arg_value->cast()->ref(); - } - if (arg_value->isa()) { + auto arg_value = input_types[index]; + if (arg_value->isa()) { has_tensor = true; break; } @@ -197,7 +182,7 @@ MaxTypeMap GetMaxDtype(const std::vector &dtypes, (void)dst_type.insert(std::make_pair(type, kTypeUnknown)); continue; } - (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices))); + (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(input_types, indices, write_indices))); } return dst_type; } @@ -211,7 +196,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap } void DoAutoCast(const std::string &func_name, const std::vector &signature, - const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, + const std::vector &input_types, const FuncGraphPtr &graph, std::vector *const op_inputs, const std::set &write_indices) { std::vector dtypes; (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), @@ -221,9 +206,9 @@ void DoAutoCast(const std::string &func_name, const std::vector &sign return; } // Stat the index of the arguments with the largest type in the same SignatureEnumDType. - std::map dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices); + std::map dst_type = GetMaxDtype(dtypes, input_types, write_indices); // Identify which arg requires auto cast - for (size_t i = 0; i < args_spec_list.size(); ++i) { + for (size_t i = 0; i < input_types.size(); ++i) { auto it = dst_type.find(dtypes[i]); if (it == dst_type.end() || it->second == kTypeUnknown) { continue; @@ -232,7 +217,7 @@ void DoAutoCast(const std::string &func_name, const std::vector &sign auto is_write = (rw_it != write_indices.end()); TypeId arg_type_id = kTypeUnknown; - AbstractBasePtr arg_value = args_spec_list[i]; + auto arg_value = input_types[i]; (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); auto it_map = type_name_map.find(arg_type_id); if (it_map == type_name_map.end()) { @@ -248,7 +233,7 @@ void DoAutoCast(const std::string &func_name, const std::vector &sign } continue; } - if (arg_value->isa() && arg_type_id == it->second) { + if ((arg_value->isa()) && arg_type_id == it->second) { continue; } MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id @@ -275,6 +260,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func } std::vector op_inputs; std::set write_indices; + std::vector input_types; op_inputs.push_back(NewValueNode(function)); // Assume, the write input of op is always the first input. We check if any write op, // and add cast op on other inputs to keep the same type with assigned parameter. @@ -292,30 +278,36 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func sig = signature[sig_size - 1].rw; } - TypePtr type = args_spec_list[i]->GetTypeTrack(); - if (type && type->type_id() == kObjectTypeRef) { - auto ref_abs = args_spec_list[i]->cast(); + TypePtr type = args_spec_list[i]->BuildType(); + if (type && type->isa()) { + auto cast_type = parse::GetMixedPrecisionTargetType(func_graph); if (sig == SignatureEnumRW::kRWRead) { - param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph); - if (ref_abs && ref_abs->need_cast()) { - auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); - param = NewCNode({NewValueNode(cast), param, NewValueNode(ref_abs->target_type())}, func_graph); + auto source_tensor_type = type->cast(); + if (source_tensor_type != nullptr) { + auto source_element = source_tensor_type->element(); + if (cast_type != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { + auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); + param = NewCNode({NewValueNode(cast), param, NewValueNode(cast_type)}, func_graph); + type = cast_type->type_id() == kNumberTypeFloat16 ? kTensorTypeFP16 : kTensorTypeFP32; + } } } else if (sig == SignatureEnumRW::kRWWrite) { - param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph); write_indices.insert(i); } // If sig is SignatureEnumRW::kRWRef, not do anything. - } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { - MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; + } else if (sig == SignatureEnumRW::kRWWrite && + !((type->type_id() == kObjectTypeRef) || (type->type_id() == kObjectTypeRefKey))) { + MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but " + << type->ToString(); } MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type " << args_spec_list[i]->ToString(); + input_types.push_back(type); op_inputs.push_back(param); } // process default - ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); - DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices); + ProcessDefault(func_name, args_spec_list.size(), signature, has_var, &op_inputs); + DoAutoCast(func_name, signature, input_types, func_graph, &op_inputs, write_indices); return func_graph->NewCNode(op_inputs); } } // namespace diff --git a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc index 6284521b23..1768bbd90f 100644 --- a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc +++ b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc @@ -81,12 +81,6 @@ void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function & } Register(types_name, py_fn); } -static TypePtr UnwrapRef(const TypePtr &type) { - if (type->isa()) { - return type->cast()->subtype(); - } - return type; -} // Return Exact match if exists, else return non ambiguous sub class match // Return py::none() if matching is ambiguous @@ -99,7 +93,7 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { } auto match = true; for (size_t i = 0; i < sign.size(); ++i) { - if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) { + if (!IsIdentidityOrSubclass(types[i], sign[i])) { match = false; break; } diff --git a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc index 611d00deea..3e3149a081 100644 --- a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc +++ b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc @@ -627,6 +627,16 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt return std::make_shared(cls->tag(), abs_attributes, cls->methods()); } + +AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor + CheckArgsSize(primitive->name(), args_spec_list, 2); + + MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0]; + return args_spec_list[0]; +} + REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof); REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType); REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord); @@ -648,5 +658,7 @@ REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImpl REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ); REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, InferImplBroadcastGradientArgs); +REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign); + } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc index cf8b2f1016..272260b20d 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc @@ -20,6 +20,7 @@ #include "ir/anf.h" #include "ir/param_info.h" +#include "ir/meta_tensor.h" #include "pipeline/jit/parse/python_adapter.h" namespace mindspore { @@ -38,8 +39,7 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { if (!para_ptr->has_default()) { return false; } - auto obj = py::cast(para_ptr->default_param()); - auto param_value = py::cast(obj.attr("_value")); + auto param_value = para_ptr->param_info(); if (param_value == nullptr) { return false; } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 2b495ec281..fe3a325e87 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1356,8 +1356,7 @@ bool ParameterIsCloned(const AnfNodePtr ¶meter_node) { if (!cloned_parameter->has_default()) { return false; } - auto obj = py::cast(cloned_parameter->default_param()); - auto param_value = py::cast(obj.attr("_value")); + auto param_value = cloned_parameter->param_info(); if (param_value == nullptr) { return false; } @@ -1380,8 +1379,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { if (!ParameterIsCloned(cloned_parameter_node)) { continue; } - auto obj = py::cast(cloned_parameter->default_param()); - auto param_value = py::cast(obj.attr("_value")); + auto param_value = cloned_parameter->param_info(); if (param_value == nullptr) { continue; } @@ -1400,10 +1398,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { continue; } - const auto ¶m_value_cloned = be_cloned_parameter->default_param(); - - auto obj_in = py::cast(param_value_cloned); - auto param_value_in = py::cast(obj_in.attr("_value")); + auto param_value_in = be_cloned_parameter->param_info(); if (param_value_in == nullptr) { continue; } diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 15222a284c..f18873e169 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -233,13 +233,14 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { for (const auto ¶m : func_graph->parameters()) { auto param_node = std::static_pointer_cast(param); if (param_node->has_default()) { - ValuePtr value = param_node->default_param(); - constexpr bool broaden = true; - AbstractBasePtr ptr = abstract::FromValue(value, broaden); - - parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); - args_spec.push_back(ptr); - parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr); + auto value = param_node->default_param(); + auto abs_value = value->ToAbstract()->cast(); + auto ref_key = std::make_shared(param_node->name()); + auto abs_ref_key = ref_key->ToAbstract(); + auto abs_ref = std::make_shared(abs_ref_key, abs_value); + parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, abs_ref); + args_spec.push_back(abs_ref); + parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, abs_ref); } } // Analyze diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index 72e17fed95..2f43c7c0cd 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -425,9 +425,6 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature converted = env; } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) { converted = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj); - } else if (py::hasattr(obj, "__parameter__")) { - auto to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); - ret = ConvertData(to_convert, &converted); } else { ret = ConvertOtherObj(obj, &converted); } @@ -555,9 +552,6 @@ void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) ValuePtr PyDataToValue(const py::object &obj) { py::object to_convert = obj; - if (py::hasattr(obj, "__parameter__")) { - to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); - } ValuePtr value = nullptr; (void)ConvertData(to_convert, &value); return value; diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc index 1b342f5ec3..255b77af88 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc @@ -306,7 +306,14 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr } void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) { - state_assign_[target] = readid; + const std::string primitive_name("assign"); + const std::string module_name("mindspore.ops.functional"); + ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true)); + auto source = ReadVariable(readid); + auto assign = func_graph()->NewCNode({assign_op, target, source}); + WriteVariable(readid, assign); + MS_LOG(INFO) << "SetState read " << target->DebugString() << ", " << readid; + AddAutoDepend(assign); } void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); } @@ -321,21 +328,13 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple); ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend); ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient); - const std::string primitive_name("assign"); - const std::string module_name("mindspore.ops.functional"); - ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true)); - if (state_assign_.size() == 0 && auto_depends_.size() == 0) { + + if (auto_depends_.size() == 0) { return; } AnfNodePtr state = nullptr; std::vector vec_states; vec_states.emplace_back(make_tuple_op); - for (auto &item : state_assign_) { - auto source = ReadVariable(item.second); - auto assign = func_graph()->NewCNode({assign_op, item.first, source}); - MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second; - vec_states.emplace_back(assign); - } for (auto &item : auto_depends_) { MS_LOG(DEBUG) << "auto_depends " << item->ToString(); vec_states.emplace_back(item); @@ -361,7 +360,6 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state}); AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped}); func_graph()->set_output(ret, true); - state_assign_.clear(); } } // namespace parse } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.h b/mindspore/ccsrc/pipeline/jit/parse/function_block.h index e598790cd4..70476791d8 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.h +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.h @@ -101,9 +101,6 @@ class FunctionBlock : public std::enable_shared_from_this { // keeps all removable phis which will be removed in one pass. std::unordered_map removable_phis_; - // set state nodes need to insert before function return nodes. - OrderedMap state_assign_; - // hold declared global variables in function std::set global_vars_; diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 87c93ac8b1..9e7ed723b5 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -59,14 +59,13 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo return func_graph; } -ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { - TypePtr dst_type; +TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) { if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { return kFloat32; } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { return kFloat16; } else { - return kNone; + return nullptr; } } diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.h b/mindspore/ccsrc/pipeline/jit/parse/parse.h index b922248e5e..6244aa7af1 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.h @@ -359,7 +359,7 @@ class ParseAst { bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); -ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); +TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph); } // namespace parse } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index b476603f12..9d81dc4c3b 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -105,24 +105,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object auto value = py::cast(obj); node->set_default_param(value); // set_abstract for parameter - constexpr bool broaden = true; - node->set_abstract(abstract::FromValue(value, broaden)); + auto abs = value->ToAbstract(); + node->set_abstract(abs); para_node = node; } - auto iter = func_graph->make_ref_params().find(para_node); - if (iter == func_graph->make_ref_params().end()) { - ValuePtr target_type = GetMixedPrecisionTargetType(func_graph, para_node); - AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); - AnfNodePtr ref_key = NewValueNode(std::make_shared(param_name)); - AnfNodePtr target_type_node = NewValueNode(target_type); - AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, para_node, target_type_node}); - func_graph->make_ref_params()[para_node] = ref_node; - func_graph->add_parameter_obj_node(ref_node); - return ref_node; - } else { - return iter->second; - } + return para_node; } bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index ff7949cce4..44e1d4ee53 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -640,7 +640,14 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v size_t size = op_exec_info->op_inputs.size(); for (size_t i = 0; i < size; i++) { auto obj = op_exec_info->op_inputs[i]; - bool op_mask = py::hasattr(obj, "__parameter__"); + bool op_mask = false; + if (py::isinstance(obj)) { + auto meta_tensor = obj.cast(); + if (meta_tensor) { + op_mask = meta_tensor->is_parameter(); + } + } + (*op_masks).push_back(op_mask); MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ " << grad_flag_; @@ -988,8 +995,9 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) { auto free_param = df_builder_->add_parameter(); free_param->set_name(param_name); - free_param->set_default_param(py::cast(obj)); free_param->debug_info()->set_name(param_name); + auto value = py::cast(obj); + free_param->set_default_param(value); MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; graph_info_map_[df_builder_].param_map[obj_id] = free_param; return free_param; @@ -1157,17 +1165,12 @@ std::vector PynativeExecutor::GetWeightsArgs(const py::object &weigh auto param_name = py::cast(name_attr); auto free_param = df_builder_->add_parameter(); free_param->set_name(param_name); - free_param->set_default_param(py::cast(param)); + auto value = py::cast(param); + free_param->set_default_param(value); free_param->debug_info()->set_name(param_name); para_node = free_param; } - ValuePtr target_type = parse::GetMixedPrecisionTargetType(df_builder_, para_node); - AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); - auto refkey = std::make_shared(para_node->cast()->name()); - AnfNodePtr ref_key_node = NewValueNode(refkey); - AnfNodePtr target_type_node = NewValueNode(target_type); - AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, para_node, target_type_node}); - w_args.push_back(ref_node); + w_args.push_back(para_node); } } else { MS_LOG(DEBUG) << "training not paramter_tuple"; @@ -1195,7 +1198,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args auto param_node = std::static_pointer_cast(param); if (param_node->has_default()) { ValuePtr value = param_node->default_param(); - AbstractBasePtr ptr = abstract::FromValue(value, true); + auto ptr = value->ToAbstract(); if (ptr == nullptr) { MS_LOG(EXCEPTION) << "Args convert error"; } diff --git a/mindspore/ccsrc/pybind_api/ir/dtype_py.cc b/mindspore/ccsrc/pybind_api/ir/dtype_py.cc index b279532d06..1f139cdd27 100644 --- a/mindspore/ccsrc/pybind_api/ir/dtype_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/dtype_py.cc @@ -147,7 +147,7 @@ REGISTER_PYBIND_DEFINE( (void)py::class_>(m_sub, "TypeType").def(py::init()); (void)py::class_>(m_sub, "String").def(py::init()); (void)py::class_>(m_sub, "RefKeyType").def(py::init()); - (void)py::class_>(m_sub, "RefType").def(py::init()); + (void)py::class_>(m_sub, "RefType").def(py::init()); (void)py::class_>(m_sub, "TypeAnything").def(py::init()); (void)py::class_>(m_sub, "Slice").def(py::init()); (void)py::class_>(m_sub, "TypeEllipsis").def(py::init()); diff --git a/mindspore/ccsrc/pybind_api/ir/param_info_py.cc b/mindspore/ccsrc/pybind_api/ir/param_info_py.cc index 9ace33643e..f151462653 100644 --- a/mindspore/ccsrc/pybind_api/ir/param_info_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/param_info_py.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace py = pybind11; REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) { - (void)py::class_(*m, "ParamInfo") + (void)py::class_(*m, "ParamInfo") .def(py::init()) .def("clone", &ParamInfo::Clone) .def_property("name", &ParamInfo::name, &ParamInfo::set_name) @@ -36,7 +36,7 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) { if (t.size() != 6) { std::runtime_error("Invalid state for ParamInfo!"); } - ParamValuePtr p = std::make_shared(); + ParamInfoPtr p = std::make_shared(); p->set_name(t[1].cast()); p->set_requires_grad(t[2].cast()); p->set_layerwise_parallel(t[3].cast()); diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc index 7596d75c28..508f2f45c6 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc @@ -213,6 +213,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { .def(py::init>(), py::arg("dtype"), py::arg("shape")) .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") + .def_property("_param_info", &MetaTensor::param_info, &MetaTensor::set_param_info) .def(py::pickle( [](const MetaTensor &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index ee41fdd165..7b02bc86c7 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -42,7 +42,7 @@ class Parameter(MetaTensor): In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by an `Initializer`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor` only saves the shape and type info of a tensor with no memory usage. The shape can be changed while - compile for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data. + compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data. Note: Each parameter of Cell is represented by Parameter class. @@ -108,7 +108,7 @@ class Parameter(MetaTensor): Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False): - self._value = ParamInfo() + self._param_info = ParamInfo() self.name = name self.requires_grad = requires_grad self.layerwise_parallel = layerwise_parallel @@ -156,13 +156,13 @@ class Parameter(MetaTensor): value_str = MetaTensor.__str__(self) if isinstance(self, Tensor): value_str = Tensor.__str__(self) - return f'Parameter (name={self._value.name}, value={value_str})' + return f'Parameter (name={self._param_info.name}, value={value_str})' def __repr__(self): value_str = MetaTensor.__repr__(self) if isinstance(self, Tensor): value_str = Tensor.__repr__(self) - return f'Parameter (name={self._value.name}, value={value_str})' + return f'Parameter (name={self._param_info.name}, value={value_str})' def __parameter__(self): """For parse check.""" @@ -181,7 +181,7 @@ class Parameter(MetaTensor): @property def name(self): """Get the name of the parameter.""" - return self._value.name + return self._param_info.name @name.setter def name(self, name_): @@ -203,7 +203,7 @@ class Parameter(MetaTensor): format(name_, PARAMETER_NAME_PREFIX_MAX_LEN)) else: raise ValueError("The type of the name should be `str` or `None`.") - self._value.name = name_ + self._param_info.name = name_ @property def cast_type(self): @@ -254,8 +254,8 @@ class Parameter(MetaTensor): _check_str_by_regular(prefix) x = copy(self) # pylint: disable=protected-access - x._value = self._value.clone() - x._value.name = prefix + '.' + self._value.name + x._param_info = self._param_info.clone() + x._param_info.name = prefix + '.' + self._param_info.name x.is_init = False if init != 'same': shape = self.shape @@ -265,24 +265,24 @@ class Parameter(MetaTensor): @property def layerwise_parallel(self): - return self._value.layerwise_parallel + return self._param_info.layerwise_parallel @layerwise_parallel.setter def layerwise_parallel(self, value=True): if not isinstance(value, bool): raise TypeError("`layerwise_parallel` parameter must be bool type") - self._value.layerwise_parallel = value + self._param_info.layerwise_parallel = value @property def requires_grad(self): """Return whether the parameter requires gradient.""" - return self._value.requires_grad + return self._param_info.requires_grad @requires_grad.setter def requires_grad(self, value=True): if not isinstance(value, bool): raise TypeError("`requires_grad` parameter must be bool type") - self._value.requires_grad = value + self._param_info.requires_grad = value @property def data(self): diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index 389f269d5a..c296c89243 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -459,10 +459,6 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { } auto other_tensor = dyn_cast(other); if (other_tensor == nullptr) { - auto ref_tensor = dyn_cast(other); - if (ref_tensor != nullptr) { - return this->Join(ref_tensor->ref()); - } MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); } if (*this == *other) { @@ -473,7 +469,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { return std::make_shared(element, shape); } -bool AbstractTensor::operator==(const AbstractTensor &other) const { +bool AbstractTensor::equal_to(const AbstractTensor &other) const { if (&other == this) { return true; } @@ -491,12 +487,14 @@ bool AbstractTensor::operator==(const AbstractTensor &other) const { return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal; } +bool AbstractTensor::operator==(const AbstractTensor &other) const { return equal_to(other); } + bool AbstractTensor::operator==(const AbstractBase &other) const { if (&other == this) { return true; } - if (other.isa()) { + if (other.tid() == tid()) { auto other_tensor = static_cast(&other); return *this == *other_tensor; } else { @@ -822,39 +820,21 @@ std::string AbstractJTagged::ToString() const { return buffer.str(); } -AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast, - TypePtr cast_target) - : ref_key_(ref_key), ref_(ref_value), need_cast_(false), target_type_(nullptr), ref_key_value_(nullptr) { +AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value) + : AbstractTensor(*ref_value), ref_key_(ref_key), ref_key_value_(nullptr) { set_type(std::make_shared()); - auto origin_type = ref_value->BuildType(); - if (need_cast && cast_target && origin_type && origin_type->isa()) { - auto tensor_dtype = origin_type->cast()->element(); - if (tensor_dtype && IsSubType(tensor_dtype, kFloat)) { - if (cast_target != tensor_dtype) { - need_cast_ = true; - target_type_ = cast_target; - } - } - } if (ref_key && ref_key->isa()) { ref_key_value_ = ref_key->cast()->ref_key_value(); } } -BaseShapePtr AbstractRef::BuildShape() const { return ref_->BuildShape(); } - TypePtr AbstractRef::BuildType() const { - TypePtr subtype = ref_->BuildType(); - TypePtr subtype_origin = subtype; - if (need_cast_) { - subtype_origin = std::make_shared(target_type_); - } - return std::make_shared(subtype, subtype_origin); + auto subtype = AbstractTensor::BuildType()->cast(); + return std::make_shared(subtype); } bool AbstractRef::operator==(const AbstractRef &other) const { - return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) && (*ref_key_ == *other.ref_key_) && - (!need_cast_ || (*target_type_ == *other.target_type_)); + return AbstractTensor::equal_to(other) && (*ref_key_ == *other.ref_key_); } bool AbstractRef::operator==(const AbstractBase &other) const { @@ -886,24 +866,20 @@ AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) { AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { auto other_ref = other->cast(); if (other_ref == nullptr) { - auto new_ref = ref_->Join(other); - return std::make_shared(ref_key_, new_ref); + return AbstractTensor::Join(other)->cast(); } if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) { return shared_from_base(); } auto ref_key = ref_key_->Join(other_ref->ref_key_); - auto ref = ref_->Join(other_ref->ref()); + auto ref = AbstractTensor::Join(other_ref->ref())->cast(); return std::make_shared(ref_key, ref); } std::string AbstractRef::ToString() const { std::ostringstream buffer; buffer << type_name() << "(" - << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString(); - if (need_cast_) { - buffer << " cast to: " << target_type_->ToString(); - } + << "key: " << ref_key_->ToString() << " ref_value: " << AbstractTensor::ToString(); auto value = GetValueTrack(); if (value) { buffer << ", value: " << value->ToString(); diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index fa768addc6..3b3ccbfa92 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -284,11 +284,9 @@ class AbstractTensor : public AbstractUndetermined { AbstractBasePtr Clone() const override; AbstractBasePtr Broaden(uint8_t config = 0) const override; AbstractBasePtr BroadenWithShape() const; - AbstractBasePtr Join(const AbstractBasePtr &other) final; - + AbstractBasePtr Join(const AbstractBasePtr &other); bool operator==(const AbstractTensor &other) const; bool operator==(const AbstractBase &other) const override; - std::string ToString() const override; std::size_t hash() const override { auto value = GetValueTrack(); @@ -301,6 +299,9 @@ class AbstractTensor : public AbstractUndetermined { } return hash_sum; } + + protected: + bool equal_to(const AbstractTensor &other) const; }; using AbstractTensorPtr = std::shared_ptr; using AbstractTensorPtrList = std::vector; @@ -575,42 +576,42 @@ class AbstractRefKey : public AbstractBase { }; using AbstractRefKeyPtr = std::shared_ptr; -class AbstractRef : public AbstractBase { +class AbstractRef : public AbstractTensor { public: - AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast = false, - TypePtr cast_target = nullptr); + AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value); ~AbstractRef() override = default; - MS_DECLARE_PARENT(AbstractRef, AbstractBase) + MS_DECLARE_PARENT(AbstractRef, AbstractTensor) TypePtr BuildType() const override; - BaseShapePtr BuildShape() const override; bool operator==(const AbstractRef &other) const; bool operator==(const AbstractBase &other) const override; AbstractBasePtr Clone() const override { - return std::make_shared(ref_key_->Clone(), ref_->Clone(), need_cast_, target_type_); + auto abs_tensor = AbstractTensor::Clone()->cast(); + if (abs_tensor == nullptr) { + return nullptr; + } + return std::make_shared(ref_key_->Clone(), abs_tensor); } std::string ToString() const override; - inline AbstractBasePtr ref() const { return ref_; } + inline AbstractTensorPtr ref() { return shared_from_base(); } inline AbstractBasePtr ref_key() const { return ref_key_; } inline RefKeyPtr ref_key_value() const { return ref_key_value_; } - inline TypePtr target_type() const { return target_type_; } - inline bool need_cast() const { return need_cast_; } AbstractBasePtr Broaden(uint8_t config = 0) const override { // always broaden for ref - return std::make_shared(ref_key_->Broaden(config), ref_->Broaden(), need_cast_, target_type_); + auto abs_tensor = AbstractTensor::Broaden()->cast(); + if (abs_tensor == nullptr) { + return nullptr; + } + return std::make_shared(ref_key_->Broaden(config), abs_tensor); } AbstractBasePtr Join(const AbstractBasePtr &other) override; std::size_t hash() const override { - return ref_->hash() ^ (std::hash{}(this->tid()) << 1); // ref_key_->hash() ^ + return AbstractTensor::hash() ^ (std::hash{}(this->tid()) << 1); // ref_key_->hash() ^ } private: AbstractBasePtr ref_key_; - AbstractBasePtr ref_; - // For mix presicion, only float type need to cast to float16 of float32 - bool need_cast_; - TypePtr target_type_; // cache for ref_key after build value, when value is null, return nullptr. RefKeyPtr ref_key_value_; }; diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 358ed75849..3ba005fe2d 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -113,17 +113,8 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr & MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() << "."; } - TypePtr type = args_spec_list[0]->GetTypeTrack(); - ValuePtr tensor_target_v = args_spec_list[2]->BuildValue(); - if (type->type_id() != kObjectTypeRefKey) { - MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); - } - auto need_cast = !tensor_target_v->isa(); - if (need_cast && !tensor_target_v->isa()) { - MS_LOG(EXCEPTION) << "Third input of make_ref should be a Type but a " << tensor_target_v->ToString(); - } - TypePtr cast_target = tensor_target_v->cast(); - return std::make_shared(args_spec_list[0], args_spec_list[1], need_cast, cast_target); + auto tensor = args_spec_list[1]->cast(); + return std::make_shared(args_spec_list[0], tensor); } AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 89f9eec7e4..4d1541a35d 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -88,6 +88,17 @@ std::string Parameter::DebugString(int recursive_level) const { return buffer.str(); } +ParamInfoPtr Parameter::param_info() const { + if (!has_default()) { + return nullptr; + } + auto tensor = default_param()->cast(); + if (tensor == nullptr || !tensor->is_parameter()) { + return nullptr; + } + return tensor->param_info(); +} + std::string ValueNode::ToString() const { MS_EXCEPTION_IF_NULL(value_); if (value_->isa()) { diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index abda9d3885..409c731f38 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -75,7 +75,7 @@ using VarPtr = std::shared_ptr; class AnfIrVisitor; class ParamInfo; -using ParamValuePtr = std::shared_ptr; +using ParamInfoPtr = std::shared_ptr; // AnfNode is the basic class of the IR definition derived from Base. // Only two types of nodes are derived: CNode and ANode. @@ -288,6 +288,7 @@ class Parameter : public ANode { has_default_ = true; } ValuePtr default_param() const { return default_param_; } + ParamInfoPtr param_info() const; bool operator==(const AnfNode &other) const override { if (!other.isa()) { diff --git a/mindspore/core/ir/dtype.cc b/mindspore/core/ir/dtype.cc index b01d12d9ff..662f99e333 100644 --- a/mindspore/core/ir/dtype.cc +++ b/mindspore/core/ir/dtype.cc @@ -94,175 +94,6 @@ bool Slice::operator==(const Type &other) const { std::string Slice::DumpText() const { return ToString(); } -TypePtr UndeterminedType::DeepCopy() const { - MS_EXCEPTION_IF_NULL(element_type_); - if (IsGeneric()) { - return std::make_shared(); - } - return std::make_shared(element_type_->DeepCopy()); -} - -std::string UndeterminedType::ToReprString() const { - if (element_type_ == nullptr) { - return "Undetermined"; - } - return "Undetermined[" + element_type_->ToReprString() + "]"; -} - -std::string UndeterminedType::ToString() const { - if (element_type_ == nullptr) { - return "Undetermined"; - } - return "Undetermined[" + element_type_->ToString() + "]"; -} - -std::string UndeterminedType::DumpText() const { - if (element_type_ == nullptr) { - return "Undetermined"; - } - return "Undetermined[" + element_type_->DumpText() + "]"; -} - -bool UndeterminedType::operator==(const Type &other) const { - if (!IsSameObjectType(*this, other)) { - return false; - } - auto other_elem_type = static_cast(other).element_type_; - if (element_type_ == nullptr && other_elem_type == nullptr) { - return true; - } else if (element_type_ == nullptr || other_elem_type == nullptr) { - return false; - } - return *element_type_ == *other_elem_type; -} - -TypePtr TensorType::DeepCopy() const { - MS_EXCEPTION_IF_NULL(element_type_); - if (IsGeneric()) { - return std::make_shared(); - } - return std::make_shared(element_type_->DeepCopy()); -} - -std::string TensorType::ToReprString() const { - if (element_type_ == nullptr) { - return "tensor"; - } - return "tensor[" + element_type_->ToReprString() + "]"; -} - -std::string TensorType::ToString() const { - if (element_type_ == nullptr) { - return "Tensor"; - } - return "Tensor[" + element_type_->ToString() + "]"; -} - -std::string TensorType::DumpText() const { - if (element_type_ == nullptr) { - return "Tensor"; - } - return "Tensor(" + element_type_->DumpText() + ")"; -} - -bool TensorType::operator==(const Type &other) const { - if (!IsSameObjectType(*this, other)) { - return false; - } - auto other_elem_type = static_cast(other).element_type_; - // When element_type_ = nullptr, which means any type of Array. - if (element_type_ == nullptr && other_elem_type == nullptr) { - return true; - } else if (element_type_ == nullptr || other_elem_type == nullptr) { - return false; - } - return *element_type_ == *other_elem_type; -} - -TypePtr RowTensorType::DeepCopy() const { - MS_EXCEPTION_IF_NULL(element_type_); - if (IsGeneric()) { - return std::make_shared(); - } - return std::make_shared(element_type_->DeepCopy()); -} - -std::string RowTensorType::ToReprString() const { - if (element_type_ == nullptr) { - return "RowTensor"; - } - return "RowTensor[" + element_type_->ToReprString() + "]"; -} - -std::string RowTensorType::ToString() const { - if (element_type_ == nullptr) { - return "RowTensor"; - } - return "RowTensor[" + element_type_->ToString() + "]"; -} - -std::string RowTensorType::DumpText() const { - if (element_type_ == nullptr) { - return "RowTensor"; - } - return "RowTensor[" + element_type_->DumpText() + "]"; -} - -bool RowTensorType::operator==(const Type &other) const { - if (!IsSameObjectType(*this, other)) { - return false; - } - auto other_elem_type = static_cast(other).element_type_; - if (element_type_ == nullptr && other_elem_type == nullptr) { - return true; - } else if (element_type_ == nullptr || other_elem_type == nullptr) { - return false; - } - return *element_type_ == *other_elem_type; -} - -TypePtr SparseTensorType::DeepCopy() const { - MS_EXCEPTION_IF_NULL(element_type_); - if (IsGeneric()) { - return std::make_shared(); - } - return std::make_shared(element_type_->DeepCopy()); -} - -std::string SparseTensorType::ToReprString() const { - if (element_type_ == nullptr) { - return "SparseTensor"; - } - return "SparseTensor[" + element_type_->ToReprString() + "]"; -} - -std::string SparseTensorType::ToString() const { - if (element_type_ == nullptr) { - return "SparseTensor"; - } - return "SparseTensor[" + element_type_->ToString() + "]"; -} - -std::string SparseTensorType::DumpText() const { - if (element_type_ == nullptr) { - return "SparseTensor"; - } - return "SparseTensor[" + element_type_->DumpText() + "]"; -} - -bool SparseTensorType::operator==(const Type &other) const { - if (!IsSameObjectType(*this, other)) { - return false; - } - auto other_elem_type = static_cast(other).element_type_; - if (element_type_ == nullptr && other_elem_type == nullptr) { - return true; - } else if (element_type_ == nullptr || other_elem_type == nullptr) { - return false; - } - return *element_type_ == *other_elem_type; -} - Function::Function() : Object(kObjectTypeFunction) { args_ = std::vector(); retval_ = nullptr; @@ -372,4 +203,8 @@ std::ostream &operator<<(std::ostream &os, const std::shared_ptr proble os << problem->ToString(); return os; } + +const TypePtr kTensorTypeFP16 = std::make_shared(std::make_shared(16)); +const TypePtr kTensorTypeFP32 = std::make_shared(std::make_shared(32)); + } // namespace mindspore diff --git a/mindspore/core/ir/dtype.h b/mindspore/core/ir/dtype.h index 969fb4f190..92c36c463d 100644 --- a/mindspore/core/ir/dtype.h +++ b/mindspore/core/ir/dtype.h @@ -32,10 +32,11 @@ #include "ir/named.h" #include "ir/dtype/type.h" -#include "ir/dtype/ref.h" #include "ir/dtype/number.h" #include "ir/dtype/container.h" #include "ir/dtype/empty.h" +#include "ir/dtype/tensor_type.h" +#include "ir/dtype/ref.h" /* namespace to support intermediate representation definition */ namespace mindspore { @@ -108,98 +109,6 @@ class Slice : public Object { }; using SlicePtr = std::shared_ptr; -class UndeterminedType : public Object { - public: - UndeterminedType() : Object(kObjectTypeUndeterminedType) {} - explicit UndeterminedType(const TypePtr &ele) - : Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {} - ~UndeterminedType() override = default; - MS_DECLARE_PARENT(UndeterminedType, Object) - - TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; } - const TypePtr element() const { return element_type_; } - void set_element(const TypePtr &element_type) { element_type_ = element_type; } - - TypePtr DeepCopy() const override; - std::string ToString() const override; - std::string ToReprString() const override; - std::string DumpText() const override; - bool operator==(const Type &other) const override; - - protected: - TypePtr element_type_; -}; -using MetaTensorTypePtr = std::shared_ptr; - -class TensorType : public Object { - public: - TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {} - explicit TensorType(const TypePtr &ele) - : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} - ~TensorType() override = default; - MS_DECLARE_PARENT(TensorType, Object) - - TypeId generic_type_id() const override { return kObjectTypeTensorType; } - const TypePtr element() const { return element_type_; } - void set_element(const TypePtr &element_type) { element_type_ = element_type; } - - TypePtr DeepCopy() const override; - std::string ToString() const override; - std::string ToReprString() const override; - std::string DumpText() const override; - bool operator==(const Type &other) const override; - - private: - TypePtr element_type_; -}; -using TensorTypePtr = std::shared_ptr; - -class RowTensorType : public Object { - public: - RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {} - explicit RowTensorType(const TypePtr &ele) - : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} - ~RowTensorType() override = default; - MS_DECLARE_PARENT(RowTensorType, Object) - - TypeId generic_type_id() const override { return kObjectTypeRowTensorType; } - const TypePtr element() const { return element_type_; } - void set_element(const TypePtr &element_type) { element_type_ = element_type; } - - TypePtr DeepCopy() const override; - std::string ToString() const override; - std::string ToReprString() const override; - std::string DumpText() const override; - bool operator==(const Type &other) const override; - - private: - TypePtr element_type_; -}; -using RowTensorTypePtr = std::shared_ptr; - -class SparseTensorType : public Object { - public: - SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {} - explicit SparseTensorType(const TypePtr &ele) - : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} - ~SparseTensorType() override = default; - MS_DECLARE_PARENT(SparseTensorType, Object) - - TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; } - const TypePtr element() const { return element_type_; } - void set_element(const TypePtr &element_type) { element_type_ = element_type; } - - TypePtr DeepCopy() const override; - std::string ToString() const override; - std::string ToReprString() const override; - std::string DumpText() const override; - bool operator==(const Type &other) const override; - - private: - TypePtr element_type_; -}; -using SparseTensorTypePtr = std::shared_ptr; - class Function : public Object { public: Function(); @@ -353,6 +262,9 @@ extern const TypePtr kDict; extern const TypePtr kSlice; extern const TypePtr kKeyword; extern const TypePtr kTensorType; +extern const TypePtr kTensorTypeFP16; +extern const TypePtr kTensorTypeFP32; + } // namespace mindspore #endif // MINDSPORE_CORE_IR_DTYPE_H_ diff --git a/mindspore/core/ir/dtype/number.h b/mindspore/core/ir/dtype/number.h index 673957c825..ae7d65419b 100644 --- a/mindspore/core/ir/dtype/number.h +++ b/mindspore/core/ir/dtype/number.h @@ -68,6 +68,8 @@ class Number : public Object { const int nbits_; }; +using NumberPtr = std::shared_ptr; + // Bool class Bool : public Number { public: diff --git a/mindspore/core/ir/dtype/ref.cc b/mindspore/core/ir/dtype/ref.cc index 1cb601f4ae..3f2f38d2a5 100644 --- a/mindspore/core/ir/dtype/ref.cc +++ b/mindspore/core/ir/dtype/ref.cc @@ -19,15 +19,15 @@ #include #include #include "utils/log_adapter.h" +#include "ir/dtype/tensor_type.h" namespace mindspore { TypePtr RefType::DeepCopy() const { if (IsGeneric()) { return std::make_shared(); } else { - auto subtype = subtype_->DeepCopy(); - auto subtype_origin = subtype_origin_->DeepCopy(); - return std::make_shared(subtype, subtype_origin); + auto subtype = TensorType::DeepCopy()->cast(); + return std::make_shared(subtype); } } @@ -39,7 +39,7 @@ std::string RefType::DumpText() const { buffer << "Ref"; } else { buffer << "Ref["; - buffer << subtype_->DumpText() << "]"; + buffer << TensorType::DumpText() << "]"; } return buffer.str(); } diff --git a/mindspore/core/ir/dtype/ref.h b/mindspore/core/ir/dtype/ref.h index 79a596a90e..ccdcb6cf6b 100644 --- a/mindspore/core/ir/dtype/ref.h +++ b/mindspore/core/ir/dtype/ref.h @@ -17,21 +17,13 @@ #ifndef MINDSPORE_CORE_IR_DTYPE_REF_H_ #define MINDSPORE_CORE_IR_DTYPE_REF_H_ -#include -#include -#include -#include #include -#include -#include #include -#include -#include -#include -#include + #include "base/base.h" #include "ir/named.h" #include "ir/dtype/type.h" +#include "ir/dtype/tensor_type.h" namespace mindspore { // TypeRefKey type @@ -48,23 +40,16 @@ class RefKeyType : public Object { }; // TypeRef type -class RefType : public Object { +class RefType : public TensorType { public: - RefType() : Object(kObjectTypeRef) {} - RefType(const TypePtr &subtype, const TypePtr &subtype_origin) - : Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {} + RefType() : TensorType() {} + explicit RefType(const TensorTypePtr &subtype) : TensorType(subtype->element()) {} ~RefType() override {} - MS_DECLARE_PARENT(RefType, Object) + MS_DECLARE_PARENT(RefType, TensorType) - TypePtr subtype() const { return subtype_; } - TypeId generic_type_id() const override { return kObjectTypeRef; } TypePtr DeepCopy() const override; std::string ToString() const override; std::string DumpText() const override; - - private: - TypePtr subtype_; - TypePtr subtype_origin_; }; using RefTypePtr = std::shared_ptr; diff --git a/mindspore/core/ir/dtype/tensor_type.cc b/mindspore/core/ir/dtype/tensor_type.cc new file mode 100644 index 0000000000..98bd4363c2 --- /dev/null +++ b/mindspore/core/ir/dtype/tensor_type.cc @@ -0,0 +1,194 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/dtype/tensor_type.h" +#include +#include +#include +#include "utils/log_adapter.h" + +namespace mindspore { + +TypePtr UndeterminedType::DeepCopy() const { + MS_EXCEPTION_IF_NULL(element_type_); + if (IsGeneric()) { + return std::make_shared(); + } + return std::make_shared(element_type_->DeepCopy()); +} + +std::string UndeterminedType::ToReprString() const { + if (element_type_ == nullptr) { + return "Undetermined"; + } + return "Undetermined[" + element_type_->ToReprString() + "]"; +} + +std::string UndeterminedType::ToString() const { + if (element_type_ == nullptr) { + return "Undetermined"; + } + return "Undetermined[" + element_type_->ToString() + "]"; +} + +std::string UndeterminedType::DumpText() const { + if (element_type_ == nullptr) { + return "Undetermined"; + } + return "Undetermined[" + element_type_->DumpText() + "]"; +} + +bool UndeterminedType::operator==(const Type &other) const { + if (!IsSameObjectType(*this, other)) { + return false; + } + auto other_elem_type = static_cast(other).element_type_; + if (element_type_ == nullptr && other_elem_type == nullptr) { + return true; + } else if (element_type_ == nullptr || other_elem_type == nullptr) { + return false; + } + return *element_type_ == *other_elem_type; +} + +TypePtr TensorType::DeepCopy() const { + MS_EXCEPTION_IF_NULL(element_type_); + if (IsGeneric()) { + return std::make_shared(); + } + return std::make_shared(element_type_->DeepCopy()); +} + +std::string TensorType::ToReprString() const { + if (element_type_ == nullptr) { + return "tensor"; + } + return "tensor[" + element_type_->ToReprString() + "]"; +} + +std::string TensorType::ToString() const { + if (element_type_ == nullptr) { + return "Tensor"; + } + return "Tensor[" + element_type_->ToString() + "]"; +} + +std::string TensorType::DumpText() const { + if (element_type_ == nullptr) { + return "Tensor"; + } + return "Tensor(" + element_type_->DumpText() + ")"; +} + +bool TensorType::operator==(const Type &other) const { + if (!IsSameObjectType(*this, other)) { + return false; + } + auto other_elem_type = static_cast(other).element_type_; + // When element_type_ = nullptr, which means any type of Array. + if (element_type_ == nullptr && other_elem_type == nullptr) { + return true; + } else if (element_type_ == nullptr || other_elem_type == nullptr) { + return false; + } + return *element_type_ == *other_elem_type; +} + +TypePtr RowTensorType::DeepCopy() const { + MS_EXCEPTION_IF_NULL(element_type_); + if (IsGeneric()) { + return std::make_shared(); + } + return std::make_shared(element_type_->DeepCopy()); +} + +std::string RowTensorType::ToReprString() const { + if (element_type_ == nullptr) { + return "RowTensor"; + } + return "RowTensor[" + element_type_->ToReprString() + "]"; +} + +std::string RowTensorType::ToString() const { + if (element_type_ == nullptr) { + return "RowTensor"; + } + return "RowTensor[" + element_type_->ToString() + "]"; +} + +std::string RowTensorType::DumpText() const { + if (element_type_ == nullptr) { + return "RowTensor"; + } + return "RowTensor[" + element_type_->DumpText() + "]"; +} + +bool RowTensorType::operator==(const Type &other) const { + if (!IsSameObjectType(*this, other)) { + return false; + } + auto other_elem_type = static_cast(other).element_type_; + if (element_type_ == nullptr && other_elem_type == nullptr) { + return true; + } else if (element_type_ == nullptr || other_elem_type == nullptr) { + return false; + } + return *element_type_ == *other_elem_type; +} + +TypePtr SparseTensorType::DeepCopy() const { + MS_EXCEPTION_IF_NULL(element_type_); + if (IsGeneric()) { + return std::make_shared(); + } + return std::make_shared(element_type_->DeepCopy()); +} + +std::string SparseTensorType::ToReprString() const { + if (element_type_ == nullptr) { + return "SparseTensor"; + } + return "SparseTensor[" + element_type_->ToReprString() + "]"; +} + +std::string SparseTensorType::ToString() const { + if (element_type_ == nullptr) { + return "SparseTensor"; + } + return "SparseTensor[" + element_type_->ToString() + "]"; +} + +std::string SparseTensorType::DumpText() const { + if (element_type_ == nullptr) { + return "SparseTensor"; + } + return "SparseTensor[" + element_type_->DumpText() + "]"; +} + +bool SparseTensorType::operator==(const Type &other) const { + if (!IsSameObjectType(*this, other)) { + return false; + } + auto other_elem_type = static_cast(other).element_type_; + if (element_type_ == nullptr && other_elem_type == nullptr) { + return true; + } else if (element_type_ == nullptr || other_elem_type == nullptr) { + return false; + } + return *element_type_ == *other_elem_type; +} + +} // namespace mindspore diff --git a/mindspore/core/ir/dtype/tensor_type.h b/mindspore/core/ir/dtype/tensor_type.h new file mode 100644 index 0000000000..6909adec5b --- /dev/null +++ b/mindspore/core/ir/dtype/tensor_type.h @@ -0,0 +1,132 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ +#define MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "base/base.h" +#include "ir/named.h" +#include "ir/dtype/type.h" + +namespace mindspore { + +class UndeterminedType : public Object { + public: + UndeterminedType() : Object(kObjectTypeUndeterminedType) {} + explicit UndeterminedType(const TypePtr &ele) + : Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {} + ~UndeterminedType() override = default; + MS_DECLARE_PARENT(UndeterminedType, Object) + + TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; } + const TypePtr element() const { return element_type_; } + void set_element(const TypePtr &element_type) { element_type_ = element_type; } + + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string ToReprString() const override; + std::string DumpText() const override; + bool operator==(const Type &other) const override; + + protected: + TypePtr element_type_; +}; +using MetaTensorTypePtr = std::shared_ptr; + +class TensorType : public Object { + public: + TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {} + explicit TensorType(const TypePtr &ele) + : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} + ~TensorType() override = default; + MS_DECLARE_PARENT(TensorType, Object) + + TypeId generic_type_id() const override { return kObjectTypeTensorType; } + const TypePtr element() const { return element_type_; } + void set_element(const TypePtr &element_type) { element_type_ = element_type; } + + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string ToReprString() const override; + std::string DumpText() const override; + bool operator==(const Type &other) const override; + + private: + TypePtr element_type_; +}; +using TensorTypePtr = std::shared_ptr; + +class RowTensorType : public Object { + public: + RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {} + explicit RowTensorType(const TypePtr &ele) + : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} + ~RowTensorType() override = default; + MS_DECLARE_PARENT(RowTensorType, Object) + + TypeId generic_type_id() const override { return kObjectTypeRowTensorType; } + const TypePtr element() const { return element_type_; } + void set_element(const TypePtr &element_type) { element_type_ = element_type; } + + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string ToReprString() const override; + std::string DumpText() const override; + bool operator==(const Type &other) const override; + + private: + TypePtr element_type_; +}; +using RowTensorTypePtr = std::shared_ptr; + +class SparseTensorType : public Object { + public: + SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {} + explicit SparseTensorType(const TypePtr &ele) + : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} + ~SparseTensorType() override = default; + MS_DECLARE_PARENT(SparseTensorType, Object) + + TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; } + const TypePtr element() const { return element_type_; } + void set_element(const TypePtr &element_type) { element_type_ = element_type; } + + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string ToReprString() const override; + std::string DumpText() const override; + bool operator==(const Type &other) const override; + + private: + TypePtr element_type_; +}; +using SparseTensorTypePtr = std::shared_ptr; + +} // namespace mindspore + +#endif // MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_ diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index 01808ca4c6..fe4ca7afa9 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -332,14 +332,11 @@ class FuncGraph : public FuncGraphBase { const std::vector ¶mter_obj_nodes() const { return paramter_obj_nodes_; } void add_parameter_obj_node(const AnfNodePtr &p); - std::unordered_map &make_ref_params() { return make_ref_params_; } - std::unordered_map attrs_; std::vector joined_shapes_; std::unordered_map transforms_; // parameter default value std::map parameter_default_value_; - std::unordered_map make_ref_params_; size_t seen_; std::list GetOrderedCnodes(); diff --git a/mindspore/core/ir/meta_tensor.h b/mindspore/core/ir/meta_tensor.h index 100c3cc59e..dc8835e8b9 100644 --- a/mindspore/core/ir/meta_tensor.h +++ b/mindspore/core/ir/meta_tensor.h @@ -23,6 +23,7 @@ #include #include "base/base.h" +#include "ir/param_info.h" #include "ir/dtype.h" #include "utils/convert_utils_base.h" #include "utils/hashing.h" @@ -163,6 +164,15 @@ class MetaTensor : public Value { return false; } } + // Get tensor's param_info info. + ParamInfoPtr param_info() const { return param_info_; } + bool is_parameter() const { return is_parameter_; } + + // Set tensor's param_info info. + void set_param_info(const ParamInfoPtr ¶m_info) { + is_parameter_ = true; + param_info_ = param_info; + } protected: // brief Data type of the tensor. @@ -184,6 +194,9 @@ class MetaTensor : public Value { // // Includes the format and data type of a tensor on device. DeviceInfo device_info_; + + bool is_parameter_{false}; + ParamInfoPtr param_info_{nullptr}; }; using MetaTensorPtr = std::shared_ptr; diff --git a/mindspore/core/ir/meta_tensor_extends.cc b/mindspore/core/ir/meta_tensor_extends.cc index 53fc58eb78..b2476628cb 100644 --- a/mindspore/core/ir/meta_tensor_extends.cc +++ b/mindspore/core/ir/meta_tensor_extends.cc @@ -34,7 +34,16 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() { } auto tensor_shape = tens->shape(); auto abs_tensor = std::make_shared(dtype, tensor_shape); - abs_tensor->set_value(shared_from_base()); + + // if is parameter always no value. + if (is_parameter()) { + auto param_name = param_info()->name(); + auto ref_key = std::make_shared(param_name); + auto abs_ref_key = ref_key->ToAbstract(); + abs_tensor = std::make_shared(abs_ref_key, abs_tensor); + } else { + abs_tensor->set_value(shared_from_base()); + } return abs_tensor; } diff --git a/mindspore/core/ir/named.h b/mindspore/core/ir/named.h index 74fbf005a7..041bef12b0 100644 --- a/mindspore/core/ir/named.h +++ b/mindspore/core/ir/named.h @@ -62,6 +62,21 @@ class Named : public Value { }; using NamedPtr = std::shared_ptr; +struct NamedHasher { + std::size_t operator()(NamedPtr const &name) const { + std::size_t hash = name->Hash(); + return hash; + } +}; + +struct NamedEqual { + bool operator()(NamedPtr const &t1, NamedPtr const &t2) const { + MS_EXCEPTION_IF_NULL(t1); + MS_EXCEPTION_IF_NULL(t2); + return *t1 == *t2; + } +}; + class None : public Named { public: None() : Named("None") {} diff --git a/mindspore/core/ir/param_info.h b/mindspore/core/ir/param_info.h index c4a1ab5ee5..2f71e60d61 100644 --- a/mindspore/core/ir/param_info.h +++ b/mindspore/core/ir/param_info.h @@ -21,10 +21,13 @@ #include #include #include -#include "ir/anf.h" -#include "ir/tensor.h" + +#include "ir/dtype.h" namespace mindspore { +class ParamInfo; +using ParamInfoPtr = std::shared_ptr; + class ParamInfo { public: ParamInfo() {} @@ -55,7 +58,7 @@ class ParamInfo { int32_t cloned_index() const { return cloned_index_; } // Make a cloned parameter and update clone info. - ParamValuePtr Clone() { + ParamInfoPtr Clone() { static std::atomic parameter_cloned_index{1}; int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed); auto clone = std::make_shared(*this); diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index efffc2ba26..8a722a8e9e 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -461,6 +461,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) { } return *this; } + abstract::AbstractBasePtr Tensor::ToAbstract() { auto tens = shared_from_base(); auto dtype = tens->Dtype(); @@ -469,7 +470,15 @@ abstract::AbstractBasePtr Tensor::ToAbstract() { } auto tensor_shape = tens->shape(); auto abs_tensor = std::make_shared(dtype, tensor_shape); - abs_tensor->set_value(shared_from_base()); + // if is parameter always no value. + if (is_parameter()) { + auto param_name = param_info()->name(); + auto ref_key = std::make_shared(param_name); + auto abs_ref_key = ref_key->ToAbstract(); + abs_tensor = std::make_shared(abs_ref_key, abs_tensor); + } else { + abs_tensor->set_value(shared_from_base()); + } return abs_tensor; } diff --git a/mindspore/core/ir/value.cc b/mindspore/core/ir/value.cc index 560247b8ce..c1a71cbae4 100644 --- a/mindspore/core/ir/value.cc +++ b/mindspore/core/ir/value.cc @@ -200,16 +200,6 @@ bool StringImm::operator==(const Value &other) const { } bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; } -bool RefKey::operator==(const Value &other) const { - if (other.isa()) { - auto other_ = static_cast(other); - return *this == other_; - } else { - return false; - } -} -bool RefKey::operator==(const RefKey &other) const { return tag_ == other.tag_; } - bool AnyValue::operator==(const Value &other) const { if (other.isa()) { return true; diff --git a/mindspore/core/ir/value.h b/mindspore/core/ir/value.h index 6288aa6c67..c01f772232 100644 --- a/mindspore/core/ir/value.h +++ b/mindspore/core/ir/value.h @@ -224,28 +224,21 @@ using StringImmPtr = std::shared_ptr; IMM_TRAITS(StringImmPtr, std::string) IMM_TRAITS(StringImmPtr, const char *) -class RefKey : public Value { +class RefKey : public Named { public: - explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash{}(tag)) {} + explicit RefKey(const std::string &tag) : Named(tag) {} ~RefKey() override = default; - MS_DECLARE_PARENT(RefKey, Value) - std::size_t hash() const override { return hash_; } - const std::string &tag() const { return tag_; } - bool operator==(const Value &other) const override; - bool operator==(const RefKey &other) const; + MS_DECLARE_PARENT(RefKey, Named) + const std::string &tag() const { return name(); } abstract::AbstractBasePtr ToAbstract() override; - std::string ToString() const override { return "RefKey[" + tag_ + "]"; } + std::string ToString() const override { return "RefKey[" + name() + "]"; } std::string DumpText() const override { std::ostringstream oss; - oss << "RefKey[\"" << tag_ << "\"]"; + oss << "RefKey[\"" << name() << "\"]"; return oss.str(); } - - private: - std::string tag_; - std::size_t hash_ = 0; }; using RefKeyPtr = std::shared_ptr; diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 51aa509a00..1a324cb850 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -43,6 +43,8 @@ if(BUILD_CONVERTER) ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/scope.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/ref.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/tensor_type.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/container.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/empty.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/number.cc diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index d98b37a6da..ddf1c74be2 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -29,6 +29,8 @@ set(ANF_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/scope.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/ref.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/tensor_type.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/container.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/empty.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/number.cc diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index e7c057ac70..02c9490c4a 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -23,7 +23,7 @@ from ...common import dtype as mstype from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register -class Assign(PrimitiveWithInfer): +class Assign(Primitive): """ Assign `Parameter` with a value. diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index b371ccb0df..ec5d181469 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -18,7 +18,6 @@ import inspect import copy from mindspore.common.api import _wrap_func -from mindspore.common import Parameter from mindspore.common._register_for_tensor import tensor_operator_registry from mindspore import context from .._c_expression import Primitive_, real_run_op, prim_type @@ -410,16 +409,12 @@ def _run_op(obj, op_name, args): if op_name == "Cast" or obj.update_parameter: cast_args = args else: - cast_args = list() - for arg in args: - if isinstance(arg, Parameter): - if arg.cast_type: - cast_args.append(cast(arg, arg.cast_type)) - else: - cast_args.append(arg) - else: - cast_args.append(arg) - output = real_run_op(obj, op_name, tuple(cast_args)) + cast_args = args + for idx, arg in enumerate(args): + cast_type = getattr(arg, "cast_type", None) + if cast_type: + cast_args[idx] = cast(arg, cast_type) + output = real_run_op(obj, op_name, cast_args) if not output: raise RuntimeError("Pynative run op %s failed!" % op_name) if len(output) == 1: diff --git a/tests/st/control/test_ascend_control_sink.py b/tests/st/control/test_ascend_control_sink.py index 0e416e205e..8c6bf19aaa 100644 --- a/tests/st/control/test_ascend_control_sink.py +++ b/tests/st/control/test_ascend_control_sink.py @@ -118,26 +118,31 @@ class ControlMixedWhileIf(nn.Cell): self.var = Parameter(initializer(1, (1), mstype.float32), name="var") def construct(self, x, y, z, c2, c4): - out = self.assign(self.var, c4) + out = c4 + self.assign(self.var, c4) while x < c2: - y = self.assign(self.var, c4) + y = c4 + self.assign(self.var, c4) while y < c2 and x < c2: if 2 * y < c2: y = y + 2 else: y = y + 1 out = out + y - z = self.assign(self.var, c4) + z = c4 + self.assign(self.var, c4) while z < c2: z = z + 1 out = out + z x = x + 1 out = out + x while x < 2 * c2: - y = self.assign(self.var, c4) + y = c4 + self.assign(self.var, c4) x = x + 1 while y < c2: - z = self.assign(self.var, c4) + z = c4 + self.assign(self.var, c4) while z < c2: z = z + 1 if x < c2: diff --git a/tests/ut/python/pipeline/parse/test_parse.py b/tests/ut/python/pipeline/parse/test_parse.py index 8bafdf26c7..642539a178 100644 --- a/tests/ut/python/pipeline/parse/test_parse.py +++ b/tests/ut/python/pipeline/parse/test_parse.py @@ -27,6 +27,7 @@ import mindspore.nn as nn from mindspore import Tensor from mindspore import context from mindspore.ops import composite as C +from mindspore.ops import operations as P from mindspore.common.api import ms_function, _executor from mindspore.ops._grad.grad_base import bprop_getters from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer @@ -254,3 +255,60 @@ def test_bprop_with_wrong_output_shape(): net = BpropWithWrongOutputShapeCell() net.set_grad() grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32))) + +class AssignWhenInsertGrad(nn.Cell): + """ NetWithNDarray definition """ + + def __init__(self): + super(AssignWhenInsertGrad, self).__init__() + self.gather = P.GatherV2() + self.damping = Tensor(np.array([0.03, 0.03]).astype(np.float32)) + self.cov_step = ms.Parameter(0, name="cov_step", requires_grad=False) + self.freq = Tensor(278, ms.int32) + self.getG = P.InsertGradientOf(self.save_gradient) + + def save_gradient(self, dout): + self.cov_step = self.cov_step + self.freq + return dout + + def construct(self, x): + self.gather(self.damping, self.cov_step, 0) + out = P.ReLU()(x) + out = self.getG(out) + return out + +grad_all = C.GradOperation('get_all', get_all=True) + +class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + out = self.net(*inputs) + return out, grad_all(self.net)(*inputs) + +def test_assign_in_insert_grad(): + context.set_context(mode=context.GRAPH_MODE) + net = AssignWhenInsertGrad().to_float(ms.float16) + input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32') + net_back = GradNet(net) + net_back(ms.Tensor(input_data)) + +class Assign(nn.Cell): + """ NetWithNDarray definition """ + + def __init__(self): + super(Assign, self).__init__() + self.cov_step = ms.Parameter(0.0, name="cov_step", requires_grad=False) + + def construct(self, x): + self.cov_step = self.cov_step + x + return self.cov_step + +def test_assign(): + context.set_context(mode=context.GRAPH_MODE) + net = Assign() + input_data = ms.Tensor(np.array(1).astype(np.int32)) + net_back = GradNet(net) + net_back(input_data) diff --git a/tests/ut/python/pipeline/parse/test_while_param.py b/tests/ut/python/pipeline/parse/test_while_param.py new file mode 100644 index 0000000000..7bd7ff9680 --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_while_param.py @@ -0,0 +1,144 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_cont_break """ +import numpy as np + +import mindspore as ms +from mindspore import Tensor, context, nn, ms_function +from mindspore.nn import Cell +from mindspore.ops import operations as P + + +class WhileSubGraphParam(Cell): + def __init__(self): + super().__init__() + self.update = ms.Parameter(Tensor(1, ms.float32), "update") + + def construct(self, x, y, z): + out1 = z + while x < y: + self.update = self.update + 1 + out1 = out1 + 1 + x = x + 1 + return out1, self.update + + +def test_while_loop_phi(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + x = Tensor(0, ms.float32) + y = Tensor(10, ms.float32) + z = Tensor(100, ms.float32) + + net = WhileSubGraphParam() + net(x, y, z) + +class WhileSubGraphParam2(Cell): + def __init__(self): + super().__init__() + self.update = ms.Parameter(Tensor(1, ms.float32), "update") + + def construct(self, x, y, z): + out1 = z + i = self.update + while x < y: + i = i + 1 + out1 = out1 + 1 + x = x + 1 + return out1, self.update + + +def test_while_loop_phi_2(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + x = Tensor(0, ms.float32) + y = Tensor(10, ms.float32) + z = Tensor(100, ms.float32) + + net = WhileSubGraphParam2() + net(x, y, z) + + +class WhileSubGraphParam3(Cell): + def __init__(self, initial_input_x): + super().__init__() + self.initial_input_x = initial_input_x + self.X = ms.Parameter(initial_input_x, name="parameter_x") + self.Y = ms.Parameter(self.initial_input_x, name="parameter_y") + + def construct(self): + a = 0 + while a < 3: + self.X = self.X + self.Y + a += 1 + return self.X + + +def test_while_loop_phi_3(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + x = Tensor(0, ms.float32) + + net = WhileSubGraphParam3(x) + net() + +class ControlMixedWhileIf(nn.Cell): + def __init__(self): + super().__init__() + self.assign = P.Assign() + self.var = ms.Parameter(ms.Tensor([1], ms.float32), name="var") + + @ms_function + def construct(self, x, y, z, c2, c4): + out = self.assign(self.var, c4) + while x < c2: + y = self.assign(self.var, c4) + while y < c2 and x < c2: + if 2 * y < c2: + y = y + 2 + else: + y = y + 1 + out = out + y + z = self.assign(self.var, c4) + while z < c2: + z = z + 1 + out = out + z + x = x + 1 + out = out + x + while x < 2 * c2: + y = self.assign(self.var, c4) + x = x + 1 + while y < c2: + z = self.assign(self.var, c4) + while z < c2: + z = z + 1 + if x < c2: + y = y - 1 + else: + y = y + 1 + out = out + z + out = out + y + out = out + x + return out + +def test_mixed_while_if(): + context.set_context(mode=context.PYNATIVE_MODE) + x = np.array(2).astype(np.int32) + y = np.array(14).astype(np.int32) + z = np.array(1).astype(np.int32) + c2 = Tensor([14], ms.int32) + c4 = Tensor([0], ms.int32) + net = ControlMixedWhileIf() + output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4) + expect = np.array(3318).astype(np.int32) + assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) + context.set_context(mode=context.GRAPH_MODE) diff --git a/tests/vm_impl/array_ops_vm_impl.py b/tests/vm_impl/array_ops_vm_impl.py index 9f54533213..21628060cc 100644 --- a/tests/vm_impl/array_ops_vm_impl.py +++ b/tests/vm_impl/array_ops_vm_impl.py @@ -22,7 +22,13 @@ from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters from .vm_interface import vm # pylint: disable=unused-argument - +@vm_impl_getters.register(P.Assign) +def vm_impl_assign(self): + """Generate vm_impl function for Assign""" + def vm_impl(x, value): + x.assign_value(value) + return x + return vm_impl @vm_impl_getters.register(P.ExpandDims) def vm_impl_expand_dims(self):