!3271 make reftype a subtype of MetaTensor and try to mark ref in node input

Merge pull request !3271 from vlne-v1/ref_demo
This commit is contained in:
mindspore-ci-bot 2020-08-26 15:45:48 +08:00 committed by Gitee
commit 95212b55a0
47 changed files with 812 additions and 720 deletions

View File

@ -185,14 +185,23 @@ class Validator:
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
@staticmethod @staticmethod
def check_subclass(arg_name, type_, template_type, prim_name): def check_subclass(arg_name, type_, template_types, prim_name):
"""Checks whether some type is subclass of another type""" """Checks whether some type is subclass of another type"""
if not isinstance(template_type, Iterable): if not isinstance(template_types, Iterable):
template_type = (template_type,) template_types = (template_types,)
if not any([mstype.issubclass_(type_, x) for x in template_type]): hit = False
for template_type in template_types:
if isinstance(template_type, mstype.Type):
if mstype.issubclass_(type_, template_type):
hit = True
break
elif type_ is template_type:
hit = True
break
if not hit:
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') f' of {",".join((str(x) for x in template_types))}, but got {type_str}.')
@staticmethod @staticmethod
def check_const_input(arg_name, arg_value, prim_name): def check_const_input(arg_name, arg_value, prim_name):
@ -206,13 +215,7 @@ class Validator:
def _check_tensor_type(arg): def _check_tensor_type(arg):
arg_key, arg_val = arg arg_key, arg_val = arg
elem_type = arg_val elem_type = arg_val
if not elem_type in valid_values: Validator.check_subclass(arg_key, elem_type, valid_values, prim_name)
type_names = []
for t in valid_values:
type_names.append(str(t))
types_info = '[' + ', '.join(type_names) + ']'
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},'
f' but got {elem_type}.')
return (arg_key, elem_type) return (arg_key, elem_type)
def _check_types_same(arg1, arg2): def _check_types_same(arg1, arg2):
@ -335,12 +338,6 @@ class Validator:
class ParamValidator: class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`""" """Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@staticmethod
def equal(arg_name, arg_value, cond_str, cond):
"""Judging valid value."""
if not cond:
raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.')
@staticmethod @staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ): def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
"""This method is only used for check int values, since when compare float values, """This method is only used for check int values, since when compare float values,
@ -360,27 +357,6 @@ class ParamValidator:
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value return arg_value
@staticmethod
def check_shape_length(arg_name, arg_value, value, rel):
"""Shape length judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}')
return arg_value
@staticmethod
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel):
"""This method is only used for check int values,
since when compare float values, we need consider float error."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod @staticmethod
def check_isinstance(arg_name, arg_value, classes): def check_isinstance(arg_name, arg_value, classes):
"""Check arg isinstance of classes""" """Check arg isinstance of classes"""
@ -388,33 +364,6 @@ class ParamValidator:
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
return arg_value return arg_value
@staticmethod
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel):
"""Is it necessary to consider error when comparing float values."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_subclass(arg_name, type_, template_type, with_type_of=True):
"""Check whether some type is subclass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
@staticmethod
def check_args_tensor(args):
"""Check whether args are all tensor."""
if not isinstance(args, dict):
raise TypeError("The args should be a dict.")
for arg, value in args.items():
ParamValidator.check_subclass(arg, value, mstype.tensor)
@staticmethod @staticmethod
def check_bool(arg_name, arg_value): def check_bool(arg_name, arg_value):
"""Check arg isinstance of bool""" """Check arg isinstance of bool"""
@ -442,113 +391,6 @@ class ParamValidator:
return arg_value return arg_value
raise_error_msg() raise_error_msg()
@staticmethod
def check_typename(arg_name, arg_type, valid_types):
"""Does it contain the _name_ attribute."""
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()
if arg_type in valid_types:
return arg_type
type_names = [get_typename(t) for t in valid_types]
if len(valid_types) == 1:
raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise ValueError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
@staticmethod
def check_string(arg_name, arg_value, valid_values):
"""String type judgment."""
if isinstance(arg_value, str) and arg_value in valid_values:
return arg_value
if len(valid_values) == 1:
raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},'
f' but got {arg_value}.')
raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},'
f' but got {arg_value}.')
@staticmethod
def check_type_same(args, valid_values):
"""Determine whether the types are the same."""
name = list(args.keys())[0]
value = list(args.values())[0]
if isinstance(value, type(mstype.tensor)):
value = value.element_type()
for arg_name, arg_value in args.items():
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
if arg_value not in valid_values:
raise TypeError(f'The `{arg_name}` should be in {valid_values},'
f' but `{arg_name}` is {arg_value}.')
if arg_value != value:
raise TypeError(f'`{arg_name}` should be same as `{name}`,'
f' but `{arg_name}` is {arg_value}, `{name}` is {value}.')
@staticmethod
def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type):
"""Determine whether the types of two variables are the same."""
if arg1_type != arg2_type:
raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.')
@staticmethod
def check_value_on_integer(arg_name, arg_value, value, rel):
"""Judging integer type."""
rel_fn = Rel.get_fns(rel)
type_match = isinstance(arg_value, int)
if type_match and (not rel_fn(arg_value, value)):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_param_equal(param1_name, param1_value, param2_name, param2_value):
"""Judging the equality of parameters."""
if param1_value != param2_value:
raise ValueError(f"`{param1_name}` must equal `{param2_name}`,"
f" but got `{param1_name}` = {param1_value},"
f" `{param2_name}` = {param2_value}.")
@staticmethod
def check_const_input(arg_name, arg_value):
"""Check valid value."""
if arg_value is None:
raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.')
@staticmethod
def check_float_positive(arg_name, arg_value):
"""Float type judgment."""
if isinstance(arg_value, float):
if arg_value > 0:
return arg_value
raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.")
raise TypeError(f"`{arg_name}` must be float!")
@staticmethod
def check_pad_value_by_mode(op_name, pad_mode, padding):
"""Validate value of padding according to pad_mode"""
if pad_mode != 'pad' and padding != 0:
raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.")
return padding
@staticmethod
def check_empty_shape_input(arg_name, arg_value):
"""Check zeros value."""
if 0 in arg_value:
raise ValueError(f"Input `{arg_name}` cannot be empty.")
@staticmethod
def check_scalar_shape_input(arg_name, arg_value):
"""Check scalar shape input."""
if arg_value != []:
raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}")
def check_int(input_param): def check_int(input_param):
"""Int type judgment.""" """Int type judgment."""

View File

@ -592,11 +592,6 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_
return get_single_type((*tuple_ptr)[output_idx]); return get_single_type((*tuple_ptr)[output_idx]);
}; };
TypePtr type_ptr = node->Type(); TypePtr type_ptr = node->Type();
if (type_ptr->isa<RefType>()) {
auto ref_type_ptr = type_ptr->cast<RefTypePtr>();
MS_EXCEPTION_IF_NULL(ref_type_ptr);
return get_tuple_type(ref_type_ptr->subtype(), output_idx);
}
return get_tuple_type(type_ptr, output_idx); return get_tuple_type(type_ptr, output_idx);
} }

View File

@ -20,6 +20,7 @@
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/dtype.h"
#include "abstract/dshape.h" #include "abstract/dshape.h"
#include "abstract/param_validator.h" #include "abstract/param_validator.h"
#include "frontend/operator/cc_implementations.h" #include "frontend/operator/cc_implementations.h"
@ -43,15 +44,15 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) {
return empty; return empty;
} }
void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, void ProcessDefault(const std::string &func_name, size_t actual_param_number, const std::vector<Signature> &signature,
const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *const op_inputs) { bool has_var, std::vector<AnfNodePtr> *const op_inputs) {
std::size_t sig_size = signature.size(); std::size_t sig_size = signature.size();
auto positional_size = sig_size; auto positional_size = sig_size;
if (has_var) { if (has_var) {
positional_size = sig_size - 1; positional_size = sig_size - 1;
} }
if (args_spec_list.size() < positional_size) { if (actual_param_number < positional_size) {
for (size_t i = args_spec_list.size(); i < sig_size; ++i) { for (size_t i = actual_param_number; i < sig_size; ++i) {
auto default_value = signature[i].default_value; auto default_value = signature[i].default_value;
if (default_value == nullptr) { if (default_value == nullptr) {
MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
@ -67,23 +68,11 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
*max_type_number = type_number; *max_type_number = type_number;
} }
bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, bool is_write, TypeId *arg_type_id,
TypeId *arg_type = nullptr) { TypeId *arg_type = nullptr) {
if (arg_value->isa<abstract::AbstractRef>()) { if (arg_type_origin->isa<TensorType>()) {
auto ref = arg_value->cast<abstract::AbstractRefPtr>(); auto tensor = arg_type_origin->cast<TensorTypePtr>();
arg_value = ref->ref(); auto tensor_type = tensor->element();
if (!is_write && ref->need_cast()) {
auto tensor_type = ref->target_type();
*arg_type_id = tensor_type->type_id();
if (arg_type != nullptr) {
*arg_type = kObjectTypeTensorType;
}
return true;
}
}
if (arg_value->isa<abstract::AbstractTensor>()) {
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
auto tensor_type = tensor->element()->BuildType();
MS_EXCEPTION_IF_NULL(tensor_type); MS_EXCEPTION_IF_NULL(tensor_type);
*arg_type_id = tensor_type->type_id(); *arg_type_id = tensor_type->type_id();
if (arg_type != nullptr) { if (arg_type != nullptr) {
@ -91,9 +80,8 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
} }
return true; return true;
} }
if (arg_value->isa<abstract::AbstractScalar>()) { if (arg_type_origin->isa<Number>()) {
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>(); auto scalar_type = arg_type_origin->cast<NumberPtr>();
auto scalar_type = scalar->BuildType();
MS_EXCEPTION_IF_NULL(scalar_type); MS_EXCEPTION_IF_NULL(scalar_type);
*arg_type_id = scalar_type->type_id(); *arg_type_id = scalar_type->type_id();
if (arg_type != nullptr) { if (arg_type != nullptr) {
@ -104,7 +92,7 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
return false; return false;
} }
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices, TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t> indices,
const std::set<size_t> &write_indices) { const std::set<size_t> &write_indices) {
TypeId max_type_id = kTypeUnknown; TypeId max_type_id = kTypeUnknown;
size_t max_type_number = 0; size_t max_type_number = 0;
@ -115,7 +103,7 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
TypeId arg_type_id = kTypeUnknown; TypeId arg_type_id = kTypeUnknown;
TypeId arg_type = kTypeUnknown; TypeId arg_type = kTypeUnknown;
auto is_write = (write_indices.find(index) != write_indices.end()); auto is_write = (write_indices.find(index) != write_indices.end());
if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { if (!GetTensorOrScalarTypeInfo(input_types[index], is_write, &arg_type_id, &arg_type)) {
continue; continue;
} }
if (arg_type != kObjectTypeTensorType) { if (arg_type != kObjectTypeTensorType) {
@ -161,8 +149,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
// Get the largest type of index in the same SignatureEnumDType of arguments. // Get the largest type of index in the same SignatureEnumDType of arguments.
using MaxTypeMap = std::map<SignatureEnumDType, TypeId>; using MaxTypeMap = std::map<SignatureEnumDType, TypeId>;
MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, const std::vector<TypePtr> &input_types,
const abstract::AbstractBasePtrList &args_spec_list, const std::set<size_t> &write_indices) { const std::set<size_t> &write_indices) {
// record index for signature.dtypes of the same type // record index for signature.dtypes of the same type
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
std::map<SignatureEnumDType, std::vector<size_t>> type_indices; std::map<SignatureEnumDType, std::vector<size_t>> type_indices;
@ -184,11 +172,8 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
} }
bool has_tensor = false; bool has_tensor = false;
for (const auto &index : indices) { for (const auto &index : indices) {
AbstractBasePtr arg_value = args_spec_list[index]; auto arg_value = input_types[index];
if (arg_value->isa<abstract::AbstractRef>()) { if (arg_value->isa<TensorType>()) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
}
if (arg_value->isa<abstract::AbstractTensor>()) {
has_tensor = true; has_tensor = true;
break; break;
} }
@ -197,7 +182,7 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
(void)dst_type.insert(std::make_pair(type, kTypeUnknown)); (void)dst_type.insert(std::make_pair(type, kTypeUnknown));
continue; continue;
} }
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices))); (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(input_types, indices, write_indices)));
} }
return dst_type; return dst_type;
} }
@ -211,7 +196,7 @@ AnfNodePtr DoCast(const AnfNodePtr &param, const TypeId &type_id, const FuncGrap
} }
void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature, void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, const std::vector<TypePtr> &input_types, const FuncGraphPtr &graph,
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) { std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) {
std::vector<SignatureEnumDType> dtypes; std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
@ -221,9 +206,9 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
return; return;
} }
// Stat the index of the arguments with the largest type in the same SignatureEnumDType. // Stat the index of the arguments with the largest type in the same SignatureEnumDType.
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices); std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, input_types, write_indices);
// Identify which arg requires auto cast // Identify which arg requires auto cast
for (size_t i = 0; i < args_spec_list.size(); ++i) { for (size_t i = 0; i < input_types.size(); ++i) {
auto it = dst_type.find(dtypes[i]); auto it = dst_type.find(dtypes[i]);
if (it == dst_type.end() || it->second == kTypeUnknown) { if (it == dst_type.end() || it->second == kTypeUnknown) {
continue; continue;
@ -232,7 +217,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
auto is_write = (rw_it != write_indices.end()); auto is_write = (rw_it != write_indices.end());
TypeId arg_type_id = kTypeUnknown; TypeId arg_type_id = kTypeUnknown;
AbstractBasePtr arg_value = args_spec_list[i]; auto arg_value = input_types[i];
(void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id);
auto it_map = type_name_map.find(arg_type_id); auto it_map = type_name_map.find(arg_type_id);
if (it_map == type_name_map.end()) { if (it_map == type_name_map.end()) {
@ -248,7 +233,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
} }
continue; continue;
} }
if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) { if ((arg_value->isa<TensorType>()) && arg_type_id == it->second) {
continue; continue;
} }
MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id
@ -275,6 +260,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
} }
std::vector<AnfNodePtr> op_inputs; std::vector<AnfNodePtr> op_inputs;
std::set<size_t> write_indices; std::set<size_t> write_indices;
std::vector<TypePtr> input_types;
op_inputs.push_back(NewValueNode(function)); op_inputs.push_back(NewValueNode(function));
// Assume, the write input of op is always the first input. We check if any write op, // Assume, the write input of op is always the first input. We check if any write op,
// and add cast op on other inputs to keep the same type with assigned parameter. // and add cast op on other inputs to keep the same type with assigned parameter.
@ -292,30 +278,36 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
sig = signature[sig_size - 1].rw; sig = signature[sig_size - 1].rw;
} }
TypePtr type = args_spec_list[i]->GetTypeTrack(); TypePtr type = args_spec_list[i]->BuildType();
if (type && type->type_id() == kObjectTypeRef) { if (type && type->isa<RefType>()) {
auto ref_abs = args_spec_list[i]->cast<abstract::AbstractRefPtr>(); auto cast_type = parse::GetMixedPrecisionTargetType(func_graph);
if (sig == SignatureEnumRW::kRWRead) { if (sig == SignatureEnumRW::kRWRead) {
param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph); auto source_tensor_type = type->cast<TensorTypePtr>();
if (ref_abs && ref_abs->need_cast()) { if (source_tensor_type != nullptr) {
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); auto source_element = source_tensor_type->element();
param = NewCNode({NewValueNode(cast), param, NewValueNode(ref_abs->target_type())}, func_graph); if (cast_type != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
param = NewCNode({NewValueNode(cast), param, NewValueNode(cast_type)}, func_graph);
type = cast_type->type_id() == kNumberTypeFloat16 ? kTensorTypeFP16 : kTensorTypeFP32;
}
} }
} else if (sig == SignatureEnumRW::kRWWrite) { } else if (sig == SignatureEnumRW::kRWWrite) {
param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
write_indices.insert(i); write_indices.insert(i);
} }
// If sig is SignatureEnumRW::kRWRef, not do anything. // If sig is SignatureEnumRW::kRWRef, not do anything.
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { } else if (sig == SignatureEnumRW::kRWWrite &&
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; !((type->type_id() == kObjectTypeRef) || (type->type_id() == kObjectTypeRefKey))) {
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but "
<< type->ToString();
} }
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type " MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type "
<< args_spec_list[i]->ToString(); << args_spec_list[i]->ToString();
input_types.push_back(type);
op_inputs.push_back(param); op_inputs.push_back(param);
} }
// process default // process default
ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); ProcessDefault(func_name, args_spec_list.size(), signature, has_var, &op_inputs);
DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices); DoAutoCast(func_name, signature, input_types, func_graph, &op_inputs, write_indices);
return func_graph->NewCNode(op_inputs); return func_graph->NewCNode(op_inputs);
} }
} // namespace } // namespace

View File

@ -81,12 +81,6 @@ void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &
} }
Register(types_name, py_fn); Register(types_name, py_fn);
} }
static TypePtr UnwrapRef(const TypePtr &type) {
if (type->isa<RefType>()) {
return type->cast<RefTypePtr>()->subtype();
}
return type;
}
// Return Exact match if exists, else return non ambiguous sub class match // Return Exact match if exists, else return non ambiguous sub class match
// Return py::none() if matching is ambiguous // Return py::none() if matching is ambiguous
@ -99,7 +93,7 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
} }
auto match = true; auto match = true;
for (size_t i = 0; i < sign.size(); ++i) { for (size_t i = 0; i < sign.size(); ++i) {
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) { if (!IsIdentidityOrSubclass(types[i], sign[i])) {
match = false; match = false;
break; break;
} }

View File

@ -627,6 +627,16 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods()); return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
} }
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor
CheckArgsSize(primitive->name(), args_spec_list, 2);
MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0];
return args_spec_list[0];
}
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof); REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType); REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord); REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord);
@ -648,5 +658,7 @@ REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImpl
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ); REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
InferImplBroadcastGradientArgs); InferImplBroadcastGradientArgs);
REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign);
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -20,6 +20,7 @@
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/param_info.h" #include "ir/param_info.h"
#include "ir/meta_tensor.h"
#include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/python_adapter.h"
namespace mindspore { namespace mindspore {
@ -38,8 +39,7 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
if (!para_ptr->has_default()) { if (!para_ptr->has_default()) {
return false; return false;
} }
auto obj = py::cast(para_ptr->default_param()); auto param_value = para_ptr->param_info();
auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
if (param_value == nullptr) { if (param_value == nullptr) {
return false; return false;
} }

View File

@ -1356,8 +1356,7 @@ bool ParameterIsCloned(const AnfNodePtr &parameter_node) {
if (!cloned_parameter->has_default()) { if (!cloned_parameter->has_default()) {
return false; return false;
} }
auto obj = py::cast(cloned_parameter->default_param()); auto param_value = cloned_parameter->param_info();
auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
if (param_value == nullptr) { if (param_value == nullptr) {
return false; return false;
} }
@ -1380,8 +1379,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
if (!ParameterIsCloned(cloned_parameter_node)) { if (!ParameterIsCloned(cloned_parameter_node)) {
continue; continue;
} }
auto obj = py::cast(cloned_parameter->default_param()); auto param_value = cloned_parameter->param_info();
auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
if (param_value == nullptr) { if (param_value == nullptr) {
continue; continue;
} }
@ -1400,10 +1398,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
continue; continue;
} }
const auto &param_value_cloned = be_cloned_parameter->default_param(); auto param_value_in = be_cloned_parameter->param_info();
auto obj_in = py::cast(param_value_cloned);
auto param_value_in = py::cast<ParamValuePtr>(obj_in.attr("_value"));
if (param_value_in == nullptr) { if (param_value_in == nullptr) {
continue; continue;
} }

View File

@ -233,13 +233,14 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
for (const auto &param : func_graph->parameters()) { for (const auto &param : func_graph->parameters()) {
auto param_node = std::static_pointer_cast<Parameter>(param); auto param_node = std::static_pointer_cast<Parameter>(param);
if (param_node->has_default()) { if (param_node->has_default()) {
ValuePtr value = param_node->default_param(); auto value = param_node->default_param();
constexpr bool broaden = true; auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
AbstractBasePtr ptr = abstract::FromValue(value, broaden); auto ref_key = std::make_shared<RefKey>(param_node->name());
auto abs_ref_key = ref_key->ToAbstract();
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
args_spec.push_back(ptr); parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, abs_ref);
parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr); args_spec.push_back(abs_ref);
parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, abs_ref);
} }
} }
// Analyze // Analyze

View File

@ -425,9 +425,6 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
converted = env; converted = env;
} else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) { } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj); converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
} else if (py::hasattr(obj, "__parameter__")) {
auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
ret = ConvertData(to_convert, &converted);
} else { } else {
ret = ConvertOtherObj(obj, &converted); ret = ConvertOtherObj(obj, &converted);
} }
@ -555,9 +552,6 @@ void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name)
ValuePtr PyDataToValue(const py::object &obj) { ValuePtr PyDataToValue(const py::object &obj) {
py::object to_convert = obj; py::object to_convert = obj;
if (py::hasattr(obj, "__parameter__")) {
to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
}
ValuePtr value = nullptr; ValuePtr value = nullptr;
(void)ConvertData(to_convert, &value); (void)ConvertData(to_convert, &value);
return value; return value;

View File

@ -306,7 +306,14 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr
} }
void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) { void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) {
state_assign_[target] = readid; const std::string primitive_name("assign");
const std::string module_name("mindspore.ops.functional");
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
auto source = ReadVariable(readid);
auto assign = func_graph()->NewCNode({assign_op, target, source});
WriteVariable(readid, assign);
MS_LOG(INFO) << "SetState read " << target->DebugString() << ", " << readid;
AddAutoDepend(assign);
} }
void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); } void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); }
@ -321,21 +328,13 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple); ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple);
ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend); ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend);
ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient); ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient);
const std::string primitive_name("assign");
const std::string module_name("mindspore.ops.functional"); if (auto_depends_.size() == 0) {
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
if (state_assign_.size() == 0 && auto_depends_.size() == 0) {
return; return;
} }
AnfNodePtr state = nullptr; AnfNodePtr state = nullptr;
std::vector<AnfNodePtr> vec_states; std::vector<AnfNodePtr> vec_states;
vec_states.emplace_back(make_tuple_op); vec_states.emplace_back(make_tuple_op);
for (auto &item : state_assign_) {
auto source = ReadVariable(item.second);
auto assign = func_graph()->NewCNode({assign_op, item.first, source});
MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second;
vec_states.emplace_back(assign);
}
for (auto &item : auto_depends_) { for (auto &item : auto_depends_) {
MS_LOG(DEBUG) << "auto_depends " << item->ToString(); MS_LOG(DEBUG) << "auto_depends " << item->ToString();
vec_states.emplace_back(item); vec_states.emplace_back(item);
@ -361,7 +360,6 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state}); AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state});
AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped}); AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped});
func_graph()->set_output(ret, true); func_graph()->set_output(ret, true);
state_assign_.clear();
} }
} // namespace parse } // namespace parse
} // namespace mindspore } // namespace mindspore

View File

@ -101,9 +101,6 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
// keeps all removable phis which will be removed in one pass. // keeps all removable phis which will be removed in one pass.
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_; std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_;
// set state nodes need to insert before function return nodes.
OrderedMap<AnfNodePtr, std::string> state_assign_;
// hold declared global variables in function // hold declared global variables in function
std::set<std::string> global_vars_; std::set<std::string> global_vars_;

View File

@ -59,14 +59,13 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
return func_graph; return func_graph;
} }
ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr &param) { TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) {
TypePtr dst_type;
if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
return kFloat32; return kFloat32;
} else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
return kFloat16; return kFloat16;
} else { } else {
return kNone; return nullptr;
} }
} }

View File

@ -364,7 +364,7 @@ class ParseAst {
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph);
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param); AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr &param); TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph);
} // namespace parse } // namespace parse
} // namespace mindspore } // namespace mindspore

View File

@ -105,24 +105,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
auto value = py::cast<tensor::MetaTensorPtr>(obj); auto value = py::cast<tensor::MetaTensorPtr>(obj);
node->set_default_param(value); node->set_default_param(value);
// set_abstract for parameter // set_abstract for parameter
constexpr bool broaden = true; auto abs = value->ToAbstract();
node->set_abstract(abstract::FromValue(value, broaden)); node->set_abstract(abs);
para_node = node; para_node = node;
} }
auto iter = func_graph->make_ref_params().find(para_node);
if (iter == func_graph->make_ref_params().end()) {
ValuePtr target_type = GetMixedPrecisionTargetType(func_graph, para_node);
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); return para_node;
AnfNodePtr ref_key = NewValueNode(std::make_shared<RefKey>(param_name));
AnfNodePtr target_type_node = NewValueNode(target_type);
AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, para_node, target_type_node});
func_graph->make_ref_params()[para_node] = ref_node;
func_graph->add_parameter_obj_node(ref_node);
return ref_node;
} else {
return iter->second;
}
} }
bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) {

View File

@ -640,7 +640,14 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
size_t size = op_exec_info->op_inputs.size(); size_t size = op_exec_info->op_inputs.size();
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
auto obj = op_exec_info->op_inputs[i]; auto obj = op_exec_info->op_inputs[i];
bool op_mask = py::hasattr(obj, "__parameter__"); bool op_mask = false;
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
if (meta_tensor) {
op_mask = meta_tensor->is_parameter();
}
}
(*op_masks).push_back(op_mask); (*op_masks).push_back(op_mask);
MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ " MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ "
<< grad_flag_; << grad_flag_;
@ -990,8 +997,9 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) { if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
auto free_param = df_builder_->add_parameter(); auto free_param = df_builder_->add_parameter();
free_param->set_name(param_name); free_param->set_name(param_name);
free_param->set_default_param(py::cast<tensor::TensorPtr>(obj));
free_param->debug_info()->set_name(param_name); free_param->debug_info()->set_name(param_name);
auto value = py::cast<tensor::TensorPtr>(obj);
free_param->set_default_param(value);
MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
graph_info_map_[df_builder_].param_map[obj_id] = free_param; graph_info_map_[df_builder_].param_map[obj_id] = free_param;
return free_param; return free_param;
@ -1159,17 +1167,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
auto param_name = py::cast<std::string>(name_attr); auto param_name = py::cast<std::string>(name_attr);
auto free_param = df_builder_->add_parameter(); auto free_param = df_builder_->add_parameter();
free_param->set_name(param_name); free_param->set_name(param_name);
free_param->set_default_param(py::cast<tensor::TensorPtr>(param)); auto value = py::cast<tensor::TensorPtr>(param);
free_param->set_default_param(value);
free_param->debug_info()->set_name(param_name); free_param->debug_info()->set_name(param_name);
para_node = free_param; para_node = free_param;
} }
ValuePtr target_type = parse::GetMixedPrecisionTargetType(df_builder_, para_node); w_args.push_back(para_node);
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name());
AnfNodePtr ref_key_node = NewValueNode(refkey);
AnfNodePtr target_type_node = NewValueNode(target_type);
AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, para_node, target_type_node});
w_args.push_back(ref_node);
} }
} else { } else {
MS_LOG(DEBUG) << "training not paramter_tuple"; MS_LOG(DEBUG) << "training not paramter_tuple";
@ -1197,7 +1200,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
auto param_node = std::static_pointer_cast<Parameter>(param); auto param_node = std::static_pointer_cast<Parameter>(param);
if (param_node->has_default()) { if (param_node->has_default()) {
ValuePtr value = param_node->default_param(); ValuePtr value = param_node->default_param();
AbstractBasePtr ptr = abstract::FromValue(value, true); auto ptr = value->ToAbstract();
if (ptr == nullptr) { if (ptr == nullptr) {
MS_LOG(EXCEPTION) << "Args convert error"; MS_LOG(EXCEPTION) << "Args convert error";
} }

View File

@ -147,7 +147,7 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<TypeType, Type, std::shared_ptr<TypeType>>(m_sub, "TypeType").def(py::init()); (void)py::class_<TypeType, Type, std::shared_ptr<TypeType>>(m_sub, "TypeType").def(py::init());
(void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init()); (void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init());
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init()); (void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); (void)py::class_<RefType, TensorType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init()); (void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
(void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init()); (void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init());
(void)py::class_<TypeEllipsis, Type, std::shared_ptr<TypeEllipsis>>(m_sub, "TypeEllipsis").def(py::init()); (void)py::class_<TypeEllipsis, Type, std::shared_ptr<TypeEllipsis>>(m_sub, "TypeEllipsis").def(py::init());

View File

@ -21,7 +21,7 @@ namespace mindspore {
namespace py = pybind11; namespace py = pybind11;
REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) { REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
(void)py::class_<ParamInfo, ParamValuePtr>(*m, "ParamInfo") (void)py::class_<ParamInfo, ParamInfoPtr>(*m, "ParamInfo")
.def(py::init()) .def(py::init())
.def("clone", &ParamInfo::Clone) .def("clone", &ParamInfo::Clone)
.def_property("name", &ParamInfo::name, &ParamInfo::set_name) .def_property("name", &ParamInfo::name, &ParamInfo::set_name)
@ -36,7 +36,7 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
if (t.size() != 6) { if (t.size() != 6) {
std::runtime_error("Invalid state for ParamInfo!"); std::runtime_error("Invalid state for ParamInfo!");
} }
ParamValuePtr p = std::make_shared<ParamInfo>(); ParamInfoPtr p = std::make_shared<ParamInfo>();
p->set_name(t[1].cast<std::string>()); p->set_name(t[1].cast<std::string>());
p->set_requires_grad(t[2].cast<bool>()); p->set_requires_grad(t[2].cast<bool>());
p->set_layerwise_parallel(t[3].cast<bool>()); p->set_layerwise_parallel(t[3].cast<bool>());

View File

@ -291,6 +291,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
.def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape")) .def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape"))
.def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.")
.def_property("_param_info", &MetaTensor::param_info, &MetaTensor::set_param_info)
.def(py::pickle( .def(py::pickle(
[](const MetaTensor &t) { // __getstate__ [](const MetaTensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */ /* Return a tuple that fully encodes the state of the object */

View File

@ -42,7 +42,7 @@ class Parameter(MetaTensor):
In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
an `Initializer`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor` an `Initializer`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor`
only saves the shape and type info of a tensor with no memory usage. The shape can be changed while only saves the shape and type info of a tensor with no memory usage. The shape can be changed while
compile for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data. compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
Note: Note:
Each parameter of Cell is represented by Parameter class. Each parameter of Cell is represented by Parameter class.
@ -108,7 +108,7 @@ class Parameter(MetaTensor):
Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False): def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False):
self._value = ParamInfo() self._param_info = ParamInfo()
self.name = name self.name = name
self.requires_grad = requires_grad self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel self.layerwise_parallel = layerwise_parallel
@ -156,13 +156,13 @@ class Parameter(MetaTensor):
value_str = MetaTensor.__str__(self) value_str = MetaTensor.__str__(self)
if isinstance(self, Tensor): if isinstance(self, Tensor):
value_str = Tensor.__str__(self) value_str = Tensor.__str__(self)
return f'Parameter (name={self._value.name}, value={value_str})' return f'Parameter (name={self._param_info.name}, value={value_str})'
def __repr__(self): def __repr__(self):
value_str = MetaTensor.__repr__(self) value_str = MetaTensor.__repr__(self)
if isinstance(self, Tensor): if isinstance(self, Tensor):
value_str = Tensor.__repr__(self) value_str = Tensor.__repr__(self)
return f'Parameter (name={self._value.name}, value={value_str})' return f'Parameter (name={self._param_info.name}, value={value_str})'
def __parameter__(self): def __parameter__(self):
"""For parse check.""" """For parse check."""
@ -181,7 +181,7 @@ class Parameter(MetaTensor):
@property @property
def name(self): def name(self):
"""Get the name of the parameter.""" """Get the name of the parameter."""
return self._value.name return self._param_info.name
@name.setter @name.setter
def name(self, name_): def name(self, name_):
@ -203,7 +203,7 @@ class Parameter(MetaTensor):
format(name_, PARAMETER_NAME_PREFIX_MAX_LEN)) format(name_, PARAMETER_NAME_PREFIX_MAX_LEN))
else: else:
raise ValueError("The type of the name should be `str` or `None`.") raise ValueError("The type of the name should be `str` or `None`.")
self._value.name = name_ self._param_info.name = name_
@property @property
def cast_type(self): def cast_type(self):
@ -254,8 +254,8 @@ class Parameter(MetaTensor):
_check_str_by_regular(prefix) _check_str_by_regular(prefix)
x = copy(self) x = copy(self)
# pylint: disable=protected-access # pylint: disable=protected-access
x._value = self._value.clone() x._param_info = self._param_info.clone()
x._value.name = prefix + '.' + self._value.name x._param_info.name = prefix + '.' + self._param_info.name
x.is_init = False x.is_init = False
if init != 'same': if init != 'same':
shape = self.shape shape = self.shape
@ -265,24 +265,24 @@ class Parameter(MetaTensor):
@property @property
def layerwise_parallel(self): def layerwise_parallel(self):
return self._value.layerwise_parallel return self._param_info.layerwise_parallel
@layerwise_parallel.setter @layerwise_parallel.setter
def layerwise_parallel(self, value=True): def layerwise_parallel(self, value=True):
if not isinstance(value, bool): if not isinstance(value, bool):
raise TypeError("`layerwise_parallel` parameter must be bool type") raise TypeError("`layerwise_parallel` parameter must be bool type")
self._value.layerwise_parallel = value self._param_info.layerwise_parallel = value
@property @property
def requires_grad(self): def requires_grad(self):
"""Return whether the parameter requires gradient.""" """Return whether the parameter requires gradient."""
return self._value.requires_grad return self._param_info.requires_grad
@requires_grad.setter @requires_grad.setter
def requires_grad(self, value=True): def requires_grad(self, value=True):
if not isinstance(value, bool): if not isinstance(value, bool):
raise TypeError("`requires_grad` parameter must be bool type") raise TypeError("`requires_grad` parameter must be bool type")
self._value.requires_grad = value self._param_info.requires_grad = value
@property @property
def data(self): def data(self):

View File

@ -459,10 +459,6 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
} }
auto other_tensor = dyn_cast<AbstractTensor>(other); auto other_tensor = dyn_cast<AbstractTensor>(other);
if (other_tensor == nullptr) { if (other_tensor == nullptr) {
auto ref_tensor = dyn_cast<AbstractRef>(other);
if (ref_tensor != nullptr) {
return this->Join(ref_tensor->ref());
}
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
} }
if (*this == *other) { if (*this == *other) {
@ -473,7 +469,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
return std::make_shared<AbstractTensor>(element, shape); return std::make_shared<AbstractTensor>(element, shape);
} }
bool AbstractTensor::operator==(const AbstractTensor &other) const { bool AbstractTensor::equal_to(const AbstractTensor &other) const {
if (&other == this) { if (&other == this) {
return true; return true;
} }
@ -491,12 +487,14 @@ bool AbstractTensor::operator==(const AbstractTensor &other) const {
return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal; return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal;
} }
bool AbstractTensor::operator==(const AbstractTensor &other) const { return equal_to(other); }
bool AbstractTensor::operator==(const AbstractBase &other) const { bool AbstractTensor::operator==(const AbstractBase &other) const {
if (&other == this) { if (&other == this) {
return true; return true;
} }
if (other.isa<AbstractTensor>()) { if (other.tid() == tid()) {
auto other_tensor = static_cast<const AbstractTensor *>(&other); auto other_tensor = static_cast<const AbstractTensor *>(&other);
return *this == *other_tensor; return *this == *other_tensor;
} else { } else {
@ -822,39 +820,21 @@ std::string AbstractJTagged::ToString() const {
return buffer.str(); return buffer.str();
} }
AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast, AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value)
TypePtr cast_target) : AbstractTensor(*ref_value), ref_key_(ref_key), ref_key_value_(nullptr) {
: ref_key_(ref_key), ref_(ref_value), need_cast_(false), target_type_(nullptr), ref_key_value_(nullptr) {
set_type(std::make_shared<RefType>()); set_type(std::make_shared<RefType>());
auto origin_type = ref_value->BuildType();
if (need_cast && cast_target && origin_type && origin_type->isa<TensorType>()) {
auto tensor_dtype = origin_type->cast<TensorTypePtr>()->element();
if (tensor_dtype && IsSubType(tensor_dtype, kFloat)) {
if (cast_target != tensor_dtype) {
need_cast_ = true;
target_type_ = cast_target;
}
}
}
if (ref_key && ref_key->isa<AbstractRefKey>()) { if (ref_key && ref_key->isa<AbstractRefKey>()) {
ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value(); ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value();
} }
} }
BaseShapePtr AbstractRef::BuildShape() const { return ref_->BuildShape(); }
TypePtr AbstractRef::BuildType() const { TypePtr AbstractRef::BuildType() const {
TypePtr subtype = ref_->BuildType(); auto subtype = AbstractTensor::BuildType()->cast<TensorTypePtr>();
TypePtr subtype_origin = subtype; return std::make_shared<RefType>(subtype);
if (need_cast_) {
subtype_origin = std::make_shared<TensorType>(target_type_);
}
return std::make_shared<RefType>(subtype, subtype_origin);
} }
bool AbstractRef::operator==(const AbstractRef &other) const { bool AbstractRef::operator==(const AbstractRef &other) const {
return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) && (*ref_key_ == *other.ref_key_) && return AbstractTensor::equal_to(other) && (*ref_key_ == *other.ref_key_);
(!need_cast_ || (*target_type_ == *other.target_type_));
} }
bool AbstractRef::operator==(const AbstractBase &other) const { bool AbstractRef::operator==(const AbstractBase &other) const {
@ -886,24 +866,20 @@ AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) {
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
auto other_ref = other->cast<AbstractRefPtr>(); auto other_ref = other->cast<AbstractRefPtr>();
if (other_ref == nullptr) { if (other_ref == nullptr) {
auto new_ref = ref_->Join(other); return AbstractTensor::Join(other)->cast<AbstractTensorPtr>();
return std::make_shared<AbstractRef>(ref_key_, new_ref);
} }
if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) { if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) {
return shared_from_base<AbstractBase>(); return shared_from_base<AbstractBase>();
} }
auto ref_key = ref_key_->Join(other_ref->ref_key_); auto ref_key = ref_key_->Join(other_ref->ref_key_);
auto ref = ref_->Join(other_ref->ref()); auto ref = AbstractTensor::Join(other_ref->ref())->cast<AbstractTensorPtr>();
return std::make_shared<AbstractRef>(ref_key, ref); return std::make_shared<AbstractRef>(ref_key, ref);
} }
std::string AbstractRef::ToString() const { std::string AbstractRef::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
buffer << type_name() << "(" buffer << type_name() << "("
<< "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString(); << "key: " << ref_key_->ToString() << " ref_value: " << AbstractTensor::ToString();
if (need_cast_) {
buffer << " cast to: " << target_type_->ToString();
}
auto value = GetValueTrack(); auto value = GetValueTrack();
if (value) { if (value) {
buffer << ", value: " << value->ToString(); buffer << ", value: " << value->ToString();

View File

@ -284,11 +284,9 @@ class AbstractTensor : public AbstractUndetermined {
AbstractBasePtr Clone() const override; AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override; AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr BroadenWithShape() const; AbstractBasePtr BroadenWithShape() const;
AbstractBasePtr Join(const AbstractBasePtr &other) final; AbstractBasePtr Join(const AbstractBasePtr &other);
bool operator==(const AbstractTensor &other) const; bool operator==(const AbstractTensor &other) const;
bool operator==(const AbstractBase &other) const override; bool operator==(const AbstractBase &other) const override;
std::string ToString() const override; std::string ToString() const override;
std::size_t hash() const override { std::size_t hash() const override {
auto value = GetValueTrack(); auto value = GetValueTrack();
@ -301,6 +299,9 @@ class AbstractTensor : public AbstractUndetermined {
} }
return hash_sum; return hash_sum;
} }
protected:
bool equal_to(const AbstractTensor &other) const;
}; };
using AbstractTensorPtr = std::shared_ptr<AbstractTensor>; using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
using AbstractTensorPtrList = std::vector<AbstractTensorPtr>; using AbstractTensorPtrList = std::vector<AbstractTensorPtr>;
@ -575,42 +576,42 @@ class AbstractRefKey : public AbstractBase {
}; };
using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>; using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>;
class AbstractRef : public AbstractBase { class AbstractRef : public AbstractTensor {
public: public:
AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast = false, AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value);
TypePtr cast_target = nullptr);
~AbstractRef() override = default; ~AbstractRef() override = default;
MS_DECLARE_PARENT(AbstractRef, AbstractBase) MS_DECLARE_PARENT(AbstractRef, AbstractTensor)
TypePtr BuildType() const override; TypePtr BuildType() const override;
BaseShapePtr BuildShape() const override;
bool operator==(const AbstractRef &other) const; bool operator==(const AbstractRef &other) const;
bool operator==(const AbstractBase &other) const override; bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override { AbstractBasePtr Clone() const override {
return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), need_cast_, target_type_); auto abs_tensor = AbstractTensor::Clone()->cast<AbstractTensorPtr>();
if (abs_tensor == nullptr) {
return nullptr;
}
return std::make_shared<AbstractRef>(ref_key_->Clone(), abs_tensor);
} }
std::string ToString() const override; std::string ToString() const override;
inline AbstractBasePtr ref() const { return ref_; } inline AbstractTensorPtr ref() { return shared_from_base<AbstractTensor>(); }
inline AbstractBasePtr ref_key() const { return ref_key_; } inline AbstractBasePtr ref_key() const { return ref_key_; }
inline RefKeyPtr ref_key_value() const { return ref_key_value_; } inline RefKeyPtr ref_key_value() const { return ref_key_value_; }
inline TypePtr target_type() const { return target_type_; }
inline bool need_cast() const { return need_cast_; }
AbstractBasePtr Broaden(uint8_t config = 0) const override { AbstractBasePtr Broaden(uint8_t config = 0) const override {
// always broaden for ref // always broaden for ref
return std::make_shared<AbstractRef>(ref_key_->Broaden(config), ref_->Broaden(), need_cast_, target_type_); auto abs_tensor = AbstractTensor::Broaden()->cast<AbstractTensorPtr>();
if (abs_tensor == nullptr) {
return nullptr;
}
return std::make_shared<AbstractRef>(ref_key_->Broaden(config), abs_tensor);
} }
AbstractBasePtr Join(const AbstractBasePtr &other) override; AbstractBasePtr Join(const AbstractBasePtr &other) override;
std::size_t hash() const override { std::size_t hash() const override {
return ref_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^ return AbstractTensor::hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^
} }
private: private:
AbstractBasePtr ref_key_; AbstractBasePtr ref_key_;
AbstractBasePtr ref_;
// For mix presicion, only float type need to cast to float16 of float32
bool need_cast_;
TypePtr target_type_;
// cache for ref_key after build value, when value is null, return nullptr. // cache for ref_key after build value, when value is null, return nullptr.
RefKeyPtr ref_key_value_; RefKeyPtr ref_key_value_;
}; };

View File

@ -113,17 +113,8 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
<< "."; << ".";
} }
TypePtr type = args_spec_list[0]->GetTypeTrack(); auto tensor = args_spec_list[1]->cast<abstract::AbstractTensorPtr>();
ValuePtr tensor_target_v = args_spec_list[2]->BuildValue(); return std::make_shared<AbstractRef>(args_spec_list[0], tensor);
if (type->type_id() != kObjectTypeRefKey) {
MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString();
}
auto need_cast = !tensor_target_v->isa<None>();
if (need_cast && !tensor_target_v->isa<Type>()) {
MS_LOG(EXCEPTION) << "Third input of make_ref should be a Type but a " << tensor_target_v->ToString();
}
TypePtr cast_target = tensor_target_v->cast<TypePtr>();
return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], need_cast, cast_target);
} }
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,

View File

@ -88,6 +88,17 @@ std::string Parameter::DebugString(int recursive_level) const {
return buffer.str(); return buffer.str();
} }
ParamInfoPtr Parameter::param_info() const {
if (!has_default()) {
return nullptr;
}
auto tensor = default_param()->cast<tensor::MetaTensorPtr>();
if (tensor == nullptr || !tensor->is_parameter()) {
return nullptr;
}
return tensor->param_info();
}
std::string ValueNode::ToString() const { std::string ValueNode::ToString() const {
MS_EXCEPTION_IF_NULL(value_); MS_EXCEPTION_IF_NULL(value_);
if (value_->isa<FuncGraph>()) { if (value_->isa<FuncGraph>()) {

View File

@ -75,7 +75,7 @@ using VarPtr = std::shared_ptr<Var>;
class AnfIrVisitor; class AnfIrVisitor;
class ParamInfo; class ParamInfo;
using ParamValuePtr = std::shared_ptr<ParamInfo>; using ParamInfoPtr = std::shared_ptr<ParamInfo>;
// AnfNode is the basic class of the IR definition derived from Base. // AnfNode is the basic class of the IR definition derived from Base.
// Only two types of nodes are derived: CNode and ANode. // Only two types of nodes are derived: CNode and ANode.
@ -288,6 +288,7 @@ class Parameter : public ANode {
has_default_ = true; has_default_ = true;
} }
ValuePtr default_param() const { return default_param_; } ValuePtr default_param() const { return default_param_; }
ParamInfoPtr param_info() const;
bool operator==(const AnfNode &other) const override { bool operator==(const AnfNode &other) const override {
if (!other.isa<Parameter>()) { if (!other.isa<Parameter>()) {

View File

@ -94,175 +94,6 @@ bool Slice::operator==(const Type &other) const {
std::string Slice::DumpText() const { return ToString(); } std::string Slice::DumpText() const { return ToString(); }
TypePtr UndeterminedType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<UndeterminedType>();
}
return std::make_shared<UndeterminedType>(element_type_->DeepCopy());
}
std::string UndeterminedType::ToReprString() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->ToReprString() + "]";
}
std::string UndeterminedType::ToString() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->ToString() + "]";
}
std::string UndeterminedType::DumpText() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->DumpText() + "]";
}
bool UndeterminedType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const UndeterminedType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
TypePtr TensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<TensorType>();
}
return std::make_shared<TensorType>(element_type_->DeepCopy());
}
std::string TensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "tensor";
}
return "tensor[" + element_type_->ToReprString() + "]";
}
std::string TensorType::ToString() const {
if (element_type_ == nullptr) {
return "Tensor";
}
return "Tensor[" + element_type_->ToString() + "]";
}
std::string TensorType::DumpText() const {
if (element_type_ == nullptr) {
return "Tensor";
}
return "Tensor(" + element_type_->DumpText() + ")";
}
bool TensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const TensorType &>(other).element_type_;
// When element_type_ = nullptr, which means any type of Array.
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
TypePtr RowTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<RowTensorType>();
}
return std::make_shared<RowTensorType>(element_type_->DeepCopy());
}
std::string RowTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "RowTensor";
}
return "RowTensor[" + element_type_->ToReprString() + "]";
}
std::string RowTensorType::ToString() const {
if (element_type_ == nullptr) {
return "RowTensor";
}
return "RowTensor[" + element_type_->ToString() + "]";
}
std::string RowTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "RowTensor";
}
return "RowTensor[" + element_type_->DumpText() + "]";
}
bool RowTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const RowTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
TypePtr SparseTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<SparseTensorType>();
}
return std::make_shared<SparseTensorType>(element_type_->DeepCopy());
}
std::string SparseTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToReprString() + "]";
}
std::string SparseTensorType::ToString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToString() + "]";
}
std::string SparseTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->DumpText() + "]";
}
bool SparseTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
Function::Function() : Object(kObjectTypeFunction) { Function::Function() : Object(kObjectTypeFunction) {
args_ = std::vector<TypePtr>(); args_ = std::vector<TypePtr>();
retval_ = nullptr; retval_ = nullptr;
@ -372,4 +203,8 @@ std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> proble
os << problem->ToString(); os << problem->ToString();
return os; return os;
} }
const TypePtr kTensorTypeFP16 = std::make_shared<TensorType>(std::make_shared<Float>(16));
const TypePtr kTensorTypeFP32 = std::make_shared<TensorType>(std::make_shared<Float>(32));
} // namespace mindspore } // namespace mindspore

View File

@ -32,10 +32,11 @@
#include "ir/named.h" #include "ir/named.h"
#include "ir/dtype/type.h" #include "ir/dtype/type.h"
#include "ir/dtype/ref.h"
#include "ir/dtype/number.h" #include "ir/dtype/number.h"
#include "ir/dtype/container.h" #include "ir/dtype/container.h"
#include "ir/dtype/empty.h" #include "ir/dtype/empty.h"
#include "ir/dtype/tensor_type.h"
#include "ir/dtype/ref.h"
/* namespace to support intermediate representation definition */ /* namespace to support intermediate representation definition */
namespace mindspore { namespace mindspore {
@ -108,98 +109,6 @@ class Slice : public Object {
}; };
using SlicePtr = std::shared_ptr<Slice>; using SlicePtr = std::shared_ptr<Slice>;
class UndeterminedType : public Object {
public:
UndeterminedType() : Object(kObjectTypeUndeterminedType) {}
explicit UndeterminedType(const TypePtr &ele)
: Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {}
~UndeterminedType() override = default;
MS_DECLARE_PARENT(UndeterminedType, Object)
TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
protected:
TypePtr element_type_;
};
using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>;
class TensorType : public Object {
public:
TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {}
explicit TensorType(const TypePtr &ele)
: Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~TensorType() override = default;
MS_DECLARE_PARENT(TensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using TensorTypePtr = std::shared_ptr<TensorType>;
class RowTensorType : public Object {
public:
RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {}
explicit RowTensorType(const TypePtr &ele)
: Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~RowTensorType() override = default;
MS_DECLARE_PARENT(RowTensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeRowTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using RowTensorTypePtr = std::shared_ptr<RowTensorType>;
class SparseTensorType : public Object {
public:
SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
explicit SparseTensorType(const TypePtr &ele)
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~SparseTensorType() override = default;
MS_DECLARE_PARENT(SparseTensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
class Function : public Object { class Function : public Object {
public: public:
Function(); Function();
@ -353,6 +262,9 @@ extern const TypePtr kDict;
extern const TypePtr kSlice; extern const TypePtr kSlice;
extern const TypePtr kKeyword; extern const TypePtr kKeyword;
extern const TypePtr kTensorType; extern const TypePtr kTensorType;
extern const TypePtr kTensorTypeFP16;
extern const TypePtr kTensorTypeFP32;
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CORE_IR_DTYPE_H_ #endif // MINDSPORE_CORE_IR_DTYPE_H_

View File

@ -68,6 +68,8 @@ class Number : public Object {
const int nbits_; const int nbits_;
}; };
using NumberPtr = std::shared_ptr<Number>;
// Bool // Bool
class Bool : public Number { class Bool : public Number {
public: public:

View File

@ -19,15 +19,15 @@
#include <cstdlib> #include <cstdlib>
#include <algorithm> #include <algorithm>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "ir/dtype/tensor_type.h"
namespace mindspore { namespace mindspore {
TypePtr RefType::DeepCopy() const { TypePtr RefType::DeepCopy() const {
if (IsGeneric()) { if (IsGeneric()) {
return std::make_shared<RefType>(); return std::make_shared<RefType>();
} else { } else {
auto subtype = subtype_->DeepCopy(); auto subtype = TensorType::DeepCopy()->cast<TensorTypePtr>();
auto subtype_origin = subtype_origin_->DeepCopy(); return std::make_shared<RefType>(subtype);
return std::make_shared<RefType>(subtype, subtype_origin);
} }
} }
@ -39,7 +39,7 @@ std::string RefType::DumpText() const {
buffer << "Ref"; buffer << "Ref";
} else { } else {
buffer << "Ref["; buffer << "Ref[";
buffer << subtype_->DumpText() << "]"; buffer << TensorType::DumpText() << "]";
} }
return buffer.str(); return buffer.str();
} }

View File

@ -17,21 +17,13 @@
#ifndef MINDSPORE_CORE_IR_DTYPE_REF_H_ #ifndef MINDSPORE_CORE_IR_DTYPE_REF_H_
#define MINDSPORE_CORE_IR_DTYPE_REF_H_ #define MINDSPORE_CORE_IR_DTYPE_REF_H_
#include <cstddef>
#include <iostream>
#include <initializer_list>
#include <map>
#include <memory> #include <memory>
#include <utility>
#include <sstream>
#include <string> #include <string>
#include <vector>
#include <type_traits>
#include <unordered_map>
#include <algorithm>
#include "base/base.h" #include "base/base.h"
#include "ir/named.h" #include "ir/named.h"
#include "ir/dtype/type.h" #include "ir/dtype/type.h"
#include "ir/dtype/tensor_type.h"
namespace mindspore { namespace mindspore {
// TypeRefKey type // TypeRefKey type
@ -48,23 +40,16 @@ class RefKeyType : public Object {
}; };
// TypeRef type // TypeRef type
class RefType : public Object { class RefType : public TensorType {
public: public:
RefType() : Object(kObjectTypeRef) {} RefType() : TensorType() {}
RefType(const TypePtr &subtype, const TypePtr &subtype_origin) explicit RefType(const TensorTypePtr &subtype) : TensorType(subtype->element()) {}
: Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {}
~RefType() override {} ~RefType() override {}
MS_DECLARE_PARENT(RefType, Object) MS_DECLARE_PARENT(RefType, TensorType)
TypePtr subtype() const { return subtype_; }
TypeId generic_type_id() const override { return kObjectTypeRef; }
TypePtr DeepCopy() const override; TypePtr DeepCopy() const override;
std::string ToString() const override; std::string ToString() const override;
std::string DumpText() const override; std::string DumpText() const override;
private:
TypePtr subtype_;
TypePtr subtype_origin_;
}; };
using RefTypePtr = std::shared_ptr<RefType>; using RefTypePtr = std::shared_ptr<RefType>;

View File

@ -0,0 +1,194 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/dtype/tensor_type.h"
#include <string>
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
namespace mindspore {
TypePtr UndeterminedType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<UndeterminedType>();
}
return std::make_shared<UndeterminedType>(element_type_->DeepCopy());
}
std::string UndeterminedType::ToReprString() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->ToReprString() + "]";
}
std::string UndeterminedType::ToString() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->ToString() + "]";
}
std::string UndeterminedType::DumpText() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->DumpText() + "]";
}
bool UndeterminedType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const UndeterminedType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
TypePtr TensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<TensorType>();
}
return std::make_shared<TensorType>(element_type_->DeepCopy());
}
std::string TensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "tensor";
}
return "tensor[" + element_type_->ToReprString() + "]";
}
std::string TensorType::ToString() const {
if (element_type_ == nullptr) {
return "Tensor";
}
return "Tensor[" + element_type_->ToString() + "]";
}
std::string TensorType::DumpText() const {
if (element_type_ == nullptr) {
return "Tensor";
}
return "Tensor(" + element_type_->DumpText() + ")";
}
bool TensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const TensorType &>(other).element_type_;
// When element_type_ = nullptr, which means any type of Array.
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
TypePtr RowTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<RowTensorType>();
}
return std::make_shared<RowTensorType>(element_type_->DeepCopy());
}
std::string RowTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "RowTensor";
}
return "RowTensor[" + element_type_->ToReprString() + "]";
}
std::string RowTensorType::ToString() const {
if (element_type_ == nullptr) {
return "RowTensor";
}
return "RowTensor[" + element_type_->ToString() + "]";
}
std::string RowTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "RowTensor";
}
return "RowTensor[" + element_type_->DumpText() + "]";
}
bool RowTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const RowTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
TypePtr SparseTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<SparseTensorType>();
}
return std::make_shared<SparseTensorType>(element_type_->DeepCopy());
}
std::string SparseTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToReprString() + "]";
}
std::string SparseTensorType::ToString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToString() + "]";
}
std::string SparseTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->DumpText() + "]";
}
bool SparseTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
} // namespace mindspore

View File

@ -0,0 +1,132 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
#define MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
#include <cstddef>
#include <iostream>
#include <initializer_list>
#include <map>
#include <memory>
#include <utility>
#include <sstream>
#include <string>
#include <vector>
#include <type_traits>
#include <unordered_map>
#include <algorithm>
#include "base/base.h"
#include "ir/named.h"
#include "ir/dtype/type.h"
namespace mindspore {
class UndeterminedType : public Object {
public:
UndeterminedType() : Object(kObjectTypeUndeterminedType) {}
explicit UndeterminedType(const TypePtr &ele)
: Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {}
~UndeterminedType() override = default;
MS_DECLARE_PARENT(UndeterminedType, Object)
TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
protected:
TypePtr element_type_;
};
using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>;
class TensorType : public Object {
public:
TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {}
explicit TensorType(const TypePtr &ele)
: Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~TensorType() override = default;
MS_DECLARE_PARENT(TensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using TensorTypePtr = std::shared_ptr<TensorType>;
class RowTensorType : public Object {
public:
RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {}
explicit RowTensorType(const TypePtr &ele)
: Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~RowTensorType() override = default;
MS_DECLARE_PARENT(RowTensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeRowTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using RowTensorTypePtr = std::shared_ptr<RowTensorType>;
class SparseTensorType : public Object {
public:
SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
explicit SparseTensorType(const TypePtr &ele)
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~SparseTensorType() override = default;
MS_DECLARE_PARENT(SparseTensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_

View File

@ -332,14 +332,11 @@ class FuncGraph : public FuncGraphBase {
const std::vector<AnfNodePtr> &paramter_obj_nodes() const { return paramter_obj_nodes_; } const std::vector<AnfNodePtr> &paramter_obj_nodes() const { return paramter_obj_nodes_; }
void add_parameter_obj_node(const AnfNodePtr &p); void add_parameter_obj_node(const AnfNodePtr &p);
std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; }
std::unordered_map<std::string, ValuePtr> attrs_; std::unordered_map<std::string, ValuePtr> attrs_;
std::vector<BaseShapePtr> joined_shapes_; std::vector<BaseShapePtr> joined_shapes_;
std::unordered_map<std::string, FuncGraphTransform> transforms_; std::unordered_map<std::string, FuncGraphTransform> transforms_;
// parameter default value // parameter default value
std::map<std::string, AnfNodePtr> parameter_default_value_; std::map<std::string, AnfNodePtr> parameter_default_value_;
std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_;
size_t seen_; size_t seen_;
std::list<CNodePtr> GetOrderedCnodes(); std::list<CNodePtr> GetOrderedCnodes();

View File

@ -23,6 +23,7 @@
#include <string> #include <string>
#include "base/base.h" #include "base/base.h"
#include "ir/param_info.h"
#include "ir/dtype.h" #include "ir/dtype.h"
#include "utils/convert_utils_base.h" #include "utils/convert_utils_base.h"
#include "utils/hashing.h" #include "utils/hashing.h"
@ -163,6 +164,15 @@ class MetaTensor : public Value {
return false; return false;
} }
} }
// Get tensor's param_info info.
ParamInfoPtr param_info() const { return param_info_; }
bool is_parameter() const { return is_parameter_; }
// Set tensor's param_info info.
void set_param_info(const ParamInfoPtr &param_info) {
is_parameter_ = true;
param_info_ = param_info;
}
protected: protected:
// brief Data type of the tensor. // brief Data type of the tensor.
@ -184,6 +194,9 @@ class MetaTensor : public Value {
// //
// Includes the format and data type of a tensor on device. // Includes the format and data type of a tensor on device.
DeviceInfo device_info_; DeviceInfo device_info_;
bool is_parameter_{false};
ParamInfoPtr param_info_{nullptr};
}; };
using MetaTensorPtr = std::shared_ptr<MetaTensor>; using MetaTensorPtr = std::shared_ptr<MetaTensor>;

View File

@ -34,7 +34,16 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() {
} }
auto tensor_shape = tens->shape(); auto tensor_shape = tens->shape();
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
abs_tensor->set_value(shared_from_base<MetaTensor>());
// if is parameter always no value.
if (is_parameter()) {
auto param_name = param_info()->name();
auto ref_key = std::make_shared<RefKey>(param_name);
auto abs_ref_key = ref_key->ToAbstract();
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
} else {
abs_tensor->set_value(shared_from_base<MetaTensor>());
}
return abs_tensor; return abs_tensor;
} }

View File

@ -62,6 +62,21 @@ class Named : public Value {
}; };
using NamedPtr = std::shared_ptr<Named>; using NamedPtr = std::shared_ptr<Named>;
struct NamedHasher {
std::size_t operator()(NamedPtr const &name) const {
std::size_t hash = name->Hash();
return hash;
}
};
struct NamedEqual {
bool operator()(NamedPtr const &t1, NamedPtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);
return *t1 == *t2;
}
};
class None : public Named { class None : public Named {
public: public:
None() : Named("None") {} None() : Named("None") {}

View File

@ -21,10 +21,13 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "ir/anf.h"
#include "ir/tensor.h" #include "ir/dtype.h"
namespace mindspore { namespace mindspore {
class ParamInfo;
using ParamInfoPtr = std::shared_ptr<ParamInfo>;
class ParamInfo { class ParamInfo {
public: public:
ParamInfo() {} ParamInfo() {}
@ -55,7 +58,7 @@ class ParamInfo {
int32_t cloned_index() const { return cloned_index_; } int32_t cloned_index() const { return cloned_index_; }
// Make a cloned parameter and update clone info. // Make a cloned parameter and update clone info.
ParamValuePtr Clone() { ParamInfoPtr Clone() {
static std::atomic<int32_t> parameter_cloned_index{1}; static std::atomic<int32_t> parameter_cloned_index{1};
int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed); int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed);
auto clone = std::make_shared<ParamInfo>(*this); auto clone = std::make_shared<ParamInfo>(*this);

View File

@ -467,6 +467,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
} }
return *this; return *this;
} }
abstract::AbstractBasePtr Tensor::ToAbstract() { abstract::AbstractBasePtr Tensor::ToAbstract() {
auto tens = shared_from_base<Tensor>(); auto tens = shared_from_base<Tensor>();
auto dtype = tens->Dtype(); auto dtype = tens->Dtype();
@ -475,7 +476,15 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
} }
auto tensor_shape = tens->shape(); auto tensor_shape = tens->shape();
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
abs_tensor->set_value(shared_from_base<Tensor>()); // if is parameter always no value.
if (is_parameter()) {
auto param_name = param_info()->name();
auto ref_key = std::make_shared<RefKey>(param_name);
auto abs_ref_key = ref_key->ToAbstract();
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
} else {
abs_tensor->set_value(shared_from_base<Tensor>());
}
return abs_tensor; return abs_tensor;
} }

View File

@ -200,16 +200,6 @@ bool StringImm::operator==(const Value &other) const {
} }
bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; } bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; }
bool RefKey::operator==(const Value &other) const {
if (other.isa<RefKey>()) {
auto other_ = static_cast<const RefKey &>(other);
return *this == other_;
} else {
return false;
}
}
bool RefKey::operator==(const RefKey &other) const { return tag_ == other.tag_; }
bool AnyValue::operator==(const Value &other) const { bool AnyValue::operator==(const Value &other) const {
if (other.isa<AnyValue>()) { if (other.isa<AnyValue>()) {
return true; return true;

View File

@ -224,28 +224,21 @@ using StringImmPtr = std::shared_ptr<StringImm>;
IMM_TRAITS(StringImmPtr, std::string) IMM_TRAITS(StringImmPtr, std::string)
IMM_TRAITS(StringImmPtr, const char *) IMM_TRAITS(StringImmPtr, const char *)
class RefKey : public Value { class RefKey : public Named {
public: public:
explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {} explicit RefKey(const std::string &tag) : Named(tag) {}
~RefKey() override = default; ~RefKey() override = default;
MS_DECLARE_PARENT(RefKey, Value) MS_DECLARE_PARENT(RefKey, Named)
std::size_t hash() const override { return hash_; } const std::string &tag() const { return name(); }
const std::string &tag() const { return tag_; }
bool operator==(const Value &other) const override;
bool operator==(const RefKey &other) const;
abstract::AbstractBasePtr ToAbstract() override; abstract::AbstractBasePtr ToAbstract() override;
std::string ToString() const override { return "RefKey[" + tag_ + "]"; } std::string ToString() const override { return "RefKey[" + name() + "]"; }
std::string DumpText() const override { std::string DumpText() const override {
std::ostringstream oss; std::ostringstream oss;
oss << "RefKey[\"" << tag_ << "\"]"; oss << "RefKey[\"" << name() << "\"]";
return oss.str(); return oss.str();
} }
private:
std::string tag_;
std::size_t hash_ = 0;
}; };
using RefKeyPtr = std::shared_ptr<RefKey>; using RefKeyPtr = std::shared_ptr<RefKey>;

View File

@ -43,6 +43,8 @@ if(BUILD_CONVERTER)
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/scope.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/scope.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value_extends.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value_extends.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/ref.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/tensor_type.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/container.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/container.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/empty.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/empty.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/number.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/number.cc

View File

@ -29,6 +29,8 @@ set(ANF_SRC
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/scope.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/scope.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value_extends.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value_extends.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/ref.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/tensor_type.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/container.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/container.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/empty.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/empty.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/number.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/number.cc

View File

@ -23,7 +23,7 @@ from ...common import dtype as mstype
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
class Assign(PrimitiveWithInfer): class Assign(Primitive):
""" """
Assign `Parameter` with a value. Assign `Parameter` with a value.

View File

@ -18,7 +18,6 @@
import inspect import inspect
import copy import copy
from mindspore.common.api import _wrap_func from mindspore.common.api import _wrap_func
from mindspore.common import Parameter
from mindspore.common._register_for_tensor import tensor_operator_registry from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore import context from mindspore import context
from .._c_expression import Primitive_, real_run_op, prim_type from .._c_expression import Primitive_, real_run_op, prim_type
@ -410,16 +409,12 @@ def _run_op(obj, op_name, args):
if op_name == "Cast" or obj.update_parameter: if op_name == "Cast" or obj.update_parameter:
cast_args = args cast_args = args
else: else:
cast_args = list() cast_args = args
for arg in args: for idx, arg in enumerate(args):
if isinstance(arg, Parameter): cast_type = getattr(arg, "cast_type", None)
if arg.cast_type: if cast_type:
cast_args.append(cast(arg, arg.cast_type)) cast_args[idx] = cast(arg, cast_type)
else: output = real_run_op(obj, op_name, cast_args)
cast_args.append(arg)
else:
cast_args.append(arg)
output = real_run_op(obj, op_name, tuple(cast_args))
if not output: if not output:
raise RuntimeError("Pynative run op %s failed!" % op_name) raise RuntimeError("Pynative run op %s failed!" % op_name)
if len(output) == 1: if len(output) == 1:

View File

@ -118,26 +118,31 @@ class ControlMixedWhileIf(nn.Cell):
self.var = Parameter(initializer(1, (1), mstype.float32), name="var") self.var = Parameter(initializer(1, (1), mstype.float32), name="var")
def construct(self, x, y, z, c2, c4): def construct(self, x, y, z, c2, c4):
out = self.assign(self.var, c4) out = c4
self.assign(self.var, c4)
while x < c2: while x < c2:
y = self.assign(self.var, c4) y = c4
self.assign(self.var, c4)
while y < c2 and x < c2: while y < c2 and x < c2:
if 2 * y < c2: if 2 * y < c2:
y = y + 2 y = y + 2
else: else:
y = y + 1 y = y + 1
out = out + y out = out + y
z = self.assign(self.var, c4) z = c4
self.assign(self.var, c4)
while z < c2: while z < c2:
z = z + 1 z = z + 1
out = out + z out = out + z
x = x + 1 x = x + 1
out = out + x out = out + x
while x < 2 * c2: while x < 2 * c2:
y = self.assign(self.var, c4) y = c4
self.assign(self.var, c4)
x = x + 1 x = x + 1
while y < c2: while y < c2:
z = self.assign(self.var, c4) z = c4
self.assign(self.var, c4)
while z < c2: while z < c2:
z = z + 1 z = z + 1
if x < c2: if x < c2:

View File

@ -27,6 +27,7 @@ import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common.api import ms_function, _executor from mindspore.common.api import ms_function, _executor
from mindspore.ops._grad.grad_base import bprop_getters from mindspore.ops._grad.grad_base import bprop_getters
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
@ -254,3 +255,60 @@ def test_bprop_with_wrong_output_shape():
net = BpropWithWrongOutputShapeCell() net = BpropWithWrongOutputShapeCell()
net.set_grad() net.set_grad()
grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32))) grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))
class AssignWhenInsertGrad(nn.Cell):
""" NetWithNDarray definition """
def __init__(self):
super(AssignWhenInsertGrad, self).__init__()
self.gather = P.GatherV2()
self.damping = Tensor(np.array([0.03, 0.03]).astype(np.float32))
self.cov_step = ms.Parameter(0, name="cov_step", requires_grad=False)
self.freq = Tensor(278, ms.int32)
self.getG = P.InsertGradientOf(self.save_gradient)
def save_gradient(self, dout):
self.cov_step = self.cov_step + self.freq
return dout
def construct(self, x):
self.gather(self.damping, self.cov_step, 0)
out = P.ReLU()(x)
out = self.getG(out)
return out
grad_all = C.GradOperation('get_all', get_all=True)
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
out = self.net(*inputs)
return out, grad_all(self.net)(*inputs)
def test_assign_in_insert_grad():
context.set_context(mode=context.GRAPH_MODE)
net = AssignWhenInsertGrad().to_float(ms.float16)
input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
net_back = GradNet(net)
net_back(ms.Tensor(input_data))
class Assign(nn.Cell):
""" NetWithNDarray definition """
def __init__(self):
super(Assign, self).__init__()
self.cov_step = ms.Parameter(0.0, name="cov_step", requires_grad=False)
def construct(self, x):
self.cov_step = self.cov_step + x
return self.cov_step
def test_assign():
context.set_context(mode=context.GRAPH_MODE)
net = Assign()
input_data = ms.Tensor(np.array(1).astype(np.int32))
net_back = GradNet(net)
net_back(input_data)

View File

@ -0,0 +1,144 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_cont_break """
import numpy as np
import mindspore as ms
from mindspore import Tensor, context, nn, ms_function
from mindspore.nn import Cell
from mindspore.ops import operations as P
class WhileSubGraphParam(Cell):
def __init__(self):
super().__init__()
self.update = ms.Parameter(Tensor(1, ms.float32), "update")
def construct(self, x, y, z):
out1 = z
while x < y:
self.update = self.update + 1
out1 = out1 + 1
x = x + 1
return out1, self.update
def test_while_loop_phi():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(0, ms.float32)
y = Tensor(10, ms.float32)
z = Tensor(100, ms.float32)
net = WhileSubGraphParam()
net(x, y, z)
class WhileSubGraphParam2(Cell):
def __init__(self):
super().__init__()
self.update = ms.Parameter(Tensor(1, ms.float32), "update")
def construct(self, x, y, z):
out1 = z
i = self.update
while x < y:
i = i + 1
out1 = out1 + 1
x = x + 1
return out1, self.update
def test_while_loop_phi_2():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(0, ms.float32)
y = Tensor(10, ms.float32)
z = Tensor(100, ms.float32)
net = WhileSubGraphParam2()
net(x, y, z)
class WhileSubGraphParam3(Cell):
def __init__(self, initial_input_x):
super().__init__()
self.initial_input_x = initial_input_x
self.X = ms.Parameter(initial_input_x, name="parameter_x")
self.Y = ms.Parameter(self.initial_input_x, name="parameter_y")
def construct(self):
a = 0
while a < 3:
self.X = self.X + self.Y
a += 1
return self.X
def test_while_loop_phi_3():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(0, ms.float32)
net = WhileSubGraphParam3(x)
net()
class ControlMixedWhileIf(nn.Cell):
def __init__(self):
super().__init__()
self.assign = P.Assign()
self.var = ms.Parameter(ms.Tensor([1], ms.float32), name="var")
@ms_function
def construct(self, x, y, z, c2, c4):
out = self.assign(self.var, c4)
while x < c2:
y = self.assign(self.var, c4)
while y < c2 and x < c2:
if 2 * y < c2:
y = y + 2
else:
y = y + 1
out = out + y
z = self.assign(self.var, c4)
while z < c2:
z = z + 1
out = out + z
x = x + 1
out = out + x
while x < 2 * c2:
y = self.assign(self.var, c4)
x = x + 1
while y < c2:
z = self.assign(self.var, c4)
while z < c2:
z = z + 1
if x < c2:
y = y - 1
else:
y = y + 1
out = out + z
out = out + y
out = out + x
return out
def test_mixed_while_if():
context.set_context(mode=context.PYNATIVE_MODE)
x = np.array(2).astype(np.int32)
y = np.array(14).astype(np.int32)
z = np.array(1).astype(np.int32)
c2 = Tensor([14], ms.int32)
c4 = Tensor([0], ms.int32)
net = ControlMixedWhileIf()
output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4)
expect = np.array(3318).astype(np.int32)
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
context.set_context(mode=context.GRAPH_MODE)

View File

@ -22,7 +22,13 @@ from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
from .vm_interface import vm from .vm_interface import vm
# pylint: disable=unused-argument # pylint: disable=unused-argument
@vm_impl_getters.register(P.Assign)
def vm_impl_assign(self):
"""Generate vm_impl function for Assign"""
def vm_impl(x, value):
x.assign_value(value)
return x
return vm_impl
@vm_impl_getters.register(P.ExpandDims) @vm_impl_getters.register(P.ExpandDims)
def vm_impl_expand_dims(self): def vm_impl_expand_dims(self):