change base class of ref to tensor in cpp

This commit is contained in:
Wei Luning 2020-08-24 15:55:26 +08:00
parent 01aa83388e
commit 24a10225cf
47 changed files with 812 additions and 720 deletions

View File

@ -185,14 +185,23 @@ class Validator:
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
@staticmethod
def check_subclass(arg_name, type_, template_type, prim_name):
def check_subclass(arg_name, type_, template_types, prim_name):
"""Checks whether some type is subclass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
if not isinstance(template_types, Iterable):
template_types = (template_types,)
hit = False
for template_type in template_types:
if isinstance(template_type, mstype.Type):
if mstype.issubclass_(type_, template_type):
hit = True
break
elif type_ is template_type:
hit = True
break
if not hit:
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
f' of {",".join((str(x) for x in template_types))}, but got {type_str}.')
@staticmethod
def check_const_input(arg_name, arg_value, prim_name):
@ -206,13 +215,7 @@ class Validator:
def _check_tensor_type(arg):
arg_key, arg_val = arg
elem_type = arg_val
if not elem_type in valid_values:
type_names = []
for t in valid_values:
type_names.append(str(t))
types_info = '[' + ', '.join(type_names) + ']'
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},'
f' but got {elem_type}.')
Validator.check_subclass(arg_key, elem_type, valid_values, prim_name)
return (arg_key, elem_type)
def _check_types_same(arg1, arg2):
@ -335,12 +338,6 @@ class Validator:
class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@staticmethod
def equal(arg_name, arg_value, cond_str, cond):
"""Judging valid value."""
if not cond:
raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.')
@staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
"""This method is only used for check int values, since when compare float values,
@ -360,27 +357,6 @@ class ParamValidator:
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_shape_length(arg_name, arg_value, value, rel):
"""Shape length judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}')
return arg_value
@staticmethod
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel):
"""This method is only used for check int values,
since when compare float values, we need consider float error."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_isinstance(arg_name, arg_value, classes):
"""Check arg isinstance of classes"""
@ -388,33 +364,6 @@ class ParamValidator:
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
return arg_value
@staticmethod
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel):
"""Is it necessary to consider error when comparing float values."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_subclass(arg_name, type_, template_type, with_type_of=True):
"""Check whether some type is subclass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
@staticmethod
def check_args_tensor(args):
"""Check whether args are all tensor."""
if not isinstance(args, dict):
raise TypeError("The args should be a dict.")
for arg, value in args.items():
ParamValidator.check_subclass(arg, value, mstype.tensor)
@staticmethod
def check_bool(arg_name, arg_value):
"""Check arg isinstance of bool"""
@ -442,113 +391,6 @@ class ParamValidator:
return arg_value
raise_error_msg()
@staticmethod
def check_typename(arg_name, arg_type, valid_types):
"""Does it contain the _name_ attribute."""
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()
if arg_type in valid_types:
return arg_type
type_names = [get_typename(t) for t in valid_types]
if len(valid_types) == 1:
raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise ValueError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
@staticmethod
def check_string(arg_name, arg_value, valid_values):
"""String type judgment."""
if isinstance(arg_value, str) and arg_value in valid_values:
return arg_value
if len(valid_values) == 1:
raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},'
f' but got {arg_value}.')
raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},'
f' but got {arg_value}.')
@staticmethod
def check_type_same(args, valid_values):
"""Determine whether the types are the same."""
name = list(args.keys())[0]
value = list(args.values())[0]
if isinstance(value, type(mstype.tensor)):
value = value.element_type()
for arg_name, arg_value in args.items():
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
if arg_value not in valid_values:
raise TypeError(f'The `{arg_name}` should be in {valid_values},'
f' but `{arg_name}` is {arg_value}.')
if arg_value != value:
raise TypeError(f'`{arg_name}` should be same as `{name}`,'
f' but `{arg_name}` is {arg_value}, `{name}` is {value}.')
@staticmethod
def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type):
"""Determine whether the types of two variables are the same."""
if arg1_type != arg2_type:
raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.')
@staticmethod
def check_value_on_integer(arg_name, arg_value, value, rel):
"""Judging integer type."""
rel_fn = Rel.get_fns(rel)
type_match = isinstance(arg_value, int)
if type_match and (not rel_fn(arg_value, value)):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_param_equal(param1_name, param1_value, param2_name, param2_value):
"""Judging the equality of parameters."""
if param1_value != param2_value:
raise ValueError(f"`{param1_name}` must equal `{param2_name}`,"
f" but got `{param1_name}` = {param1_value},"
f" `{param2_name}` = {param2_value}.")
@staticmethod
def check_const_input(arg_name, arg_value):
"""Check valid value."""
if arg_value is None:
raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.')
@staticmethod
def check_float_positive(arg_name, arg_value):
"""Float type judgment."""
if isinstance(arg_value, float):
if arg_value > 0:
return arg_value
raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.")
raise TypeError(f"`{arg_name}` must be float!")
@staticmethod
def check_pad_value_by_mode(op_name, pad_mode, padding):
"""Validate value of padding according to pad_mode"""
if pad_mode != 'pad' and padding != 0:
raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.")
return padding
@staticmethod
def check_empty_shape_input(arg_name, arg_value):
"""Check zeros value."""
if 0 in arg_value:
raise ValueError(f"Input `{arg_name}` cannot be empty.")
@staticmethod
def check_scalar_shape_input(arg_name, arg_value):
"""Check scalar shape input."""
if arg_value != []:
raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}")
def check_int(input_param):
"""Int type judgment."""

View File

@ -592,11 +592,6 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_
return get_single_type((*tuple_ptr)[output_idx]);
};
TypePtr type_ptr = node->Type();
if (type_ptr->isa<RefType>()) {
auto ref_type_ptr = type_ptr->cast<RefTypePtr>();
MS_EXCEPTION_IF_NULL(ref_type_ptr);
return get_tuple_type(ref_type_ptr->subtype(), output_idx);
}
return get_tuple_type(type_ptr, output_idx);
}

View File

