!3271 make reftype a subtype of MetaTensor and try to mark ref in node input
Merge pull request !3271 from vlne-v1/ref_demo
This commit is contained in:
commit
95212b55a0
|
@ -185,14 +185,23 @@ class Validator:
|
||||||
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
|
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_subclass(arg_name, type_, template_type, prim_name):
|
def check_subclass(arg_name, type_, template_types, prim_name):
|
||||||
"""Checks whether some type is subclass of another type"""
|
"""Checks whether some type is subclass of another type"""
|
||||||
if not isinstance(template_type, Iterable):
|
if not isinstance(template_types, Iterable):
|
||||||
template_type = (template_type,)
|
template_types = (template_types,)
|
||||||
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
hit = False
|
||||||
|
for template_type in template_types:
|
||||||
|
if isinstance(template_type, mstype.Type):
|
||||||
|
if mstype.issubclass_(type_, template_type):
|
||||||
|
hit = True
|
||||||
|
break
|
||||||
|
elif type_ is template_type:
|
||||||
|
hit = True
|
||||||
|
break
|
||||||
|
if not hit:
|
||||||
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
|
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
|
||||||
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
|
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
|
||||||
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
|
f' of {",".join((str(x) for x in template_types))}, but got {type_str}.')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_const_input(arg_name, arg_value, prim_name):
|
def check_const_input(arg_name, arg_value, prim_name):
|
||||||
|
@ -206,13 +215,7 @@ class Validator:
|
||||||
def _check_tensor_type(arg):
|
def _check_tensor_type(arg):
|
||||||
arg_key, arg_val = arg
|
arg_key, arg_val = arg
|
||||||
elem_type = arg_val
|
elem_type = arg_val
|
||||||
if not elem_type in valid_values:
|
Validator.check_subclass(arg_key, elem_type, valid_values, prim_name)
|
||||||
type_names = []
|
|
||||||
for t in valid_values:
|
|
||||||
type_names.append(str(t))
|
|
||||||
types_info = '[' + ', '.join(type_names) + ']'
|
|
||||||
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},'
|
|
||||||
f' but got {elem_type}.')
|
|
||||||
return (arg_key, elem_type)
|
return (arg_key, elem_type)
|
||||||
|
|
||||||
def _check_types_same(arg1, arg2):
|
def _check_types_same(arg1, arg2):
|
||||||
|
@ -335,12 +338,6 @@ class Validator:
|
||||||
class ParamValidator:
|
class ParamValidator:
|
||||||
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
|
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def equal(arg_name, arg_value, cond_str, cond):
|
|
||||||
"""Judging valid value."""
|
|
||||||
if not cond:
|
|
||||||
raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.')
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
|
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
|
||||||
"""This method is only used for check int values, since when compare float values,
|
"""This method is only used for check int values, since when compare float values,
|
||||||
|
@ -360,27 +357,6 @@ class ParamValidator:
|
||||||
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
|
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
|
||||||
return arg_value
|
return arg_value
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_shape_length(arg_name, arg_value, value, rel):
|
|
||||||
"""Shape length judgment."""
|
|
||||||
rel_fn = Rel.get_fns(rel)
|
|
||||||
type_mismatch = not isinstance(arg_value, int)
|
|
||||||
if type_mismatch or not rel_fn(arg_value, value):
|
|
||||||
rel_str = Rel.get_strs(rel).format(value)
|
|
||||||
raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}')
|
|
||||||
return arg_value
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel):
|
|
||||||
"""This method is only used for check int values,
|
|
||||||
since when compare float values, we need consider float error."""
|
|
||||||
rel_fn = Rel.get_fns(rel)
|
|
||||||
type_mismatch = not isinstance(arg_value, int)
|
|
||||||
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
|
|
||||||
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
|
||||||
raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.')
|
|
||||||
return arg_value
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_isinstance(arg_name, arg_value, classes):
|
def check_isinstance(arg_name, arg_value, classes):
|
||||||
"""Check arg isinstance of classes"""
|
"""Check arg isinstance of classes"""
|
||||||
|
@ -388,33 +364,6 @@ class ParamValidator:
|
||||||
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
|
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
|
||||||
return arg_value
|
return arg_value
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel):
|
|
||||||
"""Is it necessary to consider error when comparing float values."""
|
|
||||||
rel_fn = Rel.get_fns(rel)
|
|
||||||
if not rel_fn(arg_value, lower_limit, upper_limit):
|
|
||||||
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
|
||||||
raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
|
|
||||||
return arg_value
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_subclass(arg_name, type_, template_type, with_type_of=True):
|
|
||||||
"""Check whether some type is subclass of another type"""
|
|
||||||
if not isinstance(template_type, Iterable):
|
|
||||||
template_type = (template_type,)
|
|
||||||
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
|
||||||
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
|
|
||||||
raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass'
|
|
||||||
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_args_tensor(args):
|
|
||||||
"""Check whether args are all tensor."""
|
|
||||||
if not isinstance(args, dict):
|
|
||||||
raise TypeError("The args should be a dict.")
|
|
||||||
for arg, value in args.items():
|
|
||||||
ParamValidator.check_subclass(arg, value, mstype.tensor)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_bool(arg_name, arg_value):
|
def check_bool(arg_name, arg_value):
|
||||||
"""Check arg isinstance of bool"""
|
"""Check arg isinstance of bool"""
|
||||||
|
@ -442,113 +391,6 @@ class ParamValidator:
|
||||||
return arg_value
|
return arg_value
|
||||||
raise_error_msg()
|
raise_error_msg()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_typename(arg_name, arg_type, valid_types):
|
|
||||||
"""Does it contain the _name_ attribute."""
|
|
||||||
|
|
||||||
def get_typename(t):
|
|
||||||
return t.__name__ if hasattr(t, '__name__') else str(t)
|
|
||||||
|
|
||||||
if isinstance(arg_type, type(mstype.tensor)):
|
|
||||||
arg_type = arg_type.element_type()
|
|
||||||
|
|
||||||
if arg_type in valid_types:
|
|
||||||
return arg_type
|
|
||||||
type_names = [get_typename(t) for t in valid_types]
|
|
||||||
if len(valid_types) == 1:
|
|
||||||
raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},'
|
|
||||||
f' but got {get_typename(arg_type)}.')
|
|
||||||
raise ValueError(f'The type of `{arg_name}` should be one of {type_names},'
|
|
||||||
f' but got {get_typename(arg_type)}.')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_string(arg_name, arg_value, valid_values):
|
|
||||||
"""String type judgment."""
|
|
||||||
if isinstance(arg_value, str) and arg_value in valid_values:
|
|
||||||
return arg_value
|
|
||||||
if len(valid_values) == 1:
|
|
||||||
raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},'
|
|
||||||
f' but got {arg_value}.')
|
|
||||||
raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},'
|
|
||||||
f' but got {arg_value}.')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_type_same(args, valid_values):
|
|
||||||
"""Determine whether the types are the same."""
|
|
||||||
name = list(args.keys())[0]
|
|
||||||
value = list(args.values())[0]
|
|
||||||
if isinstance(value, type(mstype.tensor)):
|
|
||||||
value = value.element_type()
|
|
||||||
for arg_name, arg_value in args.items():
|
|
||||||
if isinstance(arg_value, type(mstype.tensor)):
|
|
||||||
arg_value = arg_value.element_type()
|
|
||||||
|
|
||||||
if arg_value not in valid_values:
|
|
||||||
raise TypeError(f'The `{arg_name}` should be in {valid_values},'
|
|
||||||
f' but `{arg_name}` is {arg_value}.')
|
|
||||||
if arg_value != value:
|
|
||||||
raise TypeError(f'`{arg_name}` should be same as `{name}`,'
|
|
||||||
f' but `{arg_name}` is {arg_value}, `{name}` is {value}.')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type):
|
|
||||||
"""Determine whether the types of two variables are the same."""
|
|
||||||
if arg1_type != arg2_type:
|
|
||||||
raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_value_on_integer(arg_name, arg_value, value, rel):
|
|
||||||
"""Judging integer type."""
|
|
||||||
rel_fn = Rel.get_fns(rel)
|
|
||||||
type_match = isinstance(arg_value, int)
|
|
||||||
if type_match and (not rel_fn(arg_value, value)):
|
|
||||||
rel_str = Rel.get_strs(rel).format(value)
|
|
||||||
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
|
|
||||||
return arg_value
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_param_equal(param1_name, param1_value, param2_name, param2_value):
|
|
||||||
"""Judging the equality of parameters."""
|
|
||||||
if param1_value != param2_value:
|
|
||||||
raise ValueError(f"`{param1_name}` must equal `{param2_name}`,"
|
|
||||||
f" but got `{param1_name}` = {param1_value},"
|
|
||||||
f" `{param2_name}` = {param2_value}.")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_const_input(arg_name, arg_value):
|
|
||||||
"""Check valid value."""
|
|
||||||
if arg_value is None:
|
|
||||||
raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_float_positive(arg_name, arg_value):
|
|
||||||
"""Float type judgment."""
|
|
||||||
if isinstance(arg_value, float):
|
|
||||||
if arg_value > 0:
|
|
||||||
return arg_value
|
|
||||||
raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.")
|
|
||||||
|
|
||||||
raise TypeError(f"`{arg_name}` must be float!")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_pad_value_by_mode(op_name, pad_mode, padding):
|
|
||||||
"""Validate value of padding according to pad_mode"""
|
|
||||||
if pad_mode != 'pad' and padding != 0:
|
|
||||||
raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
|
||||||
return padding
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_empty_shape_input(arg_name, arg_value):
|
|
||||||
"""Check zeros value."""
|
|
||||||
if 0 in arg_value:
|
|
||||||
raise ValueError(f"Input `{arg_name}` cannot be empty.")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_scalar_shape_input(arg_name, arg_value):
|
|
||||||
"""Check scalar shape input."""
|
|
||||||
if arg_value != []:
|
|
||||||
raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}")
|
|
||||||
|
|
||||||
|
|
||||||
def check_int(input_param):
|
def check_int(input_param):
|
||||||
"""Int type judgment."""
|
"""Int type judgment."""
|
||||||
|
|
|
@ -592,11 +592,6 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_
|
||||||
return get_single_type((*tuple_ptr)[output_idx]);
|
return get_single_type((*tuple_ptr)[output_idx]);
|
||||||
};
|
};
|
||||||
TypePtr type_ptr = node->Type();
|
TypePtr type_ptr = node->Type();
|
||||||
if (type_ptr->isa<RefType>()) {
|
|
||||||
auto ref_type_ptr = type_ptr->cast<RefTypePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(ref_type_ptr);
|
|
||||||
return get_tuple_type(ref_type_ptr->subtype(), output_idx);
|
|
||||||
}
|
|
||||||
return get_tuple_type(type_ptr, output_idx);
|
return get_tuple_type(type_ptr, output_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
#include "abstract/abstract_value.h"
|
#include "abstract/abstract_value.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
|
#include "ir/dtype.h"
|
||||||
#include "abstract/dshape.h"
|
#include "abstract/dshape.h"
|
||||||
#include "abstract/param_validator.h"
|
#include "abstract/param_validator.h"
|
||||||
#include "frontend/operator/cc_implementations.h"
|
#include "frontend/operator/cc_implementations.h"
|
||||||
|
@ -43,15 +44,15 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) {
|
||||||
return empty;
|
return empty;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list,
|
void ProcessDefault(const std::string &func_name, size_t actual_param_number, const std::vector<Signature> &signature,
|
||||||
const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *const op_inputs) {
|
bool has_var, std::vector<AnfNodePtr> *const op_inputs) {
|
||||||
std::size_t sig_size = signature.size();
|
std::size_t sig_size = signature.size();
|
||||||
auto positional_size = sig_size;
|
auto positional_size = sig_size;
|
||||||
if (has_var) {
|
if (has_var) {
|
||||||
positional_size = sig_size - 1;
|
positional_size = sig_size - 1;
|
||||||
}
|
}
|
||||||
if (args_spec_list.size() < positional_size) {
|
if (actual_param_number < positional_size) {
|
||||||
for (size_t i = args_spec_list.size(); i < sig_size; ++i) {
|
for (size_t i = actual_param_number; i < sig_size; ++i) {
|
||||||
auto default_value = signature[i].default_value;
|
auto default_value = signature[i].default_value;
|
||||||
if (default_value == nullptr) {
|
if (default_value == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
|
MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
|
||||||
|
@ -67,23 +68,11 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
|
||||||
*max_type_number = type_number;
|
*max_type_number = type_number;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id,
|
bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, bool is_write, TypeId *arg_type_id,
|
||||||
TypeId *arg_type = nullptr) {
|
TypeId *arg_type = nullptr) {
|
||||||
if (arg_value->isa<abstract::AbstractRef>()) {
|
if (arg_type_origin->isa<TensorType>()) {
|
||||||
auto ref = arg_value->cast<abstract::AbstractRefPtr>();
|
auto tensor = arg_type_origin->cast<TensorTypePtr>();
|
||||||
arg_value = ref->ref();
|
auto tensor_type = tensor->element();
|
||||||
if (!is_write && ref->need_cast()) {
|
|
||||||
auto tensor_type = ref->target_type();
|
|
||||||
*arg_type_id = tensor_type->type_id();
|
|
||||||
if (arg_type != nullptr) {
|
|
||||||
*arg_type = kObjectTypeTensorType;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (arg_value->isa<abstract::AbstractTensor>()) {
|
|
||||||
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
|
|
||||||
auto tensor_type = tensor->element()->BuildType();
|
|
||||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||||
*arg_type_id = tensor_type->type_id();
|
*arg_type_id = tensor_type->type_id();
|
||||||
if (arg_type != nullptr) {
|
if (arg_type != nullptr) {
|
||||||
|
@ -91,9 +80,8 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg_value->isa<abstract::AbstractScalar>()) {
|
if (arg_type_origin->isa<Number>()) {
|
||||||
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
|
auto scalar_type = arg_type_origin->cast<NumberPtr>();
|
||||||
auto scalar_type = scalar->BuildType();
|
|
||||||
MS_EXCEPTION_IF_NULL(scalar_type);
|
MS_EXCEPTION_IF_NULL(scalar_type);
|
||||||
*arg_type_id = scalar_type->type_id();
|
*arg_type_id = scalar_type->type_id();
|
||||||
if (arg_type != nullptr) {
|
if (arg_type != nullptr) {
|
||||||
|
@ -104,7 +92,7 @@ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices,
|
TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t> indices,
|
||||||
const std::set<size_t> &write_indices) {
|
const std::set<size_t> &write_indices) {
|
||||||
TypeId max_type_id = kTypeUnknown;
|
TypeId max_type_id = kTypeUnknown;
|
||||||
size_t max_type_number = 0;
|
size_t max_type_number = 0;
|
||||||
|
@ -115,7 +103,7 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
||||||
TypeId arg_type_id = kTypeUnknown;
|
TypeId arg_type_id = kTypeUnknown;
|
||||||
TypeId arg_type = kTypeUnknown;
|
TypeId arg_type = kTypeUnknown;
|
||||||
auto is_write = (write_indices.find(index) != write_indices.end());
|
auto is_write = (write_indices.find(index) != write_indices.end());
|
||||||
if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) {
|
if (!GetTensorOrScalarTypeInfo(input_types[index], is_write, &arg_type_id, &arg_type)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (arg_type != kObjectTypeTensorType) {
|
if (arg_type != kObjectTypeTensorType) {
|
||||||
|
@ -161,8 +149,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
||||||
|
|
||||||
// Get the largest type of index in the same SignatureEnumDType of arguments.
|
// Get the largest type of index in the same SignatureEnumDType of arguments.
|
||||||
using MaxTypeMap = std::map<SignatureEnumDType, TypeId>;
|
using MaxTypeMap = std::map<SignatureEnumDType, TypeId>;
|
||||||
MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
|
MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, const std::vector<TypePtr> &input_types,
|
||||||
const abstract::AbstractBasePtrList &args_spec_list, const std::set<size_t> &write_indices) {
|
const std::set<size_t> &write_indices) {
|
||||||
// record index for signature.dtypes of the same type
|
// record index for signature.dtypes of the same type
|
||||||
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
|
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
|
||||||
std::map<SignatureEnumDType, std::vector<size_t>> type_indices;
|
std::map<SignatureEnumDType, std::vector<size_t>> type_indices;
|
||||||
|
@ -184,11 +172,8 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
|
||||||
}
|
}
|
||||||
bool has_tensor = false;
|
bool has_tensor = false;
|
||||||
for (const auto &index : indices) {
|
for (const auto &index : indices) {
|
||||||
AbstractBasePtr arg_value = args_spec_list[index];
|
auto arg_value = input_types[index];
|
||||||
if (arg_value->isa<abstract::AbstractRef>()) {
|
if (arg_value->isa<TensorType>()) {
|
||||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
|
||||||
}
|
|
||||||
if (arg_value->isa<abstract::AbstractTensor>()) {
|
|
||||||
has_tensor = true;
|
has_tensor = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -197,7 +182,7 @@ MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
|
||||||
(void)dst_type.insert(std::make_pair(type, kTypeUnknown));
|
(void)dst_type.insert(std::make_pair(type, kTypeUnknown));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices)));
|
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(input_types, indices, write_indices)));
|
||||||
}
|
}
|
||||||
return dst_type;
|
return dst_type;
|
||||||
}
|
}
|
||||||
|
@ -211,7 +196,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap
|
||||||
}
|
}
|
||||||
|
|
||||||
void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
|
void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
|
||||||
const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph,
|
const std::vector<TypePtr> &input_types, const FuncGraphPtr &graph,
|
||||||
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) {
|
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) {
|
||||||
std::vector<SignatureEnumDType> dtypes;
|
std::vector<SignatureEnumDType> dtypes;
|
||||||
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
||||||
|
@ -221,9 +206,9 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
|
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
|
||||||
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices);
|
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, input_types, write_indices);
|
||||||
// Identify which arg requires auto cast
|
// Identify which arg requires auto cast
|
||||||
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
for (size_t i = 0; i < input_types.size(); ++i) {
|
||||||
auto it = dst_type.find(dtypes[i]);
|
auto it = dst_type.find(dtypes[i]);
|
||||||
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -232,7 +217,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
|
||||||
auto is_write = (rw_it != write_indices.end());
|
auto is_write = (rw_it != write_indices.end());
|
||||||
|
|
||||||
TypeId arg_type_id = kTypeUnknown;
|
TypeId arg_type_id = kTypeUnknown;
|
||||||
AbstractBasePtr arg_value = args_spec_list[i];
|
auto arg_value = input_types[i];
|
||||||
(void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id);
|
(void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id);
|
||||||
auto it_map = type_name_map.find(arg_type_id);
|
auto it_map = type_name_map.find(arg_type_id);
|
||||||
if (it_map == type_name_map.end()) {
|
if (it_map == type_name_map.end()) {
|
||||||
|
@ -248,7 +233,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) {
|
if ((arg_value->isa<TensorType>()) && arg_type_id == it->second) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id
|
MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id
|
||||||
|
@ -275,6 +260,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
||||||
}
|
}
|
||||||
std::vector<AnfNodePtr> op_inputs;
|
std::vector<AnfNodePtr> op_inputs;
|
||||||
std::set<size_t> write_indices;
|
std::set<size_t> write_indices;
|
||||||
|
std::vector<TypePtr> input_types;
|
||||||
op_inputs.push_back(NewValueNode(function));
|
op_inputs.push_back(NewValueNode(function));
|
||||||
// Assume, the write input of op is always the first input. We check if any write op,
|
// Assume, the write input of op is always the first input. We check if any write op,
|
||||||
// and add cast op on other inputs to keep the same type with assigned parameter.
|
// and add cast op on other inputs to keep the same type with assigned parameter.
|
||||||
|
@ -292,30 +278,36 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
||||||
sig = signature[sig_size - 1].rw;
|
sig = signature[sig_size - 1].rw;
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr type = args_spec_list[i]->GetTypeTrack();
|
TypePtr type = args_spec_list[i]->BuildType();
|
||||||
if (type && type->type_id() == kObjectTypeRef) {
|
if (type && type->isa<RefType>()) {
|
||||||
auto ref_abs = args_spec_list[i]->cast<abstract::AbstractRefPtr>();
|
auto cast_type = parse::GetMixedPrecisionTargetType(func_graph);
|
||||||
if (sig == SignatureEnumRW::kRWRead) {
|
if (sig == SignatureEnumRW::kRWRead) {
|
||||||
param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
|
auto source_tensor_type = type->cast<TensorTypePtr>();
|
||||||
if (ref_abs && ref_abs->need_cast()) {
|
if (source_tensor_type != nullptr) {
|
||||||
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
|
auto source_element = source_tensor_type->element();
|
||||||
param = NewCNode({NewValueNode(cast), param, NewValueNode(ref_abs->target_type())}, func_graph);
|
if (cast_type != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
|
||||||
|
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
|
||||||
|
param = NewCNode({NewValueNode(cast), param, NewValueNode(cast_type)}, func_graph);
|
||||||
|
type = cast_type->type_id() == kNumberTypeFloat16 ? kTensorTypeFP16 : kTensorTypeFP32;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if (sig == SignatureEnumRW::kRWWrite) {
|
} else if (sig == SignatureEnumRW::kRWWrite) {
|
||||||
param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
|
|
||||||
write_indices.insert(i);
|
write_indices.insert(i);
|
||||||
}
|
}
|
||||||
// If sig is SignatureEnumRW::kRWRef, not do anything.
|
// If sig is SignatureEnumRW::kRWRef, not do anything.
|
||||||
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
|
} else if (sig == SignatureEnumRW::kRWWrite &&
|
||||||
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter.";
|
!((type->type_id() == kObjectTypeRef) || (type->type_id() == kObjectTypeRefKey))) {
|
||||||
|
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but "
|
||||||
|
<< type->ToString();
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type "
|
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type "
|
||||||
<< args_spec_list[i]->ToString();
|
<< args_spec_list[i]->ToString();
|
||||||
|
input_types.push_back(type);
|
||||||
op_inputs.push_back(param);
|
op_inputs.push_back(param);
|
||||||
}
|
}
|
||||||
// process default
|
// process default
|
||||||
ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs);
|
ProcessDefault(func_name, args_spec_list.size(), signature, has_var, &op_inputs);
|
||||||
DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices);
|
DoAutoCast(func_name, signature, input_types, func_graph, &op_inputs, write_indices);
|
||||||
return func_graph->NewCNode(op_inputs);
|
return func_graph->NewCNode(op_inputs);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -81,12 +81,6 @@ void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &
|
||||||
}
|
}
|
||||||
Register(types_name, py_fn);
|
Register(types_name, py_fn);
|
||||||
}
|
}
|
||||||
static TypePtr UnwrapRef(const TypePtr &type) {
|
|
||||||
if (type->isa<RefType>()) {
|
|
||||||
return type->cast<RefTypePtr>()->subtype();
|
|
||||||
}
|
|
||||||
return type;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return Exact match if exists, else return non ambiguous sub class match
|
// Return Exact match if exists, else return non ambiguous sub class match
|
||||||
// Return py::none() if matching is ambiguous
|
// Return py::none() if matching is ambiguous
|
||||||
|
@ -99,7 +93,7 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
|
||||||
}
|
}
|
||||||
auto match = true;
|
auto match = true;
|
||||||
for (size_t i = 0; i < sign.size(); ++i) {
|
for (size_t i = 0; i < sign.size(); ++i) {
|
||||||
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) {
|
if (!IsIdentidityOrSubclass(types[i], sign[i])) {
|
||||||
match = false;
|
match = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -627,6 +627,16 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt
|
||||||
|
|
||||||
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
|
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: a tensor
|
||||||
|
CheckArgsSize(primitive->name(), args_spec_list, 2);
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0];
|
||||||
|
return args_spec_list[0];
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof);
|
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof);
|
||||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType);
|
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType);
|
||||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord);
|
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord);
|
||||||
|
@ -648,5 +658,7 @@ REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImpl
|
||||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ);
|
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ);
|
||||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
|
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
|
||||||
InferImplBroadcastGradientArgs);
|
InferImplBroadcastGradientArgs);
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign);
|
||||||
|
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "ir/param_info.h"
|
#include "ir/param_info.h"
|
||||||
|
#include "ir/meta_tensor.h"
|
||||||
#include "pipeline/jit/parse/python_adapter.h"
|
#include "pipeline/jit/parse/python_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -38,8 +39,7 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
|
||||||
if (!para_ptr->has_default()) {
|
if (!para_ptr->has_default()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto obj = py::cast(para_ptr->default_param());
|
auto param_value = para_ptr->param_info();
|
||||||
auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
|
|
||||||
if (param_value == nullptr) {
|
if (param_value == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1356,8 +1356,7 @@ bool ParameterIsCloned(const AnfNodePtr ¶meter_node) {
|
||||||
if (!cloned_parameter->has_default()) {
|
if (!cloned_parameter->has_default()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto obj = py::cast(cloned_parameter->default_param());
|
auto param_value = cloned_parameter->param_info();
|
||||||
auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
|
|
||||||
if (param_value == nullptr) {
|
if (param_value == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -1380,8 +1379,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
||||||
if (!ParameterIsCloned(cloned_parameter_node)) {
|
if (!ParameterIsCloned(cloned_parameter_node)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto obj = py::cast(cloned_parameter->default_param());
|
auto param_value = cloned_parameter->param_info();
|
||||||
auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
|
|
||||||
if (param_value == nullptr) {
|
if (param_value == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -1400,10 +1398,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto ¶m_value_cloned = be_cloned_parameter->default_param();
|
auto param_value_in = be_cloned_parameter->param_info();
|
||||||
|
|
||||||
auto obj_in = py::cast(param_value_cloned);
|
|
||||||
auto param_value_in = py::cast<ParamValuePtr>(obj_in.attr("_value"));
|
|
||||||
if (param_value_in == nullptr) {
|
if (param_value_in == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
|
@ -233,13 +233,14 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
||||||
for (const auto ¶m : func_graph->parameters()) {
|
for (const auto ¶m : func_graph->parameters()) {
|
||||||
auto param_node = std::static_pointer_cast<Parameter>(param);
|
auto param_node = std::static_pointer_cast<Parameter>(param);
|
||||||
if (param_node->has_default()) {
|
if (param_node->has_default()) {
|
||||||
ValuePtr value = param_node->default_param();
|
auto value = param_node->default_param();
|
||||||
constexpr bool broaden = true;
|
auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
|
||||||
AbstractBasePtr ptr = abstract::FromValue(value, broaden);
|
auto ref_key = std::make_shared<RefKey>(param_node->name());
|
||||||
|
auto abs_ref_key = ref_key->ToAbstract();
|
||||||
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
|
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
|
||||||
args_spec.push_back(ptr);
|
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, abs_ref);
|
||||||
parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr);
|
args_spec.push_back(abs_ref);
|
||||||
|
parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, abs_ref);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Analyze
|
// Analyze
|
||||||
|
|
|
@ -425,9 +425,6 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
|
||||||
converted = env;
|
converted = env;
|
||||||
} else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
|
} else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
|
||||||
converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
|
converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
|
||||||
} else if (py::hasattr(obj, "__parameter__")) {
|
|
||||||
auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
|
|
||||||
ret = ConvertData(to_convert, &converted);
|
|
||||||
} else {
|
} else {
|
||||||
ret = ConvertOtherObj(obj, &converted);
|
ret = ConvertOtherObj(obj, &converted);
|
||||||
}
|
}
|
||||||
|
@ -555,9 +552,6 @@ void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name)
|
||||||
|
|
||||||
ValuePtr PyDataToValue(const py::object &obj) {
|
ValuePtr PyDataToValue(const py::object &obj) {
|
||||||
py::object to_convert = obj;
|
py::object to_convert = obj;
|
||||||
if (py::hasattr(obj, "__parameter__")) {
|
|
||||||
to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
|
|
||||||
}
|
|
||||||
ValuePtr value = nullptr;
|
ValuePtr value = nullptr;
|
||||||
(void)ConvertData(to_convert, &value);
|
(void)ConvertData(to_convert, &value);
|
||||||
return value;
|
return value;
|
||||||
|
|
|
@ -306,7 +306,14 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr
|
||||||
}
|
}
|
||||||
|
|
||||||
void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) {
|
void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) {
|
||||||
state_assign_[target] = readid;
|
const std::string primitive_name("assign");
|
||||||
|
const std::string module_name("mindspore.ops.functional");
|
||||||
|
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
|
||||||
|
auto source = ReadVariable(readid);
|
||||||
|
auto assign = func_graph()->NewCNode({assign_op, target, source});
|
||||||
|
WriteVariable(readid, assign);
|
||||||
|
MS_LOG(INFO) << "SetState read " << target->DebugString() << ", " << readid;
|
||||||
|
AddAutoDepend(assign);
|
||||||
}
|
}
|
||||||
|
|
||||||
void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); }
|
void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); }
|
||||||
|
@ -321,21 +328,13 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
|
||||||
ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple);
|
ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple);
|
||||||
ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend);
|
ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend);
|
||||||
ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient);
|
ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient);
|
||||||
const std::string primitive_name("assign");
|
|
||||||
const std::string module_name("mindspore.ops.functional");
|
if (auto_depends_.size() == 0) {
|
||||||
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
|
|
||||||
if (state_assign_.size() == 0 && auto_depends_.size() == 0) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
AnfNodePtr state = nullptr;
|
AnfNodePtr state = nullptr;
|
||||||
std::vector<AnfNodePtr> vec_states;
|
std::vector<AnfNodePtr> vec_states;
|
||||||
vec_states.emplace_back(make_tuple_op);
|
vec_states.emplace_back(make_tuple_op);
|
||||||
for (auto &item : state_assign_) {
|
|
||||||
auto source = ReadVariable(item.second);
|
|
||||||
auto assign = func_graph()->NewCNode({assign_op, item.first, source});
|
|
||||||
MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second;
|
|
||||||
vec_states.emplace_back(assign);
|
|
||||||
}
|
|
||||||
for (auto &item : auto_depends_) {
|
for (auto &item : auto_depends_) {
|
||||||
MS_LOG(DEBUG) << "auto_depends " << item->ToString();
|
MS_LOG(DEBUG) << "auto_depends " << item->ToString();
|
||||||
vec_states.emplace_back(item);
|
vec_states.emplace_back(item);
|
||||||
|
@ -361,7 +360,6 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
|
||||||
AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state});
|
AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state});
|
||||||
AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped});
|
AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped});
|
||||||
func_graph()->set_output(ret, true);
|
func_graph()->set_output(ret, true);
|
||||||
state_assign_.clear();
|
|
||||||
}
|
}
|
||||||
} // namespace parse
|
} // namespace parse
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -101,9 +101,6 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
||||||
// keeps all removable phis which will be removed in one pass.
|
// keeps all removable phis which will be removed in one pass.
|
||||||
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_;
|
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_;
|
||||||
|
|
||||||
// set state nodes need to insert before function return nodes.
|
|
||||||
OrderedMap<AnfNodePtr, std::string> state_assign_;
|
|
||||||
|
|
||||||
// hold declared global variables in function
|
// hold declared global variables in function
|
||||||
std::set<std::string> global_vars_;
|
std::set<std::string> global_vars_;
|
||||||
|
|
||||||
|
|
|
@ -59,14 +59,13 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
|
||||||
return func_graph;
|
return func_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) {
|
TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) {
|
||||||
TypePtr dst_type;
|
|
||||||
if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
|
if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
|
||||||
return kFloat32;
|
return kFloat32;
|
||||||
} else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
|
} else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
|
||||||
return kFloat16;
|
return kFloat16;
|
||||||
} else {
|
} else {
|
||||||
return kNone;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -364,7 +364,7 @@ class ParseAst {
|
||||||
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph);
|
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m);
|
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m);
|
||||||
ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m);
|
TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
} // namespace parse
|
} // namespace parse
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -105,24 +105,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
|
||||||
auto value = py::cast<tensor::MetaTensorPtr>(obj);
|
auto value = py::cast<tensor::MetaTensorPtr>(obj);
|
||||||
node->set_default_param(value);
|
node->set_default_param(value);
|
||||||
// set_abstract for parameter
|
// set_abstract for parameter
|
||||||
constexpr bool broaden = true;
|
auto abs = value->ToAbstract();
|
||||||
node->set_abstract(abstract::FromValue(value, broaden));
|
node->set_abstract(abs);
|
||||||
para_node = node;
|
para_node = node;
|
||||||
}
|
}
|
||||||
auto iter = func_graph->make_ref_params().find(para_node);
|
|
||||||
if (iter == func_graph->make_ref_params().end()) {
|
|
||||||
ValuePtr target_type = GetMixedPrecisionTargetType(func_graph, para_node);
|
|
||||||
|
|
||||||
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
|
return para_node;
|
||||||
AnfNodePtr ref_key = NewValueNode(std::make_shared<RefKey>(param_name));
|
|
||||||
AnfNodePtr target_type_node = NewValueNode(target_type);
|
|
||||||
AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, para_node, target_type_node});
|
|
||||||
func_graph->make_ref_params()[para_node] = ref_node;
|
|
||||||
func_graph->add_parameter_obj_node(ref_node);
|
|
||||||
return ref_node;
|
|
||||||
} else {
|
|
||||||
return iter->second;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) {
|
bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) {
|
||||||
|
|
|
@ -640,7 +640,14 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
||||||
size_t size = op_exec_info->op_inputs.size();
|
size_t size = op_exec_info->op_inputs.size();
|
||||||
for (size_t i = 0; i < size; i++) {
|
for (size_t i = 0; i < size; i++) {
|
||||||
auto obj = op_exec_info->op_inputs[i];
|
auto obj = op_exec_info->op_inputs[i];
|
||||||
bool op_mask = py::hasattr(obj, "__parameter__");
|
bool op_mask = false;
|
||||||
|
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||||
|
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
||||||
|
if (meta_tensor) {
|
||||||
|
op_mask = meta_tensor->is_parameter();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
(*op_masks).push_back(op_mask);
|
(*op_masks).push_back(op_mask);
|
||||||
MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ "
|
MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ "
|
||||||
<< grad_flag_;
|
<< grad_flag_;
|
||||||
|
@ -990,8 +997,9 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
|
||||||
if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
|
if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
|
||||||
auto free_param = df_builder_->add_parameter();
|
auto free_param = df_builder_->add_parameter();
|
||||||
free_param->set_name(param_name);
|
free_param->set_name(param_name);
|
||||||
free_param->set_default_param(py::cast<tensor::TensorPtr>(obj));
|
|
||||||
free_param->debug_info()->set_name(param_name);
|
free_param->debug_info()->set_name(param_name);
|
||||||
|
auto value = py::cast<tensor::TensorPtr>(obj);
|
||||||
|
free_param->set_default_param(value);
|
||||||
MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
|
MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
|
||||||
graph_info_map_[df_builder_].param_map[obj_id] = free_param;
|
graph_info_map_[df_builder_].param_map[obj_id] = free_param;
|
||||||
return free_param;
|
return free_param;
|
||||||
|
@ -1159,17 +1167,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
|
||||||
auto param_name = py::cast<std::string>(name_attr);
|
auto param_name = py::cast<std::string>(name_attr);
|
||||||
auto free_param = df_builder_->add_parameter();
|
auto free_param = df_builder_->add_parameter();
|
||||||
free_param->set_name(param_name);
|
free_param->set_name(param_name);
|
||||||
free_param->set_default_param(py::cast<tensor::TensorPtr>(param));
|
auto value = py::cast<tensor::TensorPtr>(param);
|
||||||
|
free_param->set_default_param(value);
|
||||||
free_param->debug_info()->set_name(param_name);
|
free_param->debug_info()->set_name(param_name);
|
||||||
para_node = free_param;
|
para_node = free_param;
|
||||||
}
|
}
|
||||||
ValuePtr target_type = parse::GetMixedPrecisionTargetType(df_builder_, para_node);
|
w_args.push_back(para_node);
|
||||||
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
|
|
||||||
auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name());
|
|
||||||
AnfNodePtr ref_key_node = NewValueNode(refkey);
|
|
||||||
AnfNodePtr target_type_node = NewValueNode(target_type);
|
|
||||||
AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, para_node, target_type_node});
|
|
||||||
w_args.push_back(ref_node);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "training not paramter_tuple";
|
MS_LOG(DEBUG) << "training not paramter_tuple";
|
||||||
|
@ -1197,7 +1200,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
|
||||||
auto param_node = std::static_pointer_cast<Parameter>(param);
|
auto param_node = std::static_pointer_cast<Parameter>(param);
|
||||||
if (param_node->has_default()) {
|
if (param_node->has_default()) {
|
||||||
ValuePtr value = param_node->default_param();
|
ValuePtr value = param_node->default_param();
|
||||||
AbstractBasePtr ptr = abstract::FromValue(value, true);
|
auto ptr = value->ToAbstract();
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Args convert error";
|
MS_LOG(EXCEPTION) << "Args convert error";
|
||||||
}
|
}
|
||||||
|
|
|
@ -147,7 +147,7 @@ REGISTER_PYBIND_DEFINE(
|
||||||
(void)py::class_<TypeType, Type, std::shared_ptr<TypeType>>(m_sub, "TypeType").def(py::init());
|
(void)py::class_<TypeType, Type, std::shared_ptr<TypeType>>(m_sub, "TypeType").def(py::init());
|
||||||
(void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init());
|
(void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init());
|
||||||
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
|
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
|
||||||
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
|
(void)py::class_<RefType, TensorType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
|
||||||
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
|
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
|
||||||
(void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init());
|
(void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init());
|
||||||
(void)py::class_<TypeEllipsis, Type, std::shared_ptr<TypeEllipsis>>(m_sub, "TypeEllipsis").def(py::init());
|
(void)py::class_<TypeEllipsis, Type, std::shared_ptr<TypeEllipsis>>(m_sub, "TypeEllipsis").def(py::init());
|
||||||
|
|
|
@ -21,7 +21,7 @@ namespace mindspore {
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
|
REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
|
||||||
(void)py::class_<ParamInfo, ParamValuePtr>(*m, "ParamInfo")
|
(void)py::class_<ParamInfo, ParamInfoPtr>(*m, "ParamInfo")
|
||||||
.def(py::init())
|
.def(py::init())
|
||||||
.def("clone", &ParamInfo::Clone)
|
.def("clone", &ParamInfo::Clone)
|
||||||
.def_property("name", &ParamInfo::name, &ParamInfo::set_name)
|
.def_property("name", &ParamInfo::name, &ParamInfo::set_name)
|
||||||
|
@ -36,7 +36,7 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
|
||||||
if (t.size() != 6) {
|
if (t.size() != 6) {
|
||||||
std::runtime_error("Invalid state for ParamInfo!");
|
std::runtime_error("Invalid state for ParamInfo!");
|
||||||
}
|
}
|
||||||
ParamValuePtr p = std::make_shared<ParamInfo>();
|
ParamInfoPtr p = std::make_shared<ParamInfo>();
|
||||||
p->set_name(t[1].cast<std::string>());
|
p->set_name(t[1].cast<std::string>());
|
||||||
p->set_requires_grad(t[2].cast<bool>());
|
p->set_requires_grad(t[2].cast<bool>());
|
||||||
p->set_layerwise_parallel(t[3].cast<bool>());
|
p->set_layerwise_parallel(t[3].cast<bool>());
|
||||||
|
|
|
@ -291,6 +291,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
||||||
.def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape"))
|
.def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape"))
|
||||||
.def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
|
.def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
|
||||||
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.")
|
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.")
|
||||||
|
.def_property("_param_info", &MetaTensor::param_info, &MetaTensor::set_param_info)
|
||||||
.def(py::pickle(
|
.def(py::pickle(
|
||||||
[](const MetaTensor &t) { // __getstate__
|
[](const MetaTensor &t) { // __getstate__
|
||||||
/* Return a tuple that fully encodes the state of the object */
|
/* Return a tuple that fully encodes the state of the object */
|
||||||
|
|
|
@ -42,7 +42,7 @@ class Parameter(MetaTensor):
|
||||||
In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
|
In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
|
||||||
an `Initializer`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor`
|
an `Initializer`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor`
|
||||||
only saves the shape and type info of a tensor with no memory usage. The shape can be changed while
|
only saves the shape and type info of a tensor with no memory usage. The shape can be changed while
|
||||||
compile for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
|
compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Each parameter of Cell is represented by Parameter class.
|
Each parameter of Cell is represented by Parameter class.
|
||||||
|
@ -108,7 +108,7 @@ class Parameter(MetaTensor):
|
||||||
Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
|
Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
|
||||||
|
|
||||||
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False):
|
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False):
|
||||||
self._value = ParamInfo()
|
self._param_info = ParamInfo()
|
||||||
self.name = name
|
self.name = name
|
||||||
self.requires_grad = requires_grad
|
self.requires_grad = requires_grad
|
||||||
self.layerwise_parallel = layerwise_parallel
|
self.layerwise_parallel = layerwise_parallel
|
||||||
|
@ -156,13 +156,13 @@ class Parameter(MetaTensor):
|
||||||
value_str = MetaTensor.__str__(self)
|
value_str = MetaTensor.__str__(self)
|
||||||
if isinstance(self, Tensor):
|
if isinstance(self, Tensor):
|
||||||
value_str = Tensor.__str__(self)
|
value_str = Tensor.__str__(self)
|
||||||
return f'Parameter (name={self._value.name}, value={value_str})'
|
return f'Parameter (name={self._param_info.name}, value={value_str})'
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
value_str = MetaTensor.__repr__(self)
|
value_str = MetaTensor.__repr__(self)
|
||||||
if isinstance(self, Tensor):
|
if isinstance(self, Tensor):
|
||||||
value_str = Tensor.__repr__(self)
|
value_str = Tensor.__repr__(self)
|
||||||
return f'Parameter (name={self._value.name}, value={value_str})'
|
return f'Parameter (name={self._param_info.name}, value={value_str})'
|
||||||
|
|
||||||
def __parameter__(self):
|
def __parameter__(self):
|
||||||
"""For parse check."""
|
"""For parse check."""
|
||||||
|
@ -181,7 +181,7 @@ class Parameter(MetaTensor):
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
"""Get the name of the parameter."""
|
"""Get the name of the parameter."""
|
||||||
return self._value.name
|
return self._param_info.name
|
||||||
|
|
||||||
@name.setter
|
@name.setter
|
||||||
def name(self, name_):
|
def name(self, name_):
|
||||||
|
@ -203,7 +203,7 @@ class Parameter(MetaTensor):
|
||||||
format(name_, PARAMETER_NAME_PREFIX_MAX_LEN))
|
format(name_, PARAMETER_NAME_PREFIX_MAX_LEN))
|
||||||
else:
|
else:
|
||||||
raise ValueError("The type of the name should be `str` or `None`.")
|
raise ValueError("The type of the name should be `str` or `None`.")
|
||||||
self._value.name = name_
|
self._param_info.name = name_
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cast_type(self):
|
def cast_type(self):
|
||||||
|
@ -254,8 +254,8 @@ class Parameter(MetaTensor):
|
||||||
_check_str_by_regular(prefix)
|
_check_str_by_regular(prefix)
|
||||||
x = copy(self)
|
x = copy(self)
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
x._value = self._value.clone()
|
x._param_info = self._param_info.clone()
|
||||||
x._value.name = prefix + '.' + self._value.name
|
x._param_info.name = prefix + '.' + self._param_info.name
|
||||||
x.is_init = False
|
x.is_init = False
|
||||||
if init != 'same':
|
if init != 'same':
|
||||||
shape = self.shape
|
shape = self.shape
|
||||||
|
@ -265,24 +265,24 @@ class Parameter(MetaTensor):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layerwise_parallel(self):
|
def layerwise_parallel(self):
|
||||||
return self._value.layerwise_parallel
|
return self._param_info.layerwise_parallel
|
||||||
|
|
||||||
@layerwise_parallel.setter
|
@layerwise_parallel.setter
|
||||||
def layerwise_parallel(self, value=True):
|
def layerwise_parallel(self, value=True):
|
||||||
if not isinstance(value, bool):
|
if not isinstance(value, bool):
|
||||||
raise TypeError("`layerwise_parallel` parameter must be bool type")
|
raise TypeError("`layerwise_parallel` parameter must be bool type")
|
||||||
self._value.layerwise_parallel = value
|
self._param_info.layerwise_parallel = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def requires_grad(self):
|
def requires_grad(self):
|
||||||
"""Return whether the parameter requires gradient."""
|
"""Return whether the parameter requires gradient."""
|
||||||
return self._value.requires_grad
|
return self._param_info.requires_grad
|
||||||
|
|
||||||
@requires_grad.setter
|
@requires_grad.setter
|
||||||
def requires_grad(self, value=True):
|
def requires_grad(self, value=True):
|
||||||
if not isinstance(value, bool):
|
if not isinstance(value, bool):
|
||||||
raise TypeError("`requires_grad` parameter must be bool type")
|
raise TypeError("`requires_grad` parameter must be bool type")
|
||||||
self._value.requires_grad = value
|
self._param_info.requires_grad = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self):
|
def data(self):
|
||||||
|
|
|
@ -459,10 +459,6 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
||||||
}
|
}
|
||||||
auto other_tensor = dyn_cast<AbstractTensor>(other);
|
auto other_tensor = dyn_cast<AbstractTensor>(other);
|
||||||
if (other_tensor == nullptr) {
|
if (other_tensor == nullptr) {
|
||||||
auto ref_tensor = dyn_cast<AbstractRef>(other);
|
|
||||||
if (ref_tensor != nullptr) {
|
|
||||||
return this->Join(ref_tensor->ref());
|
|
||||||
}
|
|
||||||
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
|
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
|
||||||
}
|
}
|
||||||
if (*this == *other) {
|
if (*this == *other) {
|
||||||
|
@ -473,7 +469,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
||||||
return std::make_shared<AbstractTensor>(element, shape);
|
return std::make_shared<AbstractTensor>(element, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AbstractTensor::operator==(const AbstractTensor &other) const {
|
bool AbstractTensor::equal_to(const AbstractTensor &other) const {
|
||||||
if (&other == this) {
|
if (&other == this) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -491,12 +487,14 @@ bool AbstractTensor::operator==(const AbstractTensor &other) const {
|
||||||
return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal;
|
return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool AbstractTensor::operator==(const AbstractTensor &other) const { return equal_to(other); }
|
||||||
|
|
||||||
bool AbstractTensor::operator==(const AbstractBase &other) const {
|
bool AbstractTensor::operator==(const AbstractBase &other) const {
|
||||||
if (&other == this) {
|
if (&other == this) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (other.isa<AbstractTensor>()) {
|
if (other.tid() == tid()) {
|
||||||
auto other_tensor = static_cast<const AbstractTensor *>(&other);
|
auto other_tensor = static_cast<const AbstractTensor *>(&other);
|
||||||
return *this == *other_tensor;
|
return *this == *other_tensor;
|
||||||
} else {
|
} else {
|
||||||
|
@ -822,39 +820,21 @@ std::string AbstractJTagged::ToString() const {
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast,
|
AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value)
|
||||||
TypePtr cast_target)
|
: AbstractTensor(*ref_value), ref_key_(ref_key), ref_key_value_(nullptr) {
|
||||||
: ref_key_(ref_key), ref_(ref_value), need_cast_(false), target_type_(nullptr), ref_key_value_(nullptr) {
|
|
||||||
set_type(std::make_shared<RefType>());
|
set_type(std::make_shared<RefType>());
|
||||||
auto origin_type = ref_value->BuildType();
|
|
||||||
if (need_cast && cast_target && origin_type && origin_type->isa<TensorType>()) {
|
|
||||||
auto tensor_dtype = origin_type->cast<TensorTypePtr>()->element();
|
|
||||||
if (tensor_dtype && IsSubType(tensor_dtype, kFloat)) {
|
|
||||||
if (cast_target != tensor_dtype) {
|
|
||||||
need_cast_ = true;
|
|
||||||
target_type_ = cast_target;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (ref_key && ref_key->isa<AbstractRefKey>()) {
|
if (ref_key && ref_key->isa<AbstractRefKey>()) {
|
||||||
ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value();
|
ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseShapePtr AbstractRef::BuildShape() const { return ref_->BuildShape(); }
|
|
||||||
|
|
||||||
TypePtr AbstractRef::BuildType() const {
|
TypePtr AbstractRef::BuildType() const {
|
||||||
TypePtr subtype = ref_->BuildType();
|
auto subtype = AbstractTensor::BuildType()->cast<TensorTypePtr>();
|
||||||
TypePtr subtype_origin = subtype;
|
return std::make_shared<RefType>(subtype);
|
||||||
if (need_cast_) {
|
|
||||||
subtype_origin = std::make_shared<TensorType>(target_type_);
|
|
||||||
}
|
|
||||||
return std::make_shared<RefType>(subtype, subtype_origin);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AbstractRef::operator==(const AbstractRef &other) const {
|
bool AbstractRef::operator==(const AbstractRef &other) const {
|
||||||
return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) && (*ref_key_ == *other.ref_key_) &&
|
return AbstractTensor::equal_to(other) && (*ref_key_ == *other.ref_key_);
|
||||||
(!need_cast_ || (*target_type_ == *other.target_type_));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AbstractRef::operator==(const AbstractBase &other) const {
|
bool AbstractRef::operator==(const AbstractBase &other) const {
|
||||||
|
@ -886,24 +866,20 @@ AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) {
|
||||||
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
|
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
|
||||||
auto other_ref = other->cast<AbstractRefPtr>();
|
auto other_ref = other->cast<AbstractRefPtr>();
|
||||||
if (other_ref == nullptr) {
|
if (other_ref == nullptr) {
|
||||||
auto new_ref = ref_->Join(other);
|
return AbstractTensor::Join(other)->cast<AbstractTensorPtr>();
|
||||||
return std::make_shared<AbstractRef>(ref_key_, new_ref);
|
|
||||||
}
|
}
|
||||||
if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) {
|
if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) {
|
||||||
return shared_from_base<AbstractBase>();
|
return shared_from_base<AbstractBase>();
|
||||||
}
|
}
|
||||||
auto ref_key = ref_key_->Join(other_ref->ref_key_);
|
auto ref_key = ref_key_->Join(other_ref->ref_key_);
|
||||||
auto ref = ref_->Join(other_ref->ref());
|
auto ref = AbstractTensor::Join(other_ref->ref())->cast<AbstractTensorPtr>();
|
||||||
return std::make_shared<AbstractRef>(ref_key, ref);
|
return std::make_shared<AbstractRef>(ref_key, ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AbstractRef::ToString() const {
|
std::string AbstractRef::ToString() const {
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
buffer << type_name() << "("
|
buffer << type_name() << "("
|
||||||
<< "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString();
|
<< "key: " << ref_key_->ToString() << " ref_value: " << AbstractTensor::ToString();
|
||||||
if (need_cast_) {
|
|
||||||
buffer << " cast to: " << target_type_->ToString();
|
|
||||||
}
|
|
||||||
auto value = GetValueTrack();
|
auto value = GetValueTrack();
|
||||||
if (value) {
|
if (value) {
|
||||||
buffer << ", value: " << value->ToString();
|
buffer << ", value: " << value->ToString();
|
||||||
|
|
|
@ -284,11 +284,9 @@ class AbstractTensor : public AbstractUndetermined {
|
||||||
AbstractBasePtr Clone() const override;
|
AbstractBasePtr Clone() const override;
|
||||||
AbstractBasePtr Broaden(uint8_t config = 0) const override;
|
AbstractBasePtr Broaden(uint8_t config = 0) const override;
|
||||||
AbstractBasePtr BroadenWithShape() const;
|
AbstractBasePtr BroadenWithShape() const;
|
||||||
AbstractBasePtr Join(const AbstractBasePtr &other) final;
|
AbstractBasePtr Join(const AbstractBasePtr &other);
|
||||||
|
|
||||||
bool operator==(const AbstractTensor &other) const;
|
bool operator==(const AbstractTensor &other) const;
|
||||||
bool operator==(const AbstractBase &other) const override;
|
bool operator==(const AbstractBase &other) const override;
|
||||||
|
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
std::size_t hash() const override {
|
std::size_t hash() const override {
|
||||||
auto value = GetValueTrack();
|
auto value = GetValueTrack();
|
||||||
|
@ -301,6 +299,9 @@ class AbstractTensor : public AbstractUndetermined {
|
||||||
}
|
}
|
||||||
return hash_sum;
|
return hash_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool equal_to(const AbstractTensor &other) const;
|
||||||
};
|
};
|
||||||
using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
|
using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
|
||||||
using AbstractTensorPtrList = std::vector<AbstractTensorPtr>;
|
using AbstractTensorPtrList = std::vector<AbstractTensorPtr>;
|
||||||
|
@ -575,42 +576,42 @@ class AbstractRefKey : public AbstractBase {
|
||||||
};
|
};
|
||||||
using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>;
|
using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>;
|
||||||
|
|
||||||
class AbstractRef : public AbstractBase {
|
class AbstractRef : public AbstractTensor {
|
||||||
public:
|
public:
|
||||||
AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast = false,
|
AbstractRef(const AbstractBasePtr &ref_key, const AbstractTensorPtr &ref_value);
|
||||||
TypePtr cast_target = nullptr);
|
|
||||||
|
|
||||||
~AbstractRef() override = default;
|
~AbstractRef() override = default;
|
||||||
MS_DECLARE_PARENT(AbstractRef, AbstractBase)
|
MS_DECLARE_PARENT(AbstractRef, AbstractTensor)
|
||||||
|
|
||||||
TypePtr BuildType() const override;
|
TypePtr BuildType() const override;
|
||||||
BaseShapePtr BuildShape() const override;
|
|
||||||
bool operator==(const AbstractRef &other) const;
|
bool operator==(const AbstractRef &other) const;
|
||||||
bool operator==(const AbstractBase &other) const override;
|
bool operator==(const AbstractBase &other) const override;
|
||||||
AbstractBasePtr Clone() const override {
|
AbstractBasePtr Clone() const override {
|
||||||
return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), need_cast_, target_type_);
|
auto abs_tensor = AbstractTensor::Clone()->cast<AbstractTensorPtr>();
|
||||||
|
if (abs_tensor == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return std::make_shared<AbstractRef>(ref_key_->Clone(), abs_tensor);
|
||||||
}
|
}
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
inline AbstractBasePtr ref() const { return ref_; }
|
inline AbstractTensorPtr ref() { return shared_from_base<AbstractTensor>(); }
|
||||||
inline AbstractBasePtr ref_key() const { return ref_key_; }
|
inline AbstractBasePtr ref_key() const { return ref_key_; }
|
||||||
inline RefKeyPtr ref_key_value() const { return ref_key_value_; }
|
inline RefKeyPtr ref_key_value() const { return ref_key_value_; }
|
||||||
inline TypePtr target_type() const { return target_type_; }
|
|
||||||
inline bool need_cast() const { return need_cast_; }
|
|
||||||
AbstractBasePtr Broaden(uint8_t config = 0) const override {
|
AbstractBasePtr Broaden(uint8_t config = 0) const override {
|
||||||
// always broaden for ref
|
// always broaden for ref
|
||||||
return std::make_shared<AbstractRef>(ref_key_->Broaden(config), ref_->Broaden(), need_cast_, target_type_);
|
auto abs_tensor = AbstractTensor::Broaden()->cast<AbstractTensorPtr>();
|
||||||
|
if (abs_tensor == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return std::make_shared<AbstractRef>(ref_key_->Broaden(config), abs_tensor);
|
||||||
}
|
}
|
||||||
AbstractBasePtr Join(const AbstractBasePtr &other) override;
|
AbstractBasePtr Join(const AbstractBasePtr &other) override;
|
||||||
std::size_t hash() const override {
|
std::size_t hash() const override {
|
||||||
return ref_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^
|
return AbstractTensor::hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
AbstractBasePtr ref_key_;
|
AbstractBasePtr ref_key_;
|
||||||
AbstractBasePtr ref_;
|
|
||||||
// For mix presicion, only float type need to cast to float16 of float32
|
|
||||||
bool need_cast_;
|
|
||||||
TypePtr target_type_;
|
|
||||||
// cache for ref_key after build value, when value is null, return nullptr.
|
// cache for ref_key after build value, when value is null, return nullptr.
|
||||||
RefKeyPtr ref_key_value_;
|
RefKeyPtr ref_key_value_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -113,17 +113,8 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
|
||||||
MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
|
MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
|
||||||
<< ".";
|
<< ".";
|
||||||
}
|
}
|
||||||
TypePtr type = args_spec_list[0]->GetTypeTrack();
|
auto tensor = args_spec_list[1]->cast<abstract::AbstractTensorPtr>();
|
||||||
ValuePtr tensor_target_v = args_spec_list[2]->BuildValue();
|
return std::make_shared<AbstractRef>(args_spec_list[0], tensor);
|
||||||
if (type->type_id() != kObjectTypeRefKey) {
|
|
||||||
MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString();
|
|
||||||
}
|
|
||||||
auto need_cast = !tensor_target_v->isa<None>();
|
|
||||||
if (need_cast && !tensor_target_v->isa<Type>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Third input of make_ref should be a Type but a " << tensor_target_v->ToString();
|
|
||||||
}
|
|
||||||
TypePtr cast_target = tensor_target_v->cast<TypePtr>();
|
|
||||||
return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], need_cast, cast_target);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,
|
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||||
|
|
|
@ -88,6 +88,17 @@ std::string Parameter::DebugString(int recursive_level) const {
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ParamInfoPtr Parameter::param_info() const {
|
||||||
|
if (!has_default()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto tensor = default_param()->cast<tensor::MetaTensorPtr>();
|
||||||
|
if (tensor == nullptr || !tensor->is_parameter()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tensor->param_info();
|
||||||
|
}
|
||||||
|
|
||||||
std::string ValueNode::ToString() const {
|
std::string ValueNode::ToString() const {
|
||||||
MS_EXCEPTION_IF_NULL(value_);
|
MS_EXCEPTION_IF_NULL(value_);
|
||||||
if (value_->isa<FuncGraph>()) {
|
if (value_->isa<FuncGraph>()) {
|
||||||
|
|
|
@ -75,7 +75,7 @@ using VarPtr = std::shared_ptr<Var>;
|
||||||
class AnfIrVisitor;
|
class AnfIrVisitor;
|
||||||
|
|
||||||
class ParamInfo;
|
class ParamInfo;
|
||||||
using ParamValuePtr = std::shared_ptr<ParamInfo>;
|
using ParamInfoPtr = std::shared_ptr<ParamInfo>;
|
||||||
|
|
||||||
// AnfNode is the basic class of the IR definition derived from Base.
|
// AnfNode is the basic class of the IR definition derived from Base.
|
||||||
// Only two types of nodes are derived: CNode and ANode.
|
// Only two types of nodes are derived: CNode and ANode.
|
||||||
|
@ -288,6 +288,7 @@ class Parameter : public ANode {
|
||||||
has_default_ = true;
|
has_default_ = true;
|
||||||
}
|
}
|
||||||
ValuePtr default_param() const { return default_param_; }
|
ValuePtr default_param() const { return default_param_; }
|
||||||
|
ParamInfoPtr param_info() const;
|
||||||
|
|
||||||
bool operator==(const AnfNode &other) const override {
|
bool operator==(const AnfNode &other) const override {
|
||||||
if (!other.isa<Parameter>()) {
|
if (!other.isa<Parameter>()) {
|
||||||
|
|
|
@ -94,175 +94,6 @@ bool Slice::operator==(const Type &other) const {
|
||||||
|
|
||||||
std::string Slice::DumpText() const { return ToString(); }
|
std::string Slice::DumpText() const { return ToString(); }
|
||||||
|
|
||||||
TypePtr UndeterminedType::DeepCopy() const {
|
|
||||||
MS_EXCEPTION_IF_NULL(element_type_);
|
|
||||||
if (IsGeneric()) {
|
|
||||||
return std::make_shared<UndeterminedType>();
|
|
||||||
}
|
|
||||||
return std::make_shared<UndeterminedType>(element_type_->DeepCopy());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string UndeterminedType::ToReprString() const {
|
|
||||||
if (element_type_ == nullptr) {
|
|
||||||
return "Undetermined";
|
|
||||||
}
|
|
||||||
return "Undetermined[" + element_type_->ToReprString() + "]";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string UndeterminedType::ToString() const {
|
|
||||||
if (element_type_ == nullptr) {
|
|
||||||
return "Undetermined";
|
|
||||||
}
|
|
||||||
return "Undetermined[" + element_type_->ToString() + "]";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string UndeterminedType::DumpText() const {
|
|
||||||
if (element_type_ == nullptr) {
|
|
||||||
return "Undetermined";
|
|
||||||
}
|
|
||||||
return "Undetermined[" + element_type_->DumpText() + "]";
|
|
||||||
}
|
|
||||||
|
|
||||||
bool UndeterminedType::operator==(const Type &other) const {
|
|
||||||
if (!IsSameObjectType(*this, other)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto other_elem_type = static_cast<const UndeterminedType &>(other).element_type_;
|
|
||||||
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
|
||||||
return true;
|
|
||||||
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return *element_type_ == *other_elem_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
TypePtr TensorType::DeepCopy() const {
|
|
||||||
MS_EXCEPTION_IF_NULL(element_type_);
|
|
||||||
if (IsGeneric()) {
|
|
||||||
return std::make_shared<TensorType>();
|
|
||||||
}
|
|
||||||
return std::make_shared<TensorType>(element_type_->DeepCopy());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string TensorType::ToReprString() const {
|
|
||||||
if (element_type_ == nullptr) {
|
|
||||||
return "tensor";
|
|
||||||
}
|
|
||||||
return "tensor[" + element_type_->ToReprString() + "]";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string TensorType::ToString() const {
|
|
||||||
if (element_type_ == nullptr) {
|
|
||||||
return "Tensor";
|
|
||||||
}
|
|
||||||
return "Tensor[" + element_type_->ToString() + "]";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string TensorType::DumpText() const {
|
|
||||||
if (element_type_ == nullptr) {
|
|
||||||
return "Tensor";
|
|
||||||
}
|
|
||||||
return "Tensor(" + element_type_->DumpText() + ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
bool TensorType::operator==(const Type &other) const {
|
|
||||||
if (!IsSameObjectType(*this, other)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto other_elem_type = static_cast<const TensorType &>(other).element_type_;
|
|
||||||
// When element_type_ = nullptr, which means any type of Array.
|
|
||||||
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
|
||||||
return true;
|
|
||||||
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return *element_type_ == *other_elem_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
TypePtr RowTensorType::DeepCopy() const {
|
|
||||||
MS_EXCEPTION_IF_NULL(element_type_);
|
|
||||||
if (IsGeneric()) {
|
|
||||||
return std::make_shared<RowTensorType>();
|
|
||||||
}
|
|
||||||
return std::make_shared<RowTensorType>(element_type_->DeepCopy());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string RowTensorType::ToReprString() const {
|
|
||||||
if (element_type_ == nullptr) {
|
|
||||||
return "RowTensor";
|
|
||||||
}
|
|
||||||
return "RowTensor[" + element_type_->ToReprString() + "]";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string RowTensorType::ToString() const {
|
|
||||||
if (element_type_ == nullptr) {
|
|
||||||
return "RowTensor";
|
|
||||||
}
|
|
||||||
return "RowTensor[" + element_type_->ToString() + "]";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string RowTensorType::DumpText() const {
|
|
||||||
if (element_type_ == nullptr) {
|
|
||||||
return "RowTensor";
|
|
||||||
}
|
|
||||||
return "RowTensor[" + element_type_->DumpText() + "]";
|
|
||||||
}
|
|
||||||
|
|
||||||
bool RowTensorType::operator==(const Type &other) const {
|
|
||||||
if (!IsSameObjectType(*this, other)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto other_elem_type = static_cast<const RowTensorType &>(other).element_type_;
|
|
||||||
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
|
||||||
return true;
|
|
||||||
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return *element_type_ == *other_elem_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
TypePtr SparseTensorType::DeepCopy() const {
|
|
||||||
MS_EXCEPTION_IF_NULL(element_type_);
|
|
||||||
if (IsGeneric()) {
|
|
||||||
return std::make_shared<SparseTensorType>();
|
|
||||||
}
|
|
||||||
return std::make_shared<SparseTensorType>(element_type_->DeepCopy());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string SparseTensorType::ToReprString() const {
|
|
||||||
if (element_type_ == nullptr) {
|
|
||||||
return "SparseTensor";
|
|
||||||
}
|
|
||||||
return "SparseTensor[" + element_type_->ToReprString() + "]";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string SparseTensorType::ToString() const {
|
|
||||||
if (element_type_ == nullptr) {
|
|
||||||
return "SparseTensor";
|
|
||||||
}
|
|
||||||
return "SparseTensor[" + element_type_->ToString() + "]";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string SparseTensorType::DumpText() const {
|
|
||||||
if (element_type_ == nullptr) {
|
|
||||||
return "SparseTensor";
|
|
||||||
}
|
|
||||||
return "SparseTensor[" + element_type_->DumpText() + "]";
|
|
||||||
}
|
|
||||||
|
|
||||||
bool SparseTensorType::operator==(const Type &other) const {
|
|
||||||
if (!IsSameObjectType(*this, other)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_;
|
|
||||||
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
|
||||||
return true;
|
|
||||||
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return *element_type_ == *other_elem_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
Function::Function() : Object(kObjectTypeFunction) {
|
Function::Function() : Object(kObjectTypeFunction) {
|
||||||
args_ = std::vector<TypePtr>();
|
args_ = std::vector<TypePtr>();
|
||||||
retval_ = nullptr;
|
retval_ = nullptr;
|
||||||
|
@ -372,4 +203,8 @@ std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> proble
|
||||||
os << problem->ToString();
|
os << problem->ToString();
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const TypePtr kTensorTypeFP16 = std::make_shared<TensorType>(std::make_shared<Float>(16));
|
||||||
|
const TypePtr kTensorTypeFP32 = std::make_shared<TensorType>(std::make_shared<Float>(32));
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -32,10 +32,11 @@
|
||||||
#include "ir/named.h"
|
#include "ir/named.h"
|
||||||
|
|
||||||
#include "ir/dtype/type.h"
|
#include "ir/dtype/type.h"
|
||||||
#include "ir/dtype/ref.h"
|
|
||||||
#include "ir/dtype/number.h"
|
#include "ir/dtype/number.h"
|
||||||
#include "ir/dtype/container.h"
|
#include "ir/dtype/container.h"
|
||||||
#include "ir/dtype/empty.h"
|
#include "ir/dtype/empty.h"
|
||||||
|
#include "ir/dtype/tensor_type.h"
|
||||||
|
#include "ir/dtype/ref.h"
|
||||||
|
|
||||||
/* namespace to support intermediate representation definition */
|
/* namespace to support intermediate representation definition */
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -108,98 +109,6 @@ class Slice : public Object {
|
||||||
};
|
};
|
||||||
using SlicePtr = std::shared_ptr<Slice>;
|
using SlicePtr = std::shared_ptr<Slice>;
|
||||||
|
|
||||||
class UndeterminedType : public Object {
|
|
||||||
public:
|
|
||||||
UndeterminedType() : Object(kObjectTypeUndeterminedType) {}
|
|
||||||
explicit UndeterminedType(const TypePtr &ele)
|
|
||||||
: Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {}
|
|
||||||
~UndeterminedType() override = default;
|
|
||||||
MS_DECLARE_PARENT(UndeterminedType, Object)
|
|
||||||
|
|
||||||
TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; }
|
|
||||||
const TypePtr element() const { return element_type_; }
|
|
||||||
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
|
||||||
|
|
||||||
TypePtr DeepCopy() const override;
|
|
||||||
std::string ToString() const override;
|
|
||||||
std::string ToReprString() const override;
|
|
||||||
std::string DumpText() const override;
|
|
||||||
bool operator==(const Type &other) const override;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
TypePtr element_type_;
|
|
||||||
};
|
|
||||||
using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>;
|
|
||||||
|
|
||||||
class TensorType : public Object {
|
|
||||||
public:
|
|
||||||
TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {}
|
|
||||||
explicit TensorType(const TypePtr &ele)
|
|
||||||
: Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
|
||||||
~TensorType() override = default;
|
|
||||||
MS_DECLARE_PARENT(TensorType, Object)
|
|
||||||
|
|
||||||
TypeId generic_type_id() const override { return kObjectTypeTensorType; }
|
|
||||||
const TypePtr element() const { return element_type_; }
|
|
||||||
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
|
||||||
|
|
||||||
TypePtr DeepCopy() const override;
|
|
||||||
std::string ToString() const override;
|
|
||||||
std::string ToReprString() const override;
|
|
||||||
std::string DumpText() const override;
|
|
||||||
bool operator==(const Type &other) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
TypePtr element_type_;
|
|
||||||
};
|
|
||||||
using TensorTypePtr = std::shared_ptr<TensorType>;
|
|
||||||
|
|
||||||
class RowTensorType : public Object {
|
|
||||||
public:
|
|
||||||
RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {}
|
|
||||||
explicit RowTensorType(const TypePtr &ele)
|
|
||||||
: Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
|
||||||
~RowTensorType() override = default;
|
|
||||||
MS_DECLARE_PARENT(RowTensorType, Object)
|
|
||||||
|
|
||||||
TypeId generic_type_id() const override { return kObjectTypeRowTensorType; }
|
|
||||||
const TypePtr element() const { return element_type_; }
|
|
||||||
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
|
||||||
|
|
||||||
TypePtr DeepCopy() const override;
|
|
||||||
std::string ToString() const override;
|
|
||||||
std::string ToReprString() const override;
|
|
||||||
std::string DumpText() const override;
|
|
||||||
bool operator==(const Type &other) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
TypePtr element_type_;
|
|
||||||
};
|
|
||||||
using RowTensorTypePtr = std::shared_ptr<RowTensorType>;
|
|
||||||
|
|
||||||
class SparseTensorType : public Object {
|
|
||||||
public:
|
|
||||||
SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
|
|
||||||
explicit SparseTensorType(const TypePtr &ele)
|
|
||||||
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
|
||||||
~SparseTensorType() override = default;
|
|
||||||
MS_DECLARE_PARENT(SparseTensorType, Object)
|
|
||||||
|
|
||||||
TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
|
|
||||||
const TypePtr element() const { return element_type_; }
|
|
||||||
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
|
||||||
|
|
||||||
TypePtr DeepCopy() const override;
|
|
||||||
std::string ToString() const override;
|
|
||||||
std::string ToReprString() const override;
|
|
||||||
std::string DumpText() const override;
|
|
||||||
bool operator==(const Type &other) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
TypePtr element_type_;
|
|
||||||
};
|
|
||||||
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
|
|
||||||
|
|
||||||
class Function : public Object {
|
class Function : public Object {
|
||||||
public:
|
public:
|
||||||
Function();
|
Function();
|
||||||
|
@ -353,6 +262,9 @@ extern const TypePtr kDict;
|
||||||
extern const TypePtr kSlice;
|
extern const TypePtr kSlice;
|
||||||
extern const TypePtr kKeyword;
|
extern const TypePtr kKeyword;
|
||||||
extern const TypePtr kTensorType;
|
extern const TypePtr kTensorType;
|
||||||
|
extern const TypePtr kTensorTypeFP16;
|
||||||
|
extern const TypePtr kTensorTypeFP32;
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_IR_DTYPE_H_
|
#endif // MINDSPORE_CORE_IR_DTYPE_H_
|
||||||
|
|
|
@ -68,6 +68,8 @@ class Number : public Object {
|
||||||
const int nbits_;
|
const int nbits_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using NumberPtr = std::shared_ptr<Number>;
|
||||||
|
|
||||||
// Bool
|
// Bool
|
||||||
class Bool : public Number {
|
class Bool : public Number {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -19,15 +19,15 @@
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
#include "ir/dtype/tensor_type.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
TypePtr RefType::DeepCopy() const {
|
TypePtr RefType::DeepCopy() const {
|
||||||
if (IsGeneric()) {
|
if (IsGeneric()) {
|
||||||
return std::make_shared<RefType>();
|
return std::make_shared<RefType>();
|
||||||
} else {
|
} else {
|
||||||
auto subtype = subtype_->DeepCopy();
|
auto subtype = TensorType::DeepCopy()->cast<TensorTypePtr>();
|
||||||
auto subtype_origin = subtype_origin_->DeepCopy();
|
return std::make_shared<RefType>(subtype);
|
||||||
return std::make_shared<RefType>(subtype, subtype_origin);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ std::string RefType::DumpText() const {
|
||||||
buffer << "Ref";
|
buffer << "Ref";
|
||||||
} else {
|
} else {
|
||||||
buffer << "Ref[";
|
buffer << "Ref[";
|
||||||
buffer << subtype_->DumpText() << "]";
|
buffer << TensorType::DumpText() << "]";
|
||||||
}
|
}
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,21 +17,13 @@
|
||||||
#ifndef MINDSPORE_CORE_IR_DTYPE_REF_H_
|
#ifndef MINDSPORE_CORE_IR_DTYPE_REF_H_
|
||||||
#define MINDSPORE_CORE_IR_DTYPE_REF_H_
|
#define MINDSPORE_CORE_IR_DTYPE_REF_H_
|
||||||
|
|
||||||
#include <cstddef>
|
|
||||||
#include <iostream>
|
|
||||||
#include <initializer_list>
|
|
||||||
#include <map>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
|
||||||
#include <sstream>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
|
||||||
#include <type_traits>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <algorithm>
|
|
||||||
#include "base/base.h"
|
#include "base/base.h"
|
||||||
#include "ir/named.h"
|
#include "ir/named.h"
|
||||||
#include "ir/dtype/type.h"
|
#include "ir/dtype/type.h"
|
||||||
|
#include "ir/dtype/tensor_type.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
// TypeRefKey type
|
// TypeRefKey type
|
||||||
|
@ -48,23 +40,16 @@ class RefKeyType : public Object {
|
||||||
};
|
};
|
||||||
|
|
||||||
// TypeRef type
|
// TypeRef type
|
||||||
class RefType : public Object {
|
class RefType : public TensorType {
|
||||||
public:
|
public:
|
||||||
RefType() : Object(kObjectTypeRef) {}
|
RefType() : TensorType() {}
|
||||||
RefType(const TypePtr &subtype, const TypePtr &subtype_origin)
|
explicit RefType(const TensorTypePtr &subtype) : TensorType(subtype->element()) {}
|
||||||
: Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {}
|
|
||||||
~RefType() override {}
|
~RefType() override {}
|
||||||
MS_DECLARE_PARENT(RefType, Object)
|
MS_DECLARE_PARENT(RefType, TensorType)
|
||||||
|
|
||||||
TypePtr subtype() const { return subtype_; }
|
|
||||||
TypeId generic_type_id() const override { return kObjectTypeRef; }
|
|
||||||
TypePtr DeepCopy() const override;
|
TypePtr DeepCopy() const override;
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
std::string DumpText() const override;
|
std::string DumpText() const override;
|
||||||
|
|
||||||
private:
|
|
||||||
TypePtr subtype_;
|
|
||||||
TypePtr subtype_origin_;
|
|
||||||
};
|
};
|
||||||
using RefTypePtr = std::shared_ptr<RefType>;
|
using RefTypePtr = std::shared_ptr<RefType>;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,194 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ir/dtype/tensor_type.h"
|
||||||
|
#include <string>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
TypePtr UndeterminedType::DeepCopy() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element_type_);
|
||||||
|
if (IsGeneric()) {
|
||||||
|
return std::make_shared<UndeterminedType>();
|
||||||
|
}
|
||||||
|
return std::make_shared<UndeterminedType>(element_type_->DeepCopy());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string UndeterminedType::ToReprString() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "Undetermined";
|
||||||
|
}
|
||||||
|
return "Undetermined[" + element_type_->ToReprString() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string UndeterminedType::ToString() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "Undetermined";
|
||||||
|
}
|
||||||
|
return "Undetermined[" + element_type_->ToString() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string UndeterminedType::DumpText() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "Undetermined";
|
||||||
|
}
|
||||||
|
return "Undetermined[" + element_type_->DumpText() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool UndeterminedType::operator==(const Type &other) const {
|
||||||
|
if (!IsSameObjectType(*this, other)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto other_elem_type = static_cast<const UndeterminedType &>(other).element_type_;
|
||||||
|
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||||
|
return true;
|
||||||
|
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return *element_type_ == *other_elem_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr TensorType::DeepCopy() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element_type_);
|
||||||
|
if (IsGeneric()) {
|
||||||
|
return std::make_shared<TensorType>();
|
||||||
|
}
|
||||||
|
return std::make_shared<TensorType>(element_type_->DeepCopy());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string TensorType::ToReprString() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "tensor";
|
||||||
|
}
|
||||||
|
return "tensor[" + element_type_->ToReprString() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string TensorType::ToString() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "Tensor";
|
||||||
|
}
|
||||||
|
return "Tensor[" + element_type_->ToString() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string TensorType::DumpText() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "Tensor";
|
||||||
|
}
|
||||||
|
return "Tensor(" + element_type_->DumpText() + ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TensorType::operator==(const Type &other) const {
|
||||||
|
if (!IsSameObjectType(*this, other)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto other_elem_type = static_cast<const TensorType &>(other).element_type_;
|
||||||
|
// When element_type_ = nullptr, which means any type of Array.
|
||||||
|
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||||
|
return true;
|
||||||
|
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return *element_type_ == *other_elem_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr RowTensorType::DeepCopy() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element_type_);
|
||||||
|
if (IsGeneric()) {
|
||||||
|
return std::make_shared<RowTensorType>();
|
||||||
|
}
|
||||||
|
return std::make_shared<RowTensorType>(element_type_->DeepCopy());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string RowTensorType::ToReprString() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "RowTensor";
|
||||||
|
}
|
||||||
|
return "RowTensor[" + element_type_->ToReprString() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string RowTensorType::ToString() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "RowTensor";
|
||||||
|
}
|
||||||
|
return "RowTensor[" + element_type_->ToString() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string RowTensorType::DumpText() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "RowTensor";
|
||||||
|
}
|
||||||
|
return "RowTensor[" + element_type_->DumpText() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RowTensorType::operator==(const Type &other) const {
|
||||||
|
if (!IsSameObjectType(*this, other)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto other_elem_type = static_cast<const RowTensorType &>(other).element_type_;
|
||||||
|
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||||
|
return true;
|
||||||
|
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return *element_type_ == *other_elem_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr SparseTensorType::DeepCopy() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element_type_);
|
||||||
|
if (IsGeneric()) {
|
||||||
|
return std::make_shared<SparseTensorType>();
|
||||||
|
}
|
||||||
|
return std::make_shared<SparseTensorType>(element_type_->DeepCopy());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SparseTensorType::ToReprString() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "SparseTensor";
|
||||||
|
}
|
||||||
|
return "SparseTensor[" + element_type_->ToReprString() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SparseTensorType::ToString() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "SparseTensor";
|
||||||
|
}
|
||||||
|
return "SparseTensor[" + element_type_->ToString() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SparseTensorType::DumpText() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "SparseTensor";
|
||||||
|
}
|
||||||
|
return "SparseTensor[" + element_type_->DumpText() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SparseTensorType::operator==(const Type &other) const {
|
||||||
|
if (!IsSameObjectType(*this, other)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_;
|
||||||
|
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||||
|
return true;
|
||||||
|
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return *element_type_ == *other_elem_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,132 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
|
||||||
|
#define MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <iostream>
|
||||||
|
#include <initializer_list>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "base/base.h"
|
||||||
|
#include "ir/named.h"
|
||||||
|
#include "ir/dtype/type.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
class UndeterminedType : public Object {
|
||||||
|
public:
|
||||||
|
UndeterminedType() : Object(kObjectTypeUndeterminedType) {}
|
||||||
|
explicit UndeterminedType(const TypePtr &ele)
|
||||||
|
: Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {}
|
||||||
|
~UndeterminedType() override = default;
|
||||||
|
MS_DECLARE_PARENT(UndeterminedType, Object)
|
||||||
|
|
||||||
|
TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; }
|
||||||
|
const TypePtr element() const { return element_type_; }
|
||||||
|
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||||
|
|
||||||
|
TypePtr DeepCopy() const override;
|
||||||
|
std::string ToString() const override;
|
||||||
|
std::string ToReprString() const override;
|
||||||
|
std::string DumpText() const override;
|
||||||
|
bool operator==(const Type &other) const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
TypePtr element_type_;
|
||||||
|
};
|
||||||
|
using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>;
|
||||||
|
|
||||||
|
class TensorType : public Object {
|
||||||
|
public:
|
||||||
|
TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {}
|
||||||
|
explicit TensorType(const TypePtr &ele)
|
||||||
|
: Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
||||||
|
~TensorType() override = default;
|
||||||
|
MS_DECLARE_PARENT(TensorType, Object)
|
||||||
|
|
||||||
|
TypeId generic_type_id() const override { return kObjectTypeTensorType; }
|
||||||
|
const TypePtr element() const { return element_type_; }
|
||||||
|
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||||
|
|
||||||
|
TypePtr DeepCopy() const override;
|
||||||
|
std::string ToString() const override;
|
||||||
|
std::string ToReprString() const override;
|
||||||
|
std::string DumpText() const override;
|
||||||
|
bool operator==(const Type &other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
TypePtr element_type_;
|
||||||
|
};
|
||||||
|
using TensorTypePtr = std::shared_ptr<TensorType>;
|
||||||
|
|
||||||
|
class RowTensorType : public Object {
|
||||||
|
public:
|
||||||
|
RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {}
|
||||||
|
explicit RowTensorType(const TypePtr &ele)
|
||||||
|
: Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
||||||
|
~RowTensorType() override = default;
|
||||||
|
MS_DECLARE_PARENT(RowTensorType, Object)
|
||||||
|
|
||||||
|
TypeId generic_type_id() const override { return kObjectTypeRowTensorType; }
|
||||||
|
const TypePtr element() const { return element_type_; }
|
||||||
|
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||||
|
|
||||||
|
TypePtr DeepCopy() const override;
|
||||||
|
std::string ToString() const override;
|
||||||
|
std::string ToReprString() const override;
|
||||||
|
std::string DumpText() const override;
|
||||||
|
bool operator==(const Type &other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
TypePtr element_type_;
|
||||||
|
};
|
||||||
|
using RowTensorTypePtr = std::shared_ptr<RowTensorType>;
|
||||||
|
|
||||||
|
class SparseTensorType : public Object {
|
||||||
|
public:
|
||||||
|
SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
|
||||||
|
explicit SparseTensorType(const TypePtr &ele)
|
||||||
|
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
||||||
|
~SparseTensorType() override = default;
|
||||||
|
MS_DECLARE_PARENT(SparseTensorType, Object)
|
||||||
|
|
||||||
|
TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
|
||||||
|
const TypePtr element() const { return element_type_; }
|
||||||
|
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||||
|
|
||||||
|
TypePtr DeepCopy() const override;
|
||||||
|
std::string ToString() const override;
|
||||||
|
std::string ToReprString() const override;
|
||||||
|
std::string DumpText() const override;
|
||||||
|
bool operator==(const Type &other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
TypePtr element_type_;
|
||||||
|
};
|
||||||
|
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
|
||||||
|
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_IR_DTYPE_TENSORTYPE_H_
|
|
@ -332,14 +332,11 @@ class FuncGraph : public FuncGraphBase {
|
||||||
const std::vector<AnfNodePtr> ¶mter_obj_nodes() const { return paramter_obj_nodes_; }
|
const std::vector<AnfNodePtr> ¶mter_obj_nodes() const { return paramter_obj_nodes_; }
|
||||||
void add_parameter_obj_node(const AnfNodePtr &p);
|
void add_parameter_obj_node(const AnfNodePtr &p);
|
||||||
|
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; }
|
|
||||||
|
|
||||||
std::unordered_map<std::string, ValuePtr> attrs_;
|
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||||
std::vector<BaseShapePtr> joined_shapes_;
|
std::vector<BaseShapePtr> joined_shapes_;
|
||||||
std::unordered_map<std::string, FuncGraphTransform> transforms_;
|
std::unordered_map<std::string, FuncGraphTransform> transforms_;
|
||||||
// parameter default value
|
// parameter default value
|
||||||
std::map<std::string, AnfNodePtr> parameter_default_value_;
|
std::map<std::string, AnfNodePtr> parameter_default_value_;
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_;
|
|
||||||
size_t seen_;
|
size_t seen_;
|
||||||
|
|
||||||
std::list<CNodePtr> GetOrderedCnodes();
|
std::list<CNodePtr> GetOrderedCnodes();
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "base/base.h"
|
#include "base/base.h"
|
||||||
|
#include "ir/param_info.h"
|
||||||
#include "ir/dtype.h"
|
#include "ir/dtype.h"
|
||||||
#include "utils/convert_utils_base.h"
|
#include "utils/convert_utils_base.h"
|
||||||
#include "utils/hashing.h"
|
#include "utils/hashing.h"
|
||||||
|
@ -163,6 +164,15 @@ class MetaTensor : public Value {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Get tensor's param_info info.
|
||||||
|
ParamInfoPtr param_info() const { return param_info_; }
|
||||||
|
bool is_parameter() const { return is_parameter_; }
|
||||||
|
|
||||||
|
// Set tensor's param_info info.
|
||||||
|
void set_param_info(const ParamInfoPtr ¶m_info) {
|
||||||
|
is_parameter_ = true;
|
||||||
|
param_info_ = param_info;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// brief Data type of the tensor.
|
// brief Data type of the tensor.
|
||||||
|
@ -184,6 +194,9 @@ class MetaTensor : public Value {
|
||||||
//
|
//
|
||||||
// Includes the format and data type of a tensor on device.
|
// Includes the format and data type of a tensor on device.
|
||||||
DeviceInfo device_info_;
|
DeviceInfo device_info_;
|
||||||
|
|
||||||
|
bool is_parameter_{false};
|
||||||
|
ParamInfoPtr param_info_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
using MetaTensorPtr = std::shared_ptr<MetaTensor>;
|
using MetaTensorPtr = std::shared_ptr<MetaTensor>;
|
||||||
|
|
|
@ -34,7 +34,16 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() {
|
||||||
}
|
}
|
||||||
auto tensor_shape = tens->shape();
|
auto tensor_shape = tens->shape();
|
||||||
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
|
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
|
||||||
abs_tensor->set_value(shared_from_base<MetaTensor>());
|
|
||||||
|
// if is parameter always no value.
|
||||||
|
if (is_parameter()) {
|
||||||
|
auto param_name = param_info()->name();
|
||||||
|
auto ref_key = std::make_shared<RefKey>(param_name);
|
||||||
|
auto abs_ref_key = ref_key->ToAbstract();
|
||||||
|
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
|
||||||
|
} else {
|
||||||
|
abs_tensor->set_value(shared_from_base<MetaTensor>());
|
||||||
|
}
|
||||||
return abs_tensor;
|
return abs_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -62,6 +62,21 @@ class Named : public Value {
|
||||||
};
|
};
|
||||||
using NamedPtr = std::shared_ptr<Named>;
|
using NamedPtr = std::shared_ptr<Named>;
|
||||||
|
|
||||||
|
struct NamedHasher {
|
||||||
|
std::size_t operator()(NamedPtr const &name) const {
|
||||||
|
std::size_t hash = name->Hash();
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NamedEqual {
|
||||||
|
bool operator()(NamedPtr const &t1, NamedPtr const &t2) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(t1);
|
||||||
|
MS_EXCEPTION_IF_NULL(t2);
|
||||||
|
return *t1 == *t2;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class None : public Named {
|
class None : public Named {
|
||||||
public:
|
public:
|
||||||
None() : Named("None") {}
|
None() : Named("None") {}
|
||||||
|
|
|
@ -21,10 +21,13 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "ir/anf.h"
|
|
||||||
#include "ir/tensor.h"
|
#include "ir/dtype.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
class ParamInfo;
|
||||||
|
using ParamInfoPtr = std::shared_ptr<ParamInfo>;
|
||||||
|
|
||||||
class ParamInfo {
|
class ParamInfo {
|
||||||
public:
|
public:
|
||||||
ParamInfo() {}
|
ParamInfo() {}
|
||||||
|
@ -55,7 +58,7 @@ class ParamInfo {
|
||||||
int32_t cloned_index() const { return cloned_index_; }
|
int32_t cloned_index() const { return cloned_index_; }
|
||||||
|
|
||||||
// Make a cloned parameter and update clone info.
|
// Make a cloned parameter and update clone info.
|
||||||
ParamValuePtr Clone() {
|
ParamInfoPtr Clone() {
|
||||||
static std::atomic<int32_t> parameter_cloned_index{1};
|
static std::atomic<int32_t> parameter_cloned_index{1};
|
||||||
int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed);
|
int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed);
|
||||||
auto clone = std::make_shared<ParamInfo>(*this);
|
auto clone = std::make_shared<ParamInfo>(*this);
|
||||||
|
|
|
@ -467,6 +467,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract::AbstractBasePtr Tensor::ToAbstract() {
|
abstract::AbstractBasePtr Tensor::ToAbstract() {
|
||||||
auto tens = shared_from_base<Tensor>();
|
auto tens = shared_from_base<Tensor>();
|
||||||
auto dtype = tens->Dtype();
|
auto dtype = tens->Dtype();
|
||||||
|
@ -475,7 +476,15 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
|
||||||
}
|
}
|
||||||
auto tensor_shape = tens->shape();
|
auto tensor_shape = tens->shape();
|
||||||
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
|
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
|
||||||
abs_tensor->set_value(shared_from_base<Tensor>());
|
// if is parameter always no value.
|
||||||
|
if (is_parameter()) {
|
||||||
|
auto param_name = param_info()->name();
|
||||||
|
auto ref_key = std::make_shared<RefKey>(param_name);
|
||||||
|
auto abs_ref_key = ref_key->ToAbstract();
|
||||||
|
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
|
||||||
|
} else {
|
||||||
|
abs_tensor->set_value(shared_from_base<Tensor>());
|
||||||
|
}
|
||||||
return abs_tensor;
|
return abs_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -200,16 +200,6 @@ bool StringImm::operator==(const Value &other) const {
|
||||||
}
|
}
|
||||||
bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; }
|
bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; }
|
||||||
|
|
||||||
bool RefKey::operator==(const Value &other) const {
|
|
||||||
if (other.isa<RefKey>()) {
|
|
||||||
auto other_ = static_cast<const RefKey &>(other);
|
|
||||||
return *this == other_;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bool RefKey::operator==(const RefKey &other) const { return tag_ == other.tag_; }
|
|
||||||
|
|
||||||
bool AnyValue::operator==(const Value &other) const {
|
bool AnyValue::operator==(const Value &other) const {
|
||||||
if (other.isa<AnyValue>()) {
|
if (other.isa<AnyValue>()) {
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -224,28 +224,21 @@ using StringImmPtr = std::shared_ptr<StringImm>;
|
||||||
IMM_TRAITS(StringImmPtr, std::string)
|
IMM_TRAITS(StringImmPtr, std::string)
|
||||||
IMM_TRAITS(StringImmPtr, const char *)
|
IMM_TRAITS(StringImmPtr, const char *)
|
||||||
|
|
||||||
class RefKey : public Value {
|
class RefKey : public Named {
|
||||||
public:
|
public:
|
||||||
explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {}
|
explicit RefKey(const std::string &tag) : Named(tag) {}
|
||||||
|
|
||||||
~RefKey() override = default;
|
~RefKey() override = default;
|
||||||
MS_DECLARE_PARENT(RefKey, Value)
|
MS_DECLARE_PARENT(RefKey, Named)
|
||||||
std::size_t hash() const override { return hash_; }
|
const std::string &tag() const { return name(); }
|
||||||
const std::string &tag() const { return tag_; }
|
|
||||||
bool operator==(const Value &other) const override;
|
|
||||||
bool operator==(const RefKey &other) const;
|
|
||||||
abstract::AbstractBasePtr ToAbstract() override;
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
std::string ToString() const override { return "RefKey[" + tag_ + "]"; }
|
std::string ToString() const override { return "RefKey[" + name() + "]"; }
|
||||||
|
|
||||||
std::string DumpText() const override {
|
std::string DumpText() const override {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << "RefKey[\"" << tag_ << "\"]";
|
oss << "RefKey[\"" << name() << "\"]";
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
std::string tag_;
|
|
||||||
std::size_t hash_ = 0;
|
|
||||||
};
|
};
|
||||||
using RefKeyPtr = std::shared_ptr<RefKey>;
|
using RefKeyPtr = std::shared_ptr<RefKey>;
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,8 @@ if(BUILD_CONVERTER)
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/scope.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/scope.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value_extends.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value_extends.cc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/ref.cc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/tensor_type.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/container.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/container.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/empty.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/empty.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/number.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/number.cc
|
||||||
|
|
|
@ -29,6 +29,8 @@ set(ANF_SRC
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/scope.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/scope.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value_extends.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value_extends.cc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/ref.cc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/tensor_type.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/container.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/container.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/empty.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/empty.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/number.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/number.cc
|
||||||
|
|
|
@ -23,7 +23,7 @@ from ...common import dtype as mstype
|
||||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||||
|
|
||||||
|
|
||||||
class Assign(PrimitiveWithInfer):
|
class Assign(Primitive):
|
||||||
"""
|
"""
|
||||||
Assign `Parameter` with a value.
|
Assign `Parameter` with a value.
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,6 @@
|
||||||
import inspect
|
import inspect
|
||||||
import copy
|
import copy
|
||||||
from mindspore.common.api import _wrap_func
|
from mindspore.common.api import _wrap_func
|
||||||
from mindspore.common import Parameter
|
|
||||||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from .._c_expression import Primitive_, real_run_op, prim_type
|
from .._c_expression import Primitive_, real_run_op, prim_type
|
||||||
|
@ -410,16 +409,12 @@ def _run_op(obj, op_name, args):
|
||||||
if op_name == "Cast" or obj.update_parameter:
|
if op_name == "Cast" or obj.update_parameter:
|
||||||
cast_args = args
|
cast_args = args
|
||||||
else:
|
else:
|
||||||
cast_args = list()
|
cast_args = args
|
||||||
for arg in args:
|
for idx, arg in enumerate(args):
|
||||||
if isinstance(arg, Parameter):
|
cast_type = getattr(arg, "cast_type", None)
|
||||||
if arg.cast_type:
|
if cast_type:
|
||||||
cast_args.append(cast(arg, arg.cast_type))
|
cast_args[idx] = cast(arg, cast_type)
|
||||||
else:
|
output = real_run_op(obj, op_name, cast_args)
|
||||||
cast_args.append(arg)
|
|
||||||
else:
|
|
||||||
cast_args.append(arg)
|
|
||||||
output = real_run_op(obj, op_name, tuple(cast_args))
|
|
||||||
if not output:
|
if not output:
|
||||||
raise RuntimeError("Pynative run op %s failed!" % op_name)
|
raise RuntimeError("Pynative run op %s failed!" % op_name)
|
||||||
if len(output) == 1:
|
if len(output) == 1:
|
||||||
|
|
|
@ -118,26 +118,31 @@ class ControlMixedWhileIf(nn.Cell):
|
||||||
self.var = Parameter(initializer(1, (1), mstype.float32), name="var")
|
self.var = Parameter(initializer(1, (1), mstype.float32), name="var")
|
||||||
|
|
||||||
def construct(self, x, y, z, c2, c4):
|
def construct(self, x, y, z, c2, c4):
|
||||||
out = self.assign(self.var, c4)
|
out = c4
|
||||||
|
self.assign(self.var, c4)
|
||||||
while x < c2:
|
while x < c2:
|
||||||
y = self.assign(self.var, c4)
|
y = c4
|
||||||
|
self.assign(self.var, c4)
|
||||||
while y < c2 and x < c2:
|
while y < c2 and x < c2:
|
||||||
if 2 * y < c2:
|
if 2 * y < c2:
|
||||||
y = y + 2
|
y = y + 2
|
||||||
else:
|
else:
|
||||||
y = y + 1
|
y = y + 1
|
||||||
out = out + y
|
out = out + y
|
||||||
z = self.assign(self.var, c4)
|
z = c4
|
||||||
|
self.assign(self.var, c4)
|
||||||
while z < c2:
|
while z < c2:
|
||||||
z = z + 1
|
z = z + 1
|
||||||
out = out + z
|
out = out + z
|
||||||
x = x + 1
|
x = x + 1
|
||||||
out = out + x
|
out = out + x
|
||||||
while x < 2 * c2:
|
while x < 2 * c2:
|
||||||
y = self.assign(self.var, c4)
|
y = c4
|
||||||
|
self.assign(self.var, c4)
|
||||||
x = x + 1
|
x = x + 1
|
||||||
while y < c2:
|
while y < c2:
|
||||||
z = self.assign(self.var, c4)
|
z = c4
|
||||||
|
self.assign(self.var, c4)
|
||||||
while z < c2:
|
while z < c2:
|
||||||
z = z + 1
|
z = z + 1
|
||||||
if x < c2:
|
if x < c2:
|
||||||
|
|
|
@ -27,6 +27,7 @@ import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import operations as P
|
||||||
from mindspore.common.api import ms_function, _executor
|
from mindspore.common.api import ms_function, _executor
|
||||||
from mindspore.ops._grad.grad_base import bprop_getters
|
from mindspore.ops._grad.grad_base import bprop_getters
|
||||||
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
|
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
|
||||||
|
@ -254,3 +255,60 @@ def test_bprop_with_wrong_output_shape():
|
||||||
net = BpropWithWrongOutputShapeCell()
|
net = BpropWithWrongOutputShapeCell()
|
||||||
net.set_grad()
|
net.set_grad()
|
||||||
grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))
|
grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))
|
||||||
|
|
||||||
|
class AssignWhenInsertGrad(nn.Cell):
|
||||||
|
""" NetWithNDarray definition """
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(AssignWhenInsertGrad, self).__init__()
|
||||||
|
self.gather = P.GatherV2()
|
||||||
|
self.damping = Tensor(np.array([0.03, 0.03]).astype(np.float32))
|
||||||
|
self.cov_step = ms.Parameter(0, name="cov_step", requires_grad=False)
|
||||||
|
self.freq = Tensor(278, ms.int32)
|
||||||
|
self.getG = P.InsertGradientOf(self.save_gradient)
|
||||||
|
|
||||||
|
def save_gradient(self, dout):
|
||||||
|
self.cov_step = self.cov_step + self.freq
|
||||||
|
return dout
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
self.gather(self.damping, self.cov_step, 0)
|
||||||
|
out = P.ReLU()(x)
|
||||||
|
out = self.getG(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
grad_all = C.GradOperation('get_all', get_all=True)
|
||||||
|
|
||||||
|
class GradNet(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
|
||||||
|
def construct(self, *inputs):
|
||||||
|
out = self.net(*inputs)
|
||||||
|
return out, grad_all(self.net)(*inputs)
|
||||||
|
|
||||||
|
def test_assign_in_insert_grad():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net = AssignWhenInsertGrad().to_float(ms.float16)
|
||||||
|
input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
|
||||||
|
net_back = GradNet(net)
|
||||||
|
net_back(ms.Tensor(input_data))
|
||||||
|
|
||||||
|
class Assign(nn.Cell):
|
||||||
|
""" NetWithNDarray definition """
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Assign, self).__init__()
|
||||||
|
self.cov_step = ms.Parameter(0.0, name="cov_step", requires_grad=False)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
self.cov_step = self.cov_step + x
|
||||||
|
return self.cov_step
|
||||||
|
|
||||||
|
def test_assign():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net = Assign()
|
||||||
|
input_data = ms.Tensor(np.array(1).astype(np.int32))
|
||||||
|
net_back = GradNet(net)
|
||||||
|
net_back(input_data)
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" test_cont_break """
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import Tensor, context, nn, ms_function
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class WhileSubGraphParam(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.update = ms.Parameter(Tensor(1, ms.float32), "update")
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
out1 = z
|
||||||
|
while x < y:
|
||||||
|
self.update = self.update + 1
|
||||||
|
out1 = out1 + 1
|
||||||
|
x = x + 1
|
||||||
|
return out1, self.update
|
||||||
|
|
||||||
|
|
||||||
|
def test_while_loop_phi():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
x = Tensor(0, ms.float32)
|
||||||
|
y = Tensor(10, ms.float32)
|
||||||
|
z = Tensor(100, ms.float32)
|
||||||
|
|
||||||
|
net = WhileSubGraphParam()
|
||||||
|
net(x, y, z)
|
||||||
|
|
||||||
|
class WhileSubGraphParam2(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.update = ms.Parameter(Tensor(1, ms.float32), "update")
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
out1 = z
|
||||||
|
i = self.update
|
||||||
|
while x < y:
|
||||||
|
i = i + 1
|
||||||
|
out1 = out1 + 1
|
||||||
|
x = x + 1
|
||||||
|
return out1, self.update
|
||||||
|
|
||||||
|
|
||||||
|
def test_while_loop_phi_2():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
x = Tensor(0, ms.float32)
|
||||||
|
y = Tensor(10, ms.float32)
|
||||||
|
z = Tensor(100, ms.float32)
|
||||||
|
|
||||||
|
net = WhileSubGraphParam2()
|
||||||
|
net(x, y, z)
|
||||||
|
|
||||||
|
|
||||||
|
class WhileSubGraphParam3(Cell):
|
||||||
|
def __init__(self, initial_input_x):
|
||||||
|
super().__init__()
|
||||||
|
self.initial_input_x = initial_input_x
|
||||||
|
self.X = ms.Parameter(initial_input_x, name="parameter_x")
|
||||||
|
self.Y = ms.Parameter(self.initial_input_x, name="parameter_y")
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
a = 0
|
||||||
|
while a < 3:
|
||||||
|
self.X = self.X + self.Y
|
||||||
|
a += 1
|
||||||
|
return self.X
|
||||||
|
|
||||||
|
|
||||||
|
def test_while_loop_phi_3():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
x = Tensor(0, ms.float32)
|
||||||
|
|
||||||
|
net = WhileSubGraphParam3(x)
|
||||||
|
net()
|
||||||
|
|
||||||
|
class ControlMixedWhileIf(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.assign = P.Assign()
|
||||||
|
self.var = ms.Parameter(ms.Tensor([1], ms.float32), name="var")
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def construct(self, x, y, z, c2, c4):
|
||||||
|
out = self.assign(self.var, c4)
|
||||||
|
while x < c2:
|
||||||
|
y = self.assign(self.var, c4)
|
||||||
|
while y < c2 and x < c2:
|
||||||
|
if 2 * y < c2:
|
||||||
|
y = y + 2
|
||||||
|
else:
|
||||||
|
y = y + 1
|
||||||
|
out = out + y
|
||||||
|
z = self.assign(self.var, c4)
|
||||||
|
while z < c2:
|
||||||
|
z = z + 1
|
||||||
|
out = out + z
|
||||||
|
x = x + 1
|
||||||
|
out = out + x
|
||||||
|
while x < 2 * c2:
|
||||||
|
y = self.assign(self.var, c4)
|
||||||
|
x = x + 1
|
||||||
|
while y < c2:
|
||||||
|
z = self.assign(self.var, c4)
|
||||||
|
while z < c2:
|
||||||
|
z = z + 1
|
||||||
|
if x < c2:
|
||||||
|
y = y - 1
|
||||||
|
else:
|
||||||
|
y = y + 1
|
||||||
|
out = out + z
|
||||||
|
out = out + y
|
||||||
|
out = out + x
|
||||||
|
return out
|
||||||
|
|
||||||
|
def test_mixed_while_if():
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
x = np.array(2).astype(np.int32)
|
||||||
|
y = np.array(14).astype(np.int32)
|
||||||
|
z = np.array(1).astype(np.int32)
|
||||||
|
c2 = Tensor([14], ms.int32)
|
||||||
|
c4 = Tensor([0], ms.int32)
|
||||||
|
net = ControlMixedWhileIf()
|
||||||
|
output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4)
|
||||||
|
expect = np.array(3318).astype(np.int32)
|
||||||
|
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
|
@ -22,7 +22,13 @@ from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
|
||||||
from .vm_interface import vm
|
from .vm_interface import vm
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
|
@vm_impl_getters.register(P.Assign)
|
||||||
|
def vm_impl_assign(self):
|
||||||
|
"""Generate vm_impl function for Assign"""
|
||||||
|
def vm_impl(x, value):
|
||||||
|
x.assign_value(value)
|
||||||
|
return x
|
||||||
|
return vm_impl
|
||||||
|
|
||||||
@vm_impl_getters.register(P.ExpandDims)
|
@vm_impl_getters.register(P.ExpandDims)
|
||||||
def vm_impl_expand_dims(self):
|
def vm_impl_expand_dims(self):
|
||||||
|
|
Loading…
Reference in New Issue