@ -20,6 +20,7 @@
#include "abstract/abstract_value.h"
#include "ir/anf.h"
#include "ir/dtype.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "frontend/operator/cc_implementations.h"
@ -43,15 +44,15 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) {
return empty;
}
void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list,
const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *const op_inputs) {
void ProcessDefault(const std::string &func_name, size_t actual_param_number, const std::vector<Signature> &signature,
bool has_var, std::vector<AnfNodePtr> *const op_inputs) {
std::size_t sig_size = signature.size();
auto positional_size = sig_size;
if (has_var) {
positional_size = sig_size - 1;
}
if (args_spec_list.size() < positional_size) {
for (size_t i = args_spec_list.size(); i < sig_size; ++i) {
if (actual_param_number < positional_size) {
for (size_t i = actual_param_number; i < sig_size; ++i) {
auto default_value = signature[i].default_value;
if (default_value == nullptr) {
MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
@ -67,23 +68,11 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
*max_type_number = type_number;
}
bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id,
bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, bool is_write, TypeId *arg_type_id,
TypeId *arg_type = nullptr) {
if (arg_value->isa<abstract::AbstractRef>()) {
auto ref = arg_value->cast<abstract::AbstractRefPtr>();
arg_value = ref->ref();
if (!is_write && ref->need_cast()) {
auto tensor_type = ref->target_type();
*arg_type_id = tensor_type->type_id();
if (arg_type != nullptr) {
*arg_type = kObjectTypeTensorType;
}
return true;
}
}
if (arg_value->isa<abstract::AbstractTensor>()) {
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
auto tensor_type = tensor->element()->BuildType();
if (arg_type_origin->isa<TensorType>()) {
auto tensor = arg_type_origin->cast<TensorTypePtr>();
auto tensor_type = tensor->element();
MS_EXCEPTION_IF_NULL(tensor_type);
*arg_type_id = tensor_type->type_id();
if (arg_type != nullptr) {
@ -91,9 +80,8 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
}
return true;
}
if (arg_value->isa<abstract::AbstractScalar>()) {
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
auto scalar_type = scalar->BuildType();
if (arg_type_origin->isa<Number>()) {
auto scalar_type = arg_type_origin->cast<NumberPtr>();
MS_EXCEPTION_IF_NULL(scalar_type);
*arg_type_id = scalar_type->type_id();
if (arg_type != nullptr) {
@ -104,7 +92,7 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
return false;
}
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices,
TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t> indices,
const std::set<size_t> &write_indices) {
TypeId max_type_id = kTypeUnknown;
size_t max_type_number = 0;
@ -115,7 +103,7 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
TypeId arg_type_id = kTypeUnknown;
TypeId arg_type = kTypeUnknown;
auto is_write = (write_indices.find(index) != write_indices.end());
if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) {
if (!GetTensorOrScalarTypeInfo(input_types[index], is_write, &arg_type_id, &arg_type)) {
continue;
}
if (arg_type != kObjectTypeTensorType) {
@ -161,8 +149,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
// Get the largest type of index in the same SignatureEnumDType of arguments.
using MaxTypeMap = std::map<SignatureEnumDType, TypeId>;
MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
const abstract::AbstractBasePtrList &args_spec_list, const std::set<size_t> &write_indices) {
MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, const std::vector<TypePtr> &input_types,
const std::set<size_t> &write_indices) {
// record index for signature.dtypes of the same type
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
std::map<SignatureEnumDType, std::vector<size_t>> type_indices;
@ -184,11 +172,8 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
}
bool has_tensor = false;
for (const auto &index : indices) {
AbstractBasePtr arg_value = args_spec_list[index];
if (arg_value->isa<abstract::AbstractRef>()) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
}
if (arg_value->isa<abstract::AbstractTensor>()) {
auto arg_value = input_types[index];
if (arg_value->isa<TensorType>()) {
has_tensor = true;
break;
}
@ -197,7 +182,7 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
(void)dst_type.insert(std::make_pair(type, kTypeUnknown));
continue;
}
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices)));
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(input_types, indices, write_indices)));
}
return dst_type;
}
@ -211,7 +196,7 @@ AnfNodePtr DoCast(const AnfNodePtr &param, const TypeId &type_id, const FuncGrap
}
void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph,
const std::vector<TypePtr> &input_types, const FuncGraphPtr &graph,
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) {
std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
@ -221,9 +206,9 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
return;
}
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices);
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, input_types, write_indices);
// Identify which arg requires auto cast
for (size_t i = 0; i < args_spec_list.size(); ++i) {
for (size_t i = 0; i < input_types.size(); ++i) {
auto it = dst_type.find(dtypes[i]);
if (it == dst_type.end() || it->second == kTypeUnknown) {
continue;
@ -232,7 +217,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
auto is_write = (rw_it != write_indices.end());
TypeId arg_type_id = kTypeUnknown;
AbstractBasePtr arg_value = args_spec_list[i];
auto arg_value = input_types[i];
(void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id);
auto it_map = type_name_map.find(arg_type_id);
if (it_map == type_name_map.end()) {
@ -248,7 +233,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
}
continue;
}
if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) {
if ((arg_value->isa<TensorType>()) && arg_type_id == it->second) {
continue;
}
MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id
@ -275,6 +260,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
}
std::vector<AnfNodePtr> op_inputs;
std::set<size_t> write_indices;
std::vector<TypePtr> input_types;
op_inputs.push_back(NewValueNode(function));
// Assume, the write input of op is always the first input. We check if any write op,
// and add cast op on other inputs to keep the same type with assigned parameter.
@ -292,30 +278,36 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
sig = signature[sig_size - 1].rw;
}
TypePtr type = args_spec_list[i]->GetTypeTrack();
if (type && type->type_id() == kObjectTypeRef) {
auto ref_abs = args_spec_list[i]->cast<abstract::AbstractRefPtr>();
TypePtr type = args_spec_list[i]->BuildType();
if (type && type->isa<RefType>()) {
auto cast_type = parse::GetMixedPrecisionTargetType(func_graph);
if (sig == SignatureEnumRW::kRWRead) {
param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
if (ref_abs && ref_abs->need_cast()) {
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
param = NewCNode({NewValueNode(cast), param, NewValueNode(ref_abs->target_type())}, func_graph);
auto source_tensor_type = type->cast<TensorTypePtr>();
if (source_tensor_type != nullptr) {
auto source_element = source_tensor_type->element();
if (cast_type != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
param = NewCNode({NewValueNode(cast), param, NewValueNode(cast_type)}, func_graph);
type = cast_type->type_id() == kNumberTypeFloat16 ? kTensorTypeFP16 : kTensorTypeFP32;
}
}
} else if (sig == SignatureEnumRW::kRWWrite) {
param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
write_indices.insert(i);
}
// If sig is SignatureEnumRW::kRWRef, not do anything.
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter.";
} else if (sig == SignatureEnumRW::kRWWrite &&
!((type->type_id() == kObjectTypeRef) || (type->type_id() == kObjectTypeRefKey))) {
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but "
<< type->ToString();
}
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type "
<< args_spec_list[i]->ToString();
input_types.push_back(type);
op_inputs.push_back(param);
}
// process default
ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs);
DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices);
ProcessDefault(func_name, args_spec_list.size(), signature, has_var, &op_inputs);
DoAutoCast(func_name, signature, input_types, func_graph, &op_inputs, write_indices);
return func_graph->NewCNode(op_inputs);
}
} // namespace

View File

@ -81,12 +81,6 @@ void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &
}
Register(types_name, py_fn);
}
static TypePtr UnwrapRef(const TypePtr &type) {
if (type->isa<RefType>()) {
return type->cast<RefTypePtr>()->subtype();
}
return type;
}
// Return Exact match if exists, else return non ambiguous sub class match
// Return py::none() if matching is ambiguous
@ -99,7 +93,7 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
}
auto match = true;
for (size_t i = 0; i < sign.size(); ++i) {
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) {
if (!IsIdentidityOrSubclass(types[i], sign[i])) {
match = false;
break;
}

View File

@ -627,6 +627,16 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
}
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor
CheckArgsSize(primitive->name(), args_spec_list, 2);
MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0];
return args_spec_list[0];
}
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord);
@ -648,5 +658,7 @@ REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImpl
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
InferImplBroadcastGradientArgs);
REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign);
} // namespace abstract
} // namespace mindspore

View File

@ -20,6 +20,7 @@
#include "ir/anf.h"
#include "ir/param_info.h"
#include "ir/meta_tensor.h"
#include "pipeline/jit/parse/python_adapter.h"
namespace mindspore {
@ -38,8 +39,7 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
if (!para_ptr->has_default()) {
return false;
}
auto obj = py::cast(para_ptr->default_param());
auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
auto param_value = para_ptr->param_info();
if (param_value == nullptr) {
return false;
}

View File

@ -1356,8 +1356,7 @@ bool ParameterIsCloned(const AnfNodePtr &parameter_node) {
if (!cloned_parameter->has_default()) {
return false;
}
auto obj = py::cast(cloned_parameter->default_param());
auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
auto param_value = cloned_parameter->param_info();
if (param_value == nullptr) {
return false;
}
@ -1380,8 +1379,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
if (!ParameterIsCloned(cloned_parameter_node)) {
continue;
}
auto obj = py::cast(cloned_parameter->default_param());
auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
auto param_value = cloned_parameter->param_info();
if (param_value == nullptr) {
continue;
}
@ -1400,10 +1398,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
continue;
}
const auto &param_value_cloned = be_cloned_parameter->default_param();
auto obj_in = py::cast(param_value_cloned);
auto param_value_in = py::cast<ParamValuePtr>(obj_in.attr("_value"));
auto param_value_in = be_cloned_parameter->param_info();
if (param_value_in == nullptr) {
continue;
}

View File

@ -233,13 +233,14 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
for (const auto &param : func_graph->parameters()) {
auto param_node = std::static_pointer_cast<Parameter>(param);
if (param_node->has_default()) {
ValuePtr value = param_node->default_param();
constexpr bool broaden = true;
AbstractBasePtr ptr = abstract::FromValue(value, broaden);
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
args_spec.push_back(ptr);
parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr);
auto value = param_node->default_param();
auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
auto ref_key = std::make_shared<RefKey>(param_node->name());
auto abs_ref_key = ref_key->ToAbstract();
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, abs_ref);
args_spec.push_back(abs_ref);
parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, abs_ref);
}
}
// Analyze

View File

@ -425,9 +425,6 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
converted = env;
} else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
} else if (py::hasattr(obj, "__parameter__")) {
auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
ret = ConvertData(to_convert, &converted);
} else {
ret = ConvertOtherObj(obj, &converted);
}
@ -555,9 +552,6 @@ void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name)
ValuePtr PyDataToValue(const py::object &obj) {
py::object to_convert = obj;
if (py::hasattr(obj, "__parameter__")) {
to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
}
ValuePtr value = nullptr;
(void)ConvertData(to_convert, &value);
return value;

View File

@ -306,7 +306,14 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr
}
void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) {
state_assign_[target] = readid;
const std::string primitive_name("assign");
const std::string module_name("mindspore.ops.functional");
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
auto source = ReadVariable(readid);
auto assign = func_graph()->NewCNode({assign_op, target, source});
WriteVariable(readid, assign);
MS_LOG(INFO) << "SetState read " << target->DebugString() << ", " << readid;
AddAutoDepend(assign);
}
void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); }
@ -321,21 +328,13 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple);
ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend);
ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient);
const std::string primitive_name("assign");
const std::string module_name("mindspore.ops.functional");
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
if (state_assign_.size() == 0 && auto_depends_.size() == 0) {
if (auto_depends_.size() == 0) {
return;
}
AnfNodePtr state = nullptr;
std::vector<AnfNodePtr> vec_states;
vec_states.emplace_back(make_tuple_op);
for (auto &item : state_assign_) {
auto source = ReadVariable(item.second);
auto assign = func_graph()->NewCNode({assign_op, item.first, source});
MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second;
vec_states.emplace_back(assign);
}
for (auto &item : auto_depends_) {
MS_LOG(DEBUG) << "auto_depends " << item->ToString();
vec_states.emplace_back(item);
@ -361,7 +360,6 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state});
AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped});
func_graph()->set_output(ret, true);
state_assign_.clear();
}
} // namespace parse
} // namespace mindspore

View File

@ -101,9 +101,6 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
// keeps all removable phis which will be removed in one pass.
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_;
// set state nodes need to insert before function return nodes.
OrderedMap<AnfNodePtr, std::string> state_assign_;
// hold declared global variables in function
std::set<std::string> global_vars_;

View File

@ -59,14 +59,13 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
return func_graph;
}
ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr &param) {
TypePtr dst_type;
TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) {
if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
return kFloat32;
} else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
return kFloat16;
} else {
return kNone;
return nullptr;
}
}

View File

@ -359,7 +359,7 @@ class ParseAst {
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph);
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph);
} // namespace parse
} // namespace mindspore

View File

@ -105,24 +105,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
auto value = py::cast<tensor::MetaTensorPtr>(obj);
node->set_default_param(value);
// set_abstract for parameter
constexpr bool broaden = true;
node->set_abstract(abstract::FromValue(value, broaden));
auto abs = value->ToAbstract();
node->set_abstract(abs);
para_node = node;
}
auto iter = func_graph->make_ref_params().find(para_node);
if (iter == func_graph->make_ref_params().end()) {
ValuePtr target_type = GetMixedPrecisionTargetType(func_graph, para_node);
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
AnfNodePtr ref_key = NewValueNode(std::make_shared<RefKey>(param_name));
AnfNodePtr target_type_node = NewValueNode(target_type);
AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, para_node, target_type_node});
func_graph->make_ref_params()[para_node] = ref_node;
func_graph->add_parameter_obj_node(ref_node);
return ref_node;
} else {
return iter->second;
}
return para_node;
}
bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) {

View File

@ -640,7 +640,14 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
size_t size = op_exec_info->op_inputs.size();
for (size_t i = 0; i < size; i++) {
auto obj = op_exec_info->op_inputs[i];
bool op_mask = py::hasattr(obj, "__parameter__");
bool op_mask = false;
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
if (meta_tensor) {
op_mask = meta_tensor->is_parameter();
}
}
(*op_masks).push_back(op_mask);
MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ "
<< grad_flag_;
@ -988,8 +995,9 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
auto free_param = df_builder_->add_parameter();
free_param->set_name(param_name);
free_param->set_default_param(py::cast<tensor::TensorPtr>(obj));
free_param->debug_info()->set_name(param_name);
auto value = py::cast<tensor::TensorPtr>(obj);
free_param->set_default_param(value);
MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
graph_info_map_[df_builder_].param_map[obj_id] = free_param;
return free_param;
@ -1157,17 +1165,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
auto param_name = py::cast<std::string>(name_attr);
auto free_param = df_builder_->add_parameter();
free_param->set_name(param_name);
free_param->set_default_param(py::cast<tensor::TensorPtr>(param));
auto value = py::cast<tensor::TensorPtr>(param);
free_param->set_default_param(value);
free_param->debug_info()->set_name(param_name);
para_node = free_param;
}
ValuePtr target_type = parse::GetMixedPrecisionTargetType(df_builder_, para_node);
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name());
AnfNodePtr ref_key_node = NewValueNode(refkey);
AnfNodePtr target_type_node = NewValueNode(target_type);
AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, para_node, target_type_node});
w_args.push_back(ref_node);
w_args.push_back(para_node);
}
} else {
MS_LOG(DEBUG) << "training not paramter_tuple";
@ -1195,7 +1198,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
auto param_node = std::static_pointer_cast<Parameter>(param);
if (param_node->has_default()) {
ValuePtr value = param_node->default_param();
AbstractBasePtr ptr = abstract::FromValue(value, true);
auto ptr = value->ToAbstract();
if (ptr == nullptr) {
MS_LOG(EXCEPTION) << "Args convert error";
}

View File

@ -147,7 +147,7 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<TypeType, Type, std::shared_ptr<TypeType>>(m_sub, "TypeType").def(py::init());
(void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init());
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
(void)py::class_<RefType, TensorType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
(void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init());
(void)py::class_<TypeEllipsis, Type, std::shared_ptr<TypeEllipsis>>(m_sub, "TypeEllipsis").def(py::init());

View File

@ -21,7 +21,7 @@ namespace mindspore {
namespace py = pybind11;
REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
(void)py::class_<ParamInfo, ParamValuePtr>(*m, "ParamInfo")
(void)py::class_<ParamInfo, ParamInfoPtr>(*m, "ParamInfo")
.def(py::init())
.def("clone", &ParamInfo::Clone)
.def_property("name", &ParamInfo::name, &ParamInfo::set_name)
@ -36,7 +36,7 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
if (t.size() != 6) {
std::runtime_error("Invalid state for ParamInfo!");
}
ParamValuePtr p = std::make_shared<ParamInfo>();
ParamInfoPtr p = std::make_shared<ParamInfo>();
p->set_name(t[1].cast<std::string>());
p->set_requires_grad(t[2].cast<bool>());
p->set_layerwise_parallel(t[3].cast<bool>());

View File

@ -213,6 +213,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
.def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape"))
.def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.")
.def_property("_param_info", &MetaTensor::param_info, &MetaTensor::set_param_info)
.def(py::pickle(
[](const MetaTensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */

View File

@ -42,7 +42,7 @@ class Parameter(MetaTensor):
In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
an `Initializer`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor`
only saves the shape and type info of a tensor with no memory usage. The shape can be changed while
compile for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
Note:
Each parameter of Cell is represented by Parameter class.
@ -108,7 +108,7 @@ class Parameter(MetaTensor):
Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False):
self._value = ParamInfo()
self._param_info = ParamInfo()
self.name = name
self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel
@ -156,13 +156,13 @@ class Parameter(MetaTensor):
value_str = MetaTensor.__str__(self)
if isinstance(self, Tensor):
value_str = Tensor.__str__(self)
return f'Parameter (name={self._value.name}, value={value_str})'
return f'Parameter (name={self._param_info.name}, value={value_str})'
def __repr__(self):
value_str = MetaTensor.__repr__(self)
if isinstance(self, Tensor):
value_str = Tensor.__repr__(self)
return f'Parameter (name={self._value.name}, value={value_str})'
return f'Parameter (name={self._param_info.name}, value={value_str})'
def __parameter__(self):
"""For parse check."""
@ -181,7 +181,7 @@ class Parameter(MetaTensor):
@property
def name(self):
"""Get the name of the parameter."""
return self._value.name
return self._param_info.name
@name.setter
def name(self, name_):
@ -203,7 +203,7 @@ class Parameter(MetaTensor):
format(name_, PARAMETER_NAME_PREFIX_MAX_LEN))
else:
raise ValueError("The type of the name should be `str` or `None`.")
self._value.name = name_
self._param_info.name = name_
@property
def cast_type(self):
@ -254,8 +254,8 @@ class Parameter(MetaTensor):
_check_str_by_regular(prefix)
x = copy(self)
# pylint: disable=protected-access
x._value = self._value.clone()
x._value.name = prefix + '.' + self._value.name
x._param_info = self._param_info.clone()
x._param_info.name = prefix + '.' + self._param_info.name
x.is_init = False
if init != 'same':
shape = self.shape
@ -265,24 +265,24 @@ class Parameter(MetaTensor):
@property
def layerwise_parallel(self):
return self._value.layerwise_parallel
return self._param_info.layerwise_parallel
@layerwise_parallel.setter
def layerwise_parallel(self, value=True):
if not isinstance(value, bool):
raise TypeError("`layerwise_parallel` parameter must be bool type")
self._value.layerwise_parallel = value
self._param_info.layerwise_parallel = value
@property
def requires_grad(self):
"""Return whether the parameter requires gradient."""
return self._value.requires_grad
return self._param_info.requires_grad
@requires_grad.setter
def requires_grad(self, value=True):
if not isinstance(value, bool):
raise TypeError("`requires_grad` parameter must be bool type")
self._value.requires_grad = value
self._param_info.requires_grad = value
@property
def data(self):

View File

@ -459,10 +459,6 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
}
auto other_tensor = dyn_cast<AbstractTensor>(other);
if (other_tensor == nullptr) {
auto ref_tensor = dyn_cast<AbstractRef>(other);
if (ref_tensor != nullptr) {
return this->Join(ref_tensor->ref());
}
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
}
if (*this == *other) {
@ -473,7 +469,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
return std::make_shared<AbstractTensor>(element, shape);
}
bool AbstractTensor::operator==(const AbstractTensor &other) const {
bool AbstractTensor::equal_to(const AbstractTensor &other) const {
if (&other == this) {
return true;
}
@ -491,12 +487,14 @@ bool AbstractTensor::operator==(const AbstractTensor &other) const {
return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal;
}
bool AbstractTensor::operator==(const AbstractTensor &other) const { return equal_to(other); }
bool AbstractTensor::operator==(const AbstractBase &other) const {
if (&other == this) {
return true;
}
if (other.isa<AbstractTensor>()) {
if (other.tid() == tid()) {
auto other_tensor = static_cast<const AbstractTensor *>(&other);
return *this == *other_tensor;
} else {
@ -822,39 +820,21 @@ std::string AbstractJTagged::ToString() const {
return buffer.str();
}
AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast,
TypePtr cast_target)
: ref_key_(ref_key), ref_(ref_value), need_cast_(false), target_type_(nullptr), ref_key_value_(nullptr) {
AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value)
: AbstractTensor(*ref_value), ref_key_(ref_key), ref_key_value_(nullptr) {
set_type(std::make_shared<RefType>());
auto origin_type = ref_value->BuildType();
if (need_cast && cast_target && origin_type && origin_type->isa<TensorType>()) {
auto tensor_dtype = origin_type->cast<TensorTypePtr>()->element();
if (tensor_dtype && IsSubType(tensor_dtype, kFloat)) {
if (cast_target != tensor_dtype) {
need_cast_ = true;
target_type_ = cast_target;
}
}
}
if (ref_key && ref_key->isa<AbstractRefKey>()) {
ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value();
}
}
BaseShapePtr AbstractRef::BuildShape() const { return ref_->BuildShape(); }
TypePtr AbstractRef::BuildType() const {
TypePtr subtype = ref_->BuildType();
TypePtr subtype_origin = subtype;
if (need_cast_) {
subtype_origin = std::make_shared<TensorType>(target_type_);
}
return std::make_shared<RefType>(subtype, subtype_origin);
auto subtype = AbstractTensor::BuildType()->cast<TensorTypePtr>();
return std::make_shared<RefType>(subtype);
}
bool AbstractRef::operator==(const AbstractRef &other) const {
return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) && (*ref_key_ == *other.ref_key_) &&
(!need_cast_ || (*target_type_ == *other.target_type_));
return AbstractTensor::equal_to(other) && (*ref_key_ == *other.ref_key_);
}
bool AbstractRef::operator==(const AbstractBase &other) const {
@ -886,24 +866,20 @@ AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) {
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
auto other_ref = other->cast<AbstractRefPtr>();
if (other_ref == nullptr) {
auto new_ref = ref_->Join(other);
return std::make_shared<AbstractRef>(ref_key_, new_ref);
return AbstractTensor::Join(other)->cast<AbstractTensorPtr>();
}
if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) {
return shared_from_base<AbstractBase>();
}
auto ref_key = ref_key_->Join(other_ref->ref_key_);
auto ref = ref_->Join(other_ref->ref());
auto ref = AbstractTensor::Join(other_ref->ref())->cast<AbstractTensorPtr>();
return std::make_shared<AbstractRef>(ref_key, ref);
}
std::string AbstractRef::ToString() const {
std::ostringstream buffer;
buffer << type_name() << "("
<< "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString();
if (need_cast_) {
buffer << " cast to: " << target_type_->ToString();
}
<< "key: " << ref_key_->ToString() << " ref_value: " << AbstractTensor::ToString();
auto value = GetValueTrack();
if (value) {
buffer << ", value: " << value->ToString();

View File

@ -284,11 +284,9 @@ class AbstractTensor : public AbstractUndetermined {
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr BroadenWithShape() const;
AbstractBasePtr Join(const AbstractBasePtr &other) final;
AbstractBasePtr Join(const AbstractBasePtr &other);
bool operator==(const AbstractTensor &other) const;
bool operator==(const AbstractBase &other) const override;
std::string ToString() const override;
std::size_t hash() const override {
auto value = GetValueTrack();
@ -301,6 +299,9 @@ class AbstractTensor : public AbstractUndetermined {
}
return hash_sum;
}
protected:
bool equal_to(const AbstractTensor &other) const;
};
using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
using AbstractTensorPtrList = std::vector<AbstractTensorPtr>;
@ -575,42 +576,42 @@ class AbstractRefKey : public AbstractBase {
};
using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>;
class AbstractRef : public AbstractBase {
class AbstractRef : public AbstractTensor {
public:
AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast = false,
TypePtr cast_target = nullptr);
AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value);
~AbstractRef() override = default;
MS_DECLARE_PARENT(AbstractRef, AbstractBase)
MS_DECLARE_PARENT(AbstractRef, AbstractTensor)
TypePtr BuildType() const override;
BaseShapePtr BuildShape() const override;
bool operator==(const AbstractRef &other) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override {
return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), need_cast_, target_type_);
auto abs_tensor = AbstractTensor::Clone()->cast<AbstractTensorPtr>();
if (abs_tensor == nullptr) {
return nullptr;
}
return std::make_shared<AbstractRef>(ref_key_->Clone(), abs_tensor);
}
std::string ToString() const override;
inline AbstractBasePtr ref() const { return ref_; }
inline AbstractTensorPtr ref() { return shared_from_base<AbstractTensor>(); }
inline AbstractBasePtr ref_key() const { return ref_key_; }
inline RefKeyPtr ref_key_value() const { return ref_key_value_; }
inline TypePtr target_type() const { return target_type_; }
inline bool need_cast() const { return need_cast_; }
AbstractBasePtr Broaden(uint8_t config = 0) const override {
// always broaden for ref
return std::make_shared<AbstractRef>(ref_key_->Broaden(config), ref_->Broaden(), need_cast_, target_type_);
auto abs_tensor = AbstractTensor::Broaden()->cast<AbstractTensorPtr>();
if (abs_tensor == nullptr) {
return nullptr;
}
return std::make_shared<AbstractRef>(ref_key_->Broaden(config), abs_tensor);
}
AbstractBasePtr Join(const AbstractBasePtr &other) override;
std::size_t hash() const override {
return ref_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^
return AbstractTensor::hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^
}
private:
AbstractBasePtr ref_key_;
AbstractBasePtr ref_;
// For mix presicion, only float type need to cast to float16 of float32
bool need_cast_;
TypePtr target_type_;
// cache for ref_key after build value, when value is null, return nullptr.
RefKeyPtr ref_key_value_;
};

View File

@ -113,17 +113,8 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
<< ".";
}
TypePtr type = args_spec_list[0]->GetTypeTrack();
ValuePtr tensor_target_v = args_spec_list[2]->BuildValue();
if (type->type_id() != kObjectTypeRefKey) {
MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString();
}
auto need_cast = !tensor_target_v->isa<None>();
if (need_cast && !tensor_target_v->isa<Type>()) {
MS_LOG(EXCEPTION) << "Third input of make_ref should be a Type but a " << tensor_target_v->ToString();
}
TypePtr cast_target = tensor_target_v->cast<TypePtr>();
return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], need_cast, cast_target);
auto tensor = args_spec_list[1]->cast<abstract::AbstractTensorPtr>();
return std::make_shared<AbstractRef>(args_spec_list[0], tensor);
}
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,

View File

@ -88,6 +88,17 @@ std::string Parameter::DebugString(int recursive_level) const {
return buffer.str();
}
ParamInfoPtr Parameter::param_info() const {
if (!has_default()) {
return nullptr;
}
auto tensor = default_param()->cast<tensor::MetaTensorPtr>();
if (tensor == nullptr || !tensor->is_parameter()) {
return nullptr;
}
return tensor->param_info();
}
std::string ValueNode::ToString() const {
MS_EXCEPTION_IF_NULL(value_);
if (value_->isa<FuncGraph>()) {

View File

@ -75,7 +75,7 @@ using VarPtr = std::shared_ptr<Var>;
class AnfIrVisitor;
class ParamInfo;
using ParamValuePtr = std::shared_ptr<ParamInfo>;
using ParamInfoPtr = std::shared_ptr<ParamInfo>;
// AnfNode is the basic class of the IR definition derived from Base.
// Only two types of nodes are derived: CNode and ANode.
@ -288,6 +288,7 @@ class Parameter : public ANode {
has_default_ = true;
}
ValuePtr default_param() const { return default_param_; }
ParamInfoPtr param_info() const;
bool operator==(const AnfNode &other) const override {
if (!other.isa<Parameter>()) {

View File

@ -94,175 +94,6 @@ bool Slice::operator==(const Type &other) const {
std::string Slice::DumpText() const { return ToString(); }
TypePtr UndeterminedType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<UndeterminedType>();
}
return std::make_shared<UndeterminedType>(element_type_->DeepCopy());
}
std::string UndeterminedType::ToReprString() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->ToReprString() + "]";
}
std::string UndeterminedType::ToString() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->ToString() + "]";
}
std::string UndeterminedType::DumpText() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->DumpText() + "]";
}
bool UndeterminedType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const UndeterminedType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
TypePtr TensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<TensorType>();
}
return std::make_shared<TensorType>(element_type_->DeepCopy());
}
std::string TensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "tensor";
}
return "tensor[" + element_type_->ToReprString() + "]";
}
std::string TensorType::ToString() const {
if (element_type_ == nullptr) {
return "Tensor";
}
return "Tensor[" + element_type_->ToString() + "]";
}
std::string TensorType::DumpText() const {
if (element_type_ == nullptr) {
return "Tensor";
}
return "Tensor(" + element_type_->DumpText() + ")";
}
bool TensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const TensorType &>(other).element_type_;
// When element_type_ = nullptr, which means any type of Array.
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
TypePtr RowTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<RowTensorType>();
}
return std::make_shared<RowTensorType>(element_type_->DeepCopy());
}
std::string RowTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "RowTensor";
}
return "RowTensor[" + element_type_->ToReprString() + "]";
}
std::string RowTensorType::ToString() const {
if (element_type_ == nullptr) {
return "RowTensor";
}
return "RowTensor[" + element_type_->ToString() + "]";
}
std::string RowTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "RowTensor";
}
return "RowTensor[" + element_type_->DumpText() + "]";
}
bool RowTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const RowTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
TypePtr SparseTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<SparseTensorType>();
}
return std::make_shared<SparseTensorType>(element_type_->DeepCopy());
}
std::string SparseTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToReprString() + "]";
}
std::string SparseTensorType::ToString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToString() + "]";
}
std::string SparseTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->DumpText() + "]";
}
bool SparseTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
Function::Function() : Object(kObjectTypeFunction) {
args_ = std::vector<TypePtr>();
retval_ = nullptr;
@ -372,4 +203,8 @@ std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> proble
os << problem->ToString();
return os;
}
const TypePtr kTensorTypeFP16 = std::make_shared<TensorType>(std::make_shared<Float>(16));
const TypePtr kTensorTypeFP32 = std::make_shared<TensorType>(std::make_shared<Float>(32));
} // namespace mindspore

View File

@ -32,10 +32,11 @@
#include "ir/named.h"
#include "ir/dtype/type.h"
#include "ir/dtype/ref.h"
#include "ir/dtype/number.h"
#include "ir/dtype/container.h"
#include "ir/dtype/empty.h"
#include "ir/dtype/tensor_type.h"
#include "ir/dtype/ref.h"
/* namespace to support intermediate representation definition */
namespace mindspore {
@ -108,98 +109,6 @@ class Slice : public Object {
};
using SlicePtr = std::shared_ptr<Slice>;
class UndeterminedType : public Object {
public:
UndeterminedType() : Object(kObjectTypeUndeterminedType) {}
explicit UndeterminedType(const TypePtr &ele)
: Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {}
~UndeterminedType() override = default;
MS_DECLARE_PARENT(UndeterminedType, Object)
TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
protected:
TypePtr element_type_;
};
using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>;
class TensorType : public Object {
public:
TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {}
explicit TensorType(const TypePtr &ele)
: Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~TensorType() override = default;
MS_DECLARE_PARENT(TensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using TensorTypePtr = std::shared_ptr<TensorType>;
class RowTensorType : public Object {
public:
RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {}
explicit RowTensorType(const TypePtr &ele)
: Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~RowTensorType() override = default;
MS_DECLARE_PARENT(RowTensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeRowTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using RowTensorTypePtr = std::shared_ptr<RowTensorType>;
class SparseTensorType : public Object {
public:
SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
explicit SparseTensorType(const TypePtr &ele)
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~SparseTensorType() override = default;
MS_DECLARE_PARENT(SparseTensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
class Function : public Object {
public:
Function();
@ -353,6 +262,9 @@ extern const TypePtr kDict;
extern const TypePtr kSlice;
extern const TypePtr kKeyword;
extern const TypePtr kTensorType;
extern const TypePtr kTensorTypeFP16;
extern const TypePtr kTensorTypeFP32;
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_DTYPE_H_

View File

@ -68,6 +68,8 @@ class Number : public Object {
const int nbits_;
};
using NumberPtr = std::shared_ptr<Number>;
// Bool
class Bool : public Number {
public:

View File

@ -19,15 +19,15 @@
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
#include "ir/dtype/tensor_type.h"
namespace mindspore {
TypePtr RefType::DeepCopy() const {
if (IsGeneric()) {
return std::make_shared<RefType>();
} else {
auto subtype = subtype_->DeepCopy();
auto subtype_origin = subtype_origin_->DeepCopy();
return std::make_shared<RefType>(subtype, subtype_origin);
auto subtype = TensorType::DeepCopy()->cast<TensorTypePtr>();
return std::make_shared<RefType>(subtype);
}
}
@ -39,7 +39,7 @@ std::string RefType::DumpText() const {
buffer << "Ref";
} else {
buffer << "Ref[";
buffer << subtype_->DumpText() << "]";
buffer << TensorType::DumpText() << "]";
}
return buffer.str();
}

View File

@ -17,21 +17,13 @@
#ifndef MINDSPORE_CORE_IR_DTYPE_REF_H_
#define MINDSPORE_CORE_IR_DTYPE_REF_H_
#include <cstddef>
#include <iostream>
#include <initializer_list>
#include <map>
#include <memory>
#include <utility>
#include <sstream>
#include <string>
#include <vector>
#include <type_traits>
#include <unordered_map>
#include <algorithm>
#include "base/base.h"
#include "ir/named.h"
#include "ir/dtype/type.h"
#include "ir/dtype/tensor_type.h"
namespace mindspore {
// TypeRefKey type
@ -48,23 +40,16 @@ class RefKeyType : public Object {
};
// TypeRef type
class RefType : public Object {
class RefType : public TensorType {
public:
RefType() : Object(kObjectTypeRef) {}
RefType(const TypePtr &subtype, const TypePtr &subtype_origin)
: Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {}
RefType() : TensorType() {}
explicit RefType(const TensorTypePtr &subtype) : TensorType(subtype->element()) {}
~RefType() override {}
MS_DECLARE_PARENT(RefType, Object)
MS_DECLARE_PARENT(RefType, TensorType)
TypePtr subtype() const { return subtype_; }
TypeId generic_type_id() const override { return kObjectTypeRef; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string DumpText() const override;
private:
TypePtr subtype_;
TypePtr subtype_origin_;
};
using RefTypePtr = std::shared_ptr<RefType>;

View File

@ -0,0 +1,194 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/dtype/tensor_type.h"
#include <string>
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
namespace mindspore {
TypePtr UndeterminedType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<UndeterminedType>();
}
return std::make_shared<UndeterminedType>(element_type_->DeepCopy());
}
std::string UndeterminedType::ToReprString() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->ToReprString() + "]";
}
std::string UndeterminedType::ToString() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->ToString() + "]";
}
std::string UndeterminedType::DumpText() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->DumpText() + "]";
}
bool UndeterminedType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const UndeterminedType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
TypePtr TensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<TensorType>();
}
return std::make_shared<TensorType>(element_type_->DeepCopy());
}
std::string TensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "tensor";
}
return "tensor[" + element_type_->ToReprString() + "]";
}
std::string TensorType::ToString() const {
if (element_type_ == nullptr) {
return "Tensor";
}
return "Tensor[" + element_type_->ToString() + "]";
}
std::string TensorType::DumpText() const {
if (element_type_ == nullptr) {
return "Tensor";
}
return "Tensor(" + element_type_->DumpText() + ")";
}
bool TensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const TensorType &>(other).element_type_;
// When element_type_ = nullptr, which means any type of Array.
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
TypePtr RowTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<RowTensorType>();
}
return std::make_shared<RowTensorType>(element_type_->DeepCopy());
}
std::string RowTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "RowTensor";
}
return "RowTensor[" + element_type_->ToReprString() + "]";
}
std::string RowTensorType::ToString() const {
if (element_type_ == nullptr) {
return "RowTensor";
}
return "RowTensor[" + element_type_->ToString() + "]";
}
std::string RowTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "RowTensor";
}
return "RowTensor[" + element_type_->DumpText() + "]";
}
bool RowTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const RowTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
TypePtr SparseTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<SparseTensorType>();
}
return std::make_shared<SparseTensorType>(element_type_->DeepCopy());
}
std::string SparseTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToReprString() + "]";
}
std::string SparseTensorType::ToString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToString() + "]";
}
std::string SparseTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->DumpText() + "]";
}
bool SparseTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
} // namespace mindspore

View File

@ -0,0 +1,132 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
#define MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
#include <cstddef>
#include <iostream>
#include <initializer_list>
#include <map>
#include <memory>
#include <utility>
#include <sstream>
#include <string>
#include <vector>
#include <type_traits>
#include <unordered_map>
#include <algorithm>
#include "base/base.h"
#include "ir/named.h"
#include "ir/dtype/type.h"
namespace mindspore {
class UndeterminedType : public Object {
public:
UndeterminedType() : Object(kObjectTypeUndeterminedType) {}
explicit UndeterminedType(const TypePtr &ele)
: Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {}
~UndeterminedType() override = default;
MS_DECLARE_PARENT(UndeterminedType, Object)
TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
protected:
TypePtr element_type_;
};
using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>;
class TensorType : public Object {
public:
TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {}
explicit TensorType(const TypePtr &ele)
: Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~TensorType() override = default;
MS_DECLARE_PARENT(TensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using TensorTypePtr = std::shared_ptr<TensorType>;
class RowTensorType : public Object {
public:
RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {}
explicit RowTensorType(const TypePtr &ele)
: Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~RowTensorType() override = default;
MS_DECLARE_PARENT(RowTensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeRowTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using RowTensorTypePtr = std::shared_ptr<RowTensorType>;
class SparseTensorType : public Object {
public:
SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
explicit SparseTensorType(const TypePtr &ele)
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~SparseTensorType() override = default;
MS_DECLARE_PARENT(SparseTensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_

View File

@ -332,14 +332,11 @@ class FuncGraph : public FuncGraphBase {
const std::vector<AnfNodePtr> &paramter_obj_nodes() const { return paramter_obj_nodes_; }
void add_parameter_obj_node(const AnfNodePtr &p);
std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; }
std::unordered_map<std::string, ValuePtr> attrs_;
std::vector<BaseShapePtr> joined_shapes_;
std::unordered_map<std::string, FuncGraphTransform> transforms_;
// parameter default value
std::map<std::string, AnfNodePtr> parameter_default_value_;
std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_;
size_t seen_;
std::list<CNodePtr> GetOrderedCnodes();

View File

@ -23,6 +23,7 @@
#include <string>
#include "base/base.h"
#include "ir/param_info.h"
#include "ir/dtype.h"
#include "utils/convert_utils_base.h"
#include "utils/hashing.h"
@ -163,6 +164,15 @@ class MetaTensor : public Value {
return false;
}
}
// Get tensor's param_info info.
ParamInfoPtr param_info() const { return param_info_; }
bool is_parameter() const { return is_parameter_; }
// Set tensor's param_info info.
void set_param_info(const ParamInfoPtr &param_info) {
is_parameter_ = true;
param_info_ = param_info;
}
protected:
// brief Data type of the tensor.
@ -184,6 +194,9 @@ class MetaTensor : public Value {
//
// Includes the format and data type of a tensor on device.
DeviceInfo device_info_;
bool is_parameter_{false};
ParamInfoPtr param_info_{nullptr};
};
using MetaTensorPtr = std::shared_ptr<MetaTensor>;

View File

@ -34,7 +34,16 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() {
}
auto tensor_shape = tens->shape();
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
abs_tensor->set_value(shared_from_base<MetaTensor>());
// if is parameter always no value.
if (is_parameter()) {
auto param_name = param_info()->name();
auto ref_key = std::make_shared<RefKey>(param_name);
auto abs_ref_key = ref_key->ToAbstract();
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
} else {
abs_tensor->set_value(shared_from_base<MetaTensor>());
}
return abs_tensor;
}

View File

@ -62,6 +62,21 @@ class Named : public Value {
};
using NamedPtr = std::shared_ptr<Named>;
struct NamedHasher {
std::size_t operator()(NamedPtr const &name) const {
std::size_t hash = name->Hash();
return hash;
}
};
struct NamedEqual {
bool operator()(NamedPtr const &t1, NamedPtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);
return *t1 == *t2;
}
};
class None : public Named {
public:
None() : Named("None") {}

View File

@ -21,10 +21,13 @@
#include <memory>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "ir/tensor.h"
#include "ir/dtype.h"
namespace mindspore {
class ParamInfo;
using ParamInfoPtr = std::shared_ptr<ParamInfo>;
class ParamInfo {
public:
ParamInfo() {}
@ -55,7 +58,7 @@ class ParamInfo {
int32_t cloned_index() const { return cloned_index_; }
// Make a cloned parameter and update clone info.
ParamValuePtr Clone() {
ParamInfoPtr Clone() {
static std::atomic<int32_t> parameter_cloned_index{1};
int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed);
auto clone = std::make_shared<ParamInfo>(*this);

View File

@ -461,6 +461,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
}
return *this;
}
abstract::AbstractBasePtr Tensor::ToAbstract() {
auto tens = shared_from_base<Tensor>();
auto dtype = tens->Dtype();
@ -469,7 +470,15 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
}
auto tensor_shape = tens->shape();
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
abs_tensor->set_value(shared_from_base<Tensor>());
// if is parameter always no value.
if (is_parameter()) {
auto param_name = param_info()->name();
auto ref_key = std::make_shared<RefKey>(param_name);
auto abs_ref_key = ref_key->ToAbstract();
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
} else {
abs_tensor->set_value(shared_from_base<Tensor>());
}
return abs_tensor;
}

View File

@ -200,16 +200,6 @@ bool StringImm::operator==(const Value &other) const {
}
bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; }
bool RefKey::operator==(const Value &other) const {
if (other.isa<RefKey>()) {
auto other_ = static_cast<const RefKey &>(other);
return *this == other_;
} else {
return false;
}
}
bool RefKey::operator==(const RefKey &other) const { return tag_ == other.tag_; }
bool AnyValue::operator==(const Value &other) const {
if (other.isa<AnyValue>()) {
return true;

View File

@ -224,28 +224,21 @@ using StringImmPtr = std::shared_ptr<StringImm>;
IMM_TRAITS(StringImmPtr, std::string)
IMM_TRAITS(StringImmPtr, const char *)
class RefKey : public Value {
class RefKey : public Named {
public:
explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {}
explicit RefKey(const std::string &tag) : Named(tag) {}
~RefKey() override = default;
MS_DECLARE_PARENT(RefKey, Value)
std::size_t hash() const override { return hash_; }
const std::string &tag() const { return tag_; }
bool operator==(const Value &other) const override;
bool operator==(const RefKey &other) const;
MS_DECLARE_PARENT(RefKey, Named)
const std::string &tag() const { return name(); }
abstract::AbstractBasePtr ToAbstract() override;
std::string ToString() const override { return "RefKey[" + tag_ + "]"; }
std::string ToString() const override { return "RefKey[" + name() + "]"; }
std::string DumpText() const override {
std::ostringstream oss;
oss << "RefKey[\"" << tag_ << "\"]";
oss << "RefKey[\"" << name() << "\"]";
return oss.str();
}
private:
std::string tag_;
std::size_t hash_ = 0;
};
using RefKeyPtr = std::shared_ptr<RefKey>;

View File

@ -43,6 +43,8 @@ if(BUILD_CONVERTER)
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/scope.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value_extends.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/ref.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/tensor_type.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/container.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/empty.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/number.cc

View File

@ -29,6 +29,8 @@ set(ANF_SRC
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/scope.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value_extends.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/ref.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/tensor_type.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/container.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/empty.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/number.cc

View File

@ -23,7 +23,7 @@ from ...common import dtype as mstype
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
class Assign(PrimitiveWithInfer):
class Assign(Primitive):
"""
Assign `Parameter` with a value.

View File

@ -18,7 +18,6 @@
import inspect
import copy
from mindspore.common.api import _wrap_func
from mindspore.common import Parameter
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore import context
from .._c_expression import Primitive_, real_run_op, prim_type
@ -410,16 +409,12 @@ def _run_op(obj, op_name, args):
if op_name == "Cast" or obj.update_parameter:
cast_args = args
else:
cast_args = list()
for arg in args:
if isinstance(arg, Parameter):
if arg.cast_type:
cast_args.append(cast(arg, arg.cast_type))
else:
cast_args.append(arg)
else:
cast_args.append(arg)
output = real_run_op(obj, op_name, tuple(cast_args))
cast_args = args
for idx, arg in enumerate(args):
cast_type = getattr(arg, "cast_type", None)
if cast_type:
cast_args[idx] = cast(arg, cast_type)
output = real_run_op(obj, op_name, cast_args)
if not output:
raise RuntimeError("Pynative run op %s failed!" % op_name)
if len(output) == 1:

View File

@ -118,26 +118,31 @@ class ControlMixedWhileIf(nn.Cell):
self.var = Parameter(initializer(1, (1), mstype.float32), name="var")
def construct(self, x, y, z, c2, c4):
out = self.assign(self.var, c4)
out = c4
self.assign(self.var, c4)
while x < c2:
y = self.assign(self.var, c4)
y = c4
self.assign(self.var, c4)
while y < c2 and x < c2:
if 2 * y < c2:
y = y + 2
else:
y = y + 1
out = out + y
z = self.assign(self.var, c4)
z = c4
self.assign(self.var, c4)
while z < c2:
z = z + 1
out = out + z
x = x + 1
out = out + x
while x < 2 * c2:
y = self.assign(self.var, c4)
y = c4
self.assign(self.var, c4)
x = x + 1
while y < c2:
z = self.assign(self.var, c4)
z = c4
self.assign(self.var, c4)
while z < c2:
z = z + 1
if x < c2:

View File

@ -27,6 +27,7 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common.api import ms_function, _executor
from mindspore.ops._grad.grad_base import bprop_getters
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
@ -254,3 +255,60 @@ def test_bprop_with_wrong_output_shape():
net = BpropWithWrongOutputShapeCell()
net.set_grad()
grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))
class AssignWhenInsertGrad(nn.Cell):
""" NetWithNDarray definition """
def __init__(self):
super(AssignWhenInsertGrad, self).__init__()
self.gather = P.GatherV2()
self.damping = Tensor(np.array([0.03, 0.03]).astype(np.float32))
self.cov_step = ms.Parameter(0, name="cov_step", requires_grad=False)
self.freq = Tensor(278, ms.int32)
self.getG = P.InsertGradientOf(self.save_gradient)
def save_gradient(self, dout):
self.cov_step = self.cov_step + self.freq
return dout
def construct(self, x):
self.gather(self.damping, self.cov_step, 0)
out = P.ReLU()(x)
out = self.getG(out)
return out
grad_all = C.GradOperation('get_all', get_all=True)
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
out = self.net(*inputs)
return out, grad_all(self.net)(*inputs)
def test_assign_in_insert_grad():
context.set_context(mode=context.GRAPH_MODE)
net = AssignWhenInsertGrad().to_float(ms.float16)
input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
net_back = GradNet(net)
net_back(ms.Tensor(input_data))
class Assign(nn.Cell):
""" NetWithNDarray definition """
def __init__(self):
super(Assign, self).__init__()
self.cov_step = ms.Parameter(0.0, name="cov_step", requires_grad=False)
def construct(self, x):
self.cov_step = self.cov_step + x
return self.cov_step
def test_assign():
context.set_context(mode=context.GRAPH_MODE)
net = Assign()
input_data = ms.Tensor(np.array(1).astype(np.int32))
net_back = GradNet(net)
net_back(input_data)

View File

@ -0,0 +1,144 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_cont_break """
import numpy as np
import mindspore as ms
from mindspore import Tensor, context, nn, ms_function
from mindspore.nn import Cell
from mindspore.ops import operations as P
class WhileSubGraphParam(Cell):
def __init__(self):
super().__init__()
self.update = ms.Parameter(Tensor(1, ms.float32), "update")
def construct(self, x, y, z):
out1 = z
while x < y:
self.update = self.update + 1
out1 = out1 + 1
x = x + 1
return out1, self.update
def test_while_loop_phi():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(0, ms.float32)
y = Tensor(10, ms.float32)
z = Tensor(100, ms.float32)
net = WhileSubGraphParam()
net(x, y, z)
class WhileSubGraphParam2(Cell):
def __init__(self):
super().__init__()
self.update = ms.Parameter(Tensor(1, ms.float32), "update")
def construct(self, x, y, z):
out1 = z
i = self.update
while x < y:
i = i + 1
out1 = out1 + 1
x = x + 1
return out1, self.update
def test_while_loop_phi_2():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(0, ms.float32)
y = Tensor(10, ms.float32)
z = Tensor(100, ms.float32)
net = WhileSubGraphParam2()
net(x, y, z)
class WhileSubGraphParam3(Cell):
def __init__(self, initial_input_x):
super().__init__()
self.initial_input_x = initial_input_x
self.X = ms.Parameter(initial_input_x, name="parameter_x")
self.Y = ms.Parameter(self.initial_input_x, name="parameter_y")
def construct(self):
a = 0
while a < 3:
self.X = self.X + self.Y
a += 1
return self.X
def test_while_loop_phi_3():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(0, ms.float32)
net = WhileSubGraphParam3(x)
net()
class ControlMixedWhileIf(nn.Cell):
def __init__(self):
super().__init__()
self.assign = P.Assign()
self.var = ms.Parameter(ms.Tensor([1], ms.float32), name="var")
@ms_function
def construct(self, x, y, z, c2, c4):
out = self.assign(self.var, c4)
while x < c2:
y = self.assign(self.var, c4)
while y < c2 and x < c2:
if 2 * y < c2:
y = y + 2
else:
y = y + 1
out = out + y
z = self.assign(self.var, c4)
while z < c2:
z = z + 1
out = out + z
x = x + 1
out = out + x
while x < 2 * c2:
y = self.assign(self.var, c4)
x = x + 1
while y < c2:
z = self.assign(self.var, c4)
while z < c2:
z = z + 1
if x < c2:
y = y - 1
else:
y = y + 1
out = out + z
out = out + y
out = out + x
return out
def test_mixed_while_if():
context.set_context(mode=context.PYNATIVE_MODE)
x = np.array(2).astype(np.int32)
y = np.array(14).astype(np.int32)
z = np.array(1).astype(np.int32)
c2 = Tensor([14], ms.int32)
c4 = Tensor([0], ms.int32)
net = ControlMixedWhileIf()
output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4)
expect = np.array(3318).astype(np.int32)
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
context.set_context(mode=context.GRAPH_MODE)

View File

@ -22,7 +22,13 @@ from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
from .vm_interface import vm
# pylint: disable=unused-argument
@vm_impl_getters.register(P.Assign)
def vm_impl_assign(self):
"""Generate vm_impl function for Assign"""
def vm_impl(x, value):
x.assign_value(value)
return x
return vm_impl
@vm_impl_getters.register(P.ExpandDims)
def vm_impl_expand_dims(self):