forked from mindspore-Ecosystem/mindspore
implicit type conversion
Signed-off-by: candanzg <zhangshucheng@huawei.com>
This commit is contained in:
parent
4ce1cf4529
commit
2429da19fb
|
@ -33,6 +33,9 @@ namespace mindspore {
|
|||
namespace prim {
|
||||
namespace {
|
||||
using PatternListType = std::initializer_list<BaseRef>;
|
||||
const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt8, 2}, {kNumberTypeUInt8, 3},
|
||||
{kNumberTypeInt16, 4}, {kNumberTypeInt32, 5}, {kNumberTypeInt64, 6},
|
||||
{kNumberTypeFloat16, 7}, {kNumberTypeFloat32, 8}, {kNumberTypeFloat64, 9}};
|
||||
|
||||
const std::vector<Signature> &GetSignature(const ValuePtr &function) {
|
||||
static const auto empty = std::vector<Signature>();
|
||||
|
@ -44,6 +47,16 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) {
|
|||
return empty;
|
||||
}
|
||||
|
||||
const std::string GetOpName(const ValuePtr &function) {
|
||||
std::string name = "";
|
||||
if (function->isa<Primitive>()) {
|
||||
name = function->cast<PrimitivePyPtr>()->name();
|
||||
} else if (function->isa<MetaFuncGraph>()) {
|
||||
name = function->cast<MetaFuncGraphPtr>()->name();
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list,
|
||||
const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *op_inputs) {
|
||||
std::size_t sig_size = signature.size();
|
||||
|
@ -62,10 +75,89 @@ void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &arg
|
|||
}
|
||||
}
|
||||
}
|
||||
bool CompareTensorScalarType(const TypeId &tensor_type, const size_t &t_type_number, const TypeId &scalar_type,
|
||||
const size_t &s_type_number) {
|
||||
if (scalar_type == kNumberTypeFloat16 || scalar_type == kNumberTypeFloat32 || scalar_type == kNumberTypeFloat64) {
|
||||
if (tensor_type == kNumberTypeFloat16 || tensor_type == kNumberTypeFloat32 || tensor_type == kNumberTypeFloat64) {
|
||||
return t_type_number >= s_type_number;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type,
|
||||
const size_t type_number) {
|
||||
*max_type_id = type_id;
|
||||
*max_type = type;
|
||||
*max_type_number = type_number;
|
||||
}
|
||||
|
||||
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs) {
|
||||
TypeId max_type_id = kTypeUnknown;
|
||||
TypeId max_type = kTypeUnknown;
|
||||
size_t max_type_number = 0;
|
||||
bool has_int8 = false;
|
||||
for (const auto &index : indexs) {
|
||||
TypeId arg_type_id = kTypeUnknown;
|
||||
TypeId arg_type = kTypeUnknown;
|
||||
AbstractBasePtr arg_value = args_spec_list[index];
|
||||
if (arg_value->isa<abstract::AbstractRef>()) {
|
||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
||||
}
|
||||
if (arg_value->isa<abstract::AbstractTensor>()) {
|
||||
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
|
||||
auto tensor_type = tensor->element()->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
arg_type_id = tensor_type->type_id();
|
||||
arg_type = kObjectTypeTensorType;
|
||||
} else if (arg_value->isa<abstract::AbstractScalar>()) {
|
||||
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
|
||||
auto scalar_type = scalar->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(scalar_type);
|
||||
arg_type_id = scalar_type->type_id();
|
||||
arg_type = kObjectTypeNumber;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
auto it = type_map.find(arg_type_id);
|
||||
if (it == type_map.end()) {
|
||||
continue;
|
||||
}
|
||||
if (arg_type_id == kNumberTypeInt8) {
|
||||
has_int8 = true;
|
||||
}
|
||||
if (max_type_id == kTypeUnknown) {
|
||||
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (max_type == arg_type) {
|
||||
if (it->second > max_type_number) {
|
||||
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
||||
}
|
||||
} else {
|
||||
if (arg_type == kObjectTypeTensorType) {
|
||||
if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) {
|
||||
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
||||
}
|
||||
} else {
|
||||
if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) {
|
||||
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (max_type_id == kNumberTypeUInt8 && has_int8 == true) {
|
||||
max_type_id = kNumberTypeInt16;
|
||||
}
|
||||
return max_type_id;
|
||||
}
|
||||
|
||||
// Get the largest type of index in the same SignatureEnumDType of arguments.
|
||||
std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<SignatureEnumDType> &dtypes,
|
||||
const abstract::AbstractBasePtrList &args_spec_list) {
|
||||
std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
|
||||
const abstract::AbstractBasePtrList &args_spec_list) {
|
||||
// record index for signature.dtypes of the same type
|
||||
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
|
||||
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
||||
|
@ -77,10 +169,7 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur
|
|||
it->second.push_back(i);
|
||||
}
|
||||
}
|
||||
// example:sig_dtype:[T, T1, T, T2, T, T1, T3, T4, T4]
|
||||
// and args type: [int, Tensor, Tensor, float, Tensor, int, Tensor, int, float]
|
||||
// result:{{T:2},{T1:1}}
|
||||
std::map<SignatureEnumDType, size_t> dst_type;
|
||||
std::map<SignatureEnumDType, TypeId> dst_type;
|
||||
for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) {
|
||||
auto type = it->first;
|
||||
auto indexs = it->second;
|
||||
|
@ -88,36 +177,36 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur
|
|||
if (indexs.size() < 2) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool has_tensor = false;
|
||||
for (const auto &index : indexs) {
|
||||
AbstractBasePtr arg_value = args_spec_list[index];
|
||||
if (arg_value->isa<abstract::AbstractRef>()) {
|
||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
||||
}
|
||||
|
||||
if (arg_value->isa<abstract::AbstractTensor>()) {
|
||||
(void)dst_type.insert(std::make_pair(type, index));
|
||||
has_tensor = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!has_tensor) {
|
||||
(void)dst_type.insert(std::make_pair(type, kTypeUnknown));
|
||||
continue;
|
||||
}
|
||||
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs)));
|
||||
}
|
||||
return dst_type;
|
||||
}
|
||||
|
||||
AnfNodePtr DoCast(const AnfNodePtr ¶m, const AnfNodePtr &source_param, const FuncGraphPtr &graph) {
|
||||
// op and module import path
|
||||
auto prim_dtype = prim::GetPythonOps("dtype", "mindspore.ops.functional");
|
||||
MS_EXCEPTION_IF_NULL(prim_dtype);
|
||||
// op and module import path
|
||||
AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGraphPtr &graph) {
|
||||
auto prim_cast_class = prim::GetPythonOps("Cast", "mindspore.ops.operations");
|
||||
MS_EXCEPTION_IF_NULL(prim_cast_class);
|
||||
auto dtype_node = NewCNode({NewValueNode(prim_dtype), source_param}, graph);
|
||||
auto dtype_node = NewValueNode(TypeIdToType(type_id));
|
||||
auto cast_node = NewCNode({NewValueNode(prim_cast_class)}, graph);
|
||||
return NewCNode({cast_node, param, dtype_node}, graph);
|
||||
}
|
||||
|
||||
void DoAutoCast(const std::vector<Signature> &signature, const abstract::AbstractBasePtrList &args_spec_list,
|
||||
const FuncGraphPtr &graph, std::vector<AnfNodePtr> *op_inputs) {
|
||||
const FuncGraphPtr &graph, std::vector<AnfNodePtr> *op_inputs, const std::set<size_t> &write_indexs) {
|
||||
std::vector<SignatureEnumDType> dtypes;
|
||||
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
||||
[](const Signature &sig) { return sig.dtype; });
|
||||
|
@ -126,33 +215,49 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
|
|||
return;
|
||||
}
|
||||
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
|
||||
std::map<SignatureEnumDType, size_t> dst_type = GetMaxDtypeIndex(dtypes, args_spec_list);
|
||||
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list);
|
||||
// Identify which arg requires auto cast
|
||||
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
||||
auto it = dst_type.find(dtypes[i]);
|
||||
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
AbstractBasePtr arg_value = args_spec_list[i];
|
||||
if (arg_value->isa<abstract::AbstractRef>()) {
|
||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
||||
}
|
||||
auto it = dst_type.find(dtypes[i]);
|
||||
if (it == dst_type.end() || it->second == i || !arg_value->isa<abstract::AbstractScalar>()) {
|
||||
continue;
|
||||
}
|
||||
// When scalar is of bool type, the type of tensor must also be of bool type,
|
||||
// otherwise the cast operator will not be added.
|
||||
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
|
||||
auto scalar_type = scalar->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(scalar_type);
|
||||
if (scalar_type->type_id() == kNumberTypeBool) {
|
||||
auto tensor = args_spec_list[it->second]->cast<abstract::AbstractTensorPtr>();
|
||||
TypeId arg_type_id = kTypeUnknown;
|
||||
if (arg_value->isa<abstract::AbstractTensor>()) {
|
||||
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
|
||||
auto tensor_type = tensor->element()->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
if (tensor_type->type_id() != kNumberTypeBool) {
|
||||
continue;
|
||||
}
|
||||
arg_type_id = tensor_type->type_id();
|
||||
} else if (arg_value->isa<abstract::AbstractScalar>()) {
|
||||
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
|
||||
auto scalar_type = scalar->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(scalar_type);
|
||||
arg_type_id = scalar_type->type_id();
|
||||
}
|
||||
// get source node for cast
|
||||
AnfNodePtr source_node = (*op_inputs)[it->second + 1];
|
||||
(*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], source_node, graph);
|
||||
auto it_map = type_map.find(arg_type_id);
|
||||
if (it_map == type_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto rw_it = write_indexs.find(i);
|
||||
if (rw_it != write_indexs.end()) {
|
||||
if (arg_type_id != it->second) {
|
||||
MS_LOG(EXCEPTION) << "In op '" << GetOpName(graph) << "', argument '" << args_spec_list[i]
|
||||
<< "' can not cast type from '" << TypeIdLabel(arg_type_id) << "' to '"
|
||||
<< TypeIdLabel(it->second) << "' automatically.";
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) {
|
||||
continue;
|
||||
}
|
||||
if ((arg_type_id == kNumberTypeBool || it->second == kNumberTypeBool) && arg_type_id != it->second) {
|
||||
continue;
|
||||
}
|
||||
(*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -173,10 +278,10 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
}
|
||||
}
|
||||
std::vector<AnfNodePtr> op_inputs;
|
||||
std::set<size_t> write_indexs;
|
||||
op_inputs.push_back(NewValueNode(function));
|
||||
// Assume, the write input of op is always the first input. We check if any write op,
|
||||
// and add cast op on other inputs to keep the same type with assigned parameter.
|
||||
AnfNodePtr assign_source = nullptr;
|
||||
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
||||
AnfNodePtr param = params_list[i];
|
||||
SignatureEnumRW sig = SignatureEnumRW::kRWDefault;
|
||||
|
@ -191,22 +296,18 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
if (sig == SignatureEnumRW::kRWRead) {
|
||||
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param});
|
||||
} else if (sig == SignatureEnumRW::kRWWrite) {
|
||||
assign_source = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param});
|
||||
write_indexs.insert(i);
|
||||
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefKey), param});
|
||||
}
|
||||
// If sig is SignatureEnumRW::kRWRef, not do anything.
|
||||
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
|
||||
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter.";
|
||||
}
|
||||
// add cast op here
|
||||
if (assign_source != nullptr && sig != SignatureEnumRW::kRWWrite) {
|
||||
param = DoCast(param, assign_source, func_graph);
|
||||
}
|
||||
op_inputs.push_back(param);
|
||||
}
|
||||
// process default
|
||||
ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs);
|
||||
DoAutoCast(signature, args_spec_list, func_graph, &op_inputs);
|
||||
DoAutoCast(signature, args_spec_list, func_graph, &op_inputs, write_indexs);
|
||||
return func_graph->NewCNode(op_inputs);
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -321,8 +321,8 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
|||
dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mindspore.float32.
|
||||
|
||||
Returns:
|
||||
Union[Tensor, Initialized], When `init` is Tensor, the return is Tensor object,
|
||||
otherwise the return is Initialize object.
|
||||
Union[Tensor, Initializer], When `init` is Tensor, the return is Tensor object,
|
||||
otherwise the return is Initialize object.
|
||||
|
||||
Examples:
|
||||
>>> tensor = initializer('ones', [1, 2, 3], mindspore.float32)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
"""Parameter for cell."""
|
||||
import numbers
|
||||
from copy import copy, deepcopy
|
||||
from . import dtype as mstype
|
||||
from .initializer import initializer, Initializer
|
||||
from .tensor import Tensor, MetaTensor
|
||||
from .._checkparam import _check_str_by_regular
|
||||
|
@ -199,6 +200,10 @@ class Parameter:
|
|||
elif isinstance(data, Initializer):
|
||||
self.init_mode = data
|
||||
data = MetaTensor(self.init_mode.dtype, self.init_mode.shape)
|
||||
elif isinstance(data, int):
|
||||
data = Tensor(data, dtype=mstype.int32)
|
||||
elif isinstance(data, float):
|
||||
data = Tensor(data, dtype=mstype.float32)
|
||||
else:
|
||||
data = Tensor(data)
|
||||
data.init_flag = False
|
||||
|
|
|
@ -145,8 +145,8 @@ class AssignAdd(PrimitiveWithInfer):
|
|||
>>> net(value)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
|
||||
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -189,8 +189,8 @@ class AssignSub(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
|
||||
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
|
|
|
@ -24,6 +24,7 @@ import numpy as np
|
|||
from ... import context
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._c_expression import signature_dtype as sig_dtype
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
|
@ -1489,11 +1490,13 @@ class ApplyMomentum(PrimitiveWithInfer):
|
|||
Please refer to the usage in nn.ApplyMomentum.
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
|
||||
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T),
|
||||
('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
@prim_attr_register
|
||||
def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
"""Other operators."""
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._c_expression import signature_dtype as sig_dtype
|
||||
from ..._checkparam import Validator as validator, Rel
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
|
@ -46,8 +47,8 @@ class Assign(PrimitiveWithInfer):
|
|||
>>> net(x)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
|
||||
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
|
|
|
@ -0,0 +1,244 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""multitype_ops directory test case"""
|
||||
import numpy as np
|
||||
from functools import partial, reduce
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import functional as F, composite as C
|
||||
import mindspore.context as context
|
||||
import pytest
|
||||
|
||||
class TensorIntAutoCast(nn.Cell):
|
||||
def __init__(self,):
|
||||
super(TensorIntAutoCast, self).__init__()
|
||||
self.i = 2
|
||||
def construct(self, t):
|
||||
z = F.tensor_mul(t, self.i)
|
||||
return z
|
||||
|
||||
|
||||
class TensorFPAutoCast(nn.Cell):
|
||||
def __init__(self,):
|
||||
super(TensorFPAutoCast, self).__init__()
|
||||
self.f = 1.2
|
||||
def construct(self, t):
|
||||
z = F.tensor_mul(t, self.f)
|
||||
return z
|
||||
|
||||
|
||||
class TensorBoolAutoCast(nn.Cell):
|
||||
def __init__(self,):
|
||||
super(TensorBoolAutoCast, self).__init__()
|
||||
self.f = True
|
||||
def construct(self, t):
|
||||
z = F.tensor_mul(t, self.f)
|
||||
return z
|
||||
|
||||
class TensorAutoCast(nn.Cell):
|
||||
def __init__(self,):
|
||||
super(TensorAutoCast, self).__init__()
|
||||
def construct(self, t1, t2):
|
||||
z = F.tensor_mul(t1, t2)
|
||||
return z
|
||||
|
||||
|
||||
def test_tensor_auto_cast():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
t0 = Tensor([True, False], mstype.bool_)
|
||||
t_uint8 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint8)
|
||||
t_int8 = Tensor(np.ones([2, 1, 2, 2]), mstype.int8)
|
||||
t_int16 = Tensor(np.ones([2, 1, 2, 2]), mstype.int16)
|
||||
t_int32 = Tensor(np.ones([2, 1, 2, 2]), mstype.int32)
|
||||
t_int64 = Tensor(np.ones([2, 1, 2, 2]), mstype.int64)
|
||||
t_fp16 = Tensor(np.ones([2, 1, 2, 2]), mstype.float16)
|
||||
t_fp32 = Tensor(np.ones([2, 1, 2, 2]), mstype.float32)
|
||||
t_fp64 = Tensor(np.ones([2, 1, 2, 2]), mstype.float64)
|
||||
net = TensorAutoCast()
|
||||
rs = net(t_uint8, t_int8)
|
||||
assert rs.dtype() == mstype.int16
|
||||
rs = net(t_uint8, t_int16)
|
||||
assert rs.dtype() == mstype.int16
|
||||
rs = net(t_uint8, t_int32)
|
||||
assert rs.dtype() == mstype.int32
|
||||
rs = net(t_uint8, t_int64)
|
||||
assert rs.dtype() == mstype.int64
|
||||
rs = net(t_int8, t_int16)
|
||||
assert rs.dtype() == mstype.int16
|
||||
rs = net(t_int8, t_int32)
|
||||
assert rs.dtype() == mstype.int32
|
||||
rs = net(t_int8, t_int64)
|
||||
assert rs.dtype() == mstype.int64
|
||||
rs = net(t_int16, t_int32)
|
||||
assert rs.dtype() == mstype.int32
|
||||
rs = net(t_int16, t_int64)
|
||||
assert rs.dtype() == mstype.int64
|
||||
rs = net(t_int32, t_int64)
|
||||
assert rs.dtype() == mstype.int64
|
||||
|
||||
rs = net(t_fp16, t_fp32)
|
||||
assert rs.dtype() == mstype.float32
|
||||
rs = net(t_fp16, t_fp64)
|
||||
assert rs.dtype() == mstype.float64
|
||||
rs = net(t_fp32, t_fp64)
|
||||
assert rs.dtype() == mstype.float64
|
||||
|
||||
rs = net(t_uint8, t_fp16)
|
||||
assert rs.dtype() == mstype.float16
|
||||
rs = net(t_uint8, t_fp32)
|
||||
assert rs.dtype() == mstype.float32
|
||||
rs = net(t_uint8, t_fp64)
|
||||
assert rs.dtype() == mstype.float64
|
||||
rs = net(t_int8, t_fp64)
|
||||
assert rs.dtype() == mstype.float64
|
||||
rs = net(t_int16, t_fp64)
|
||||
assert rs.dtype() == mstype.float64
|
||||
rs = net(t_int32, t_fp64)
|
||||
assert rs.dtype() == mstype.float64
|
||||
rs = net(t_int64, t_fp64)
|
||||
assert rs.dtype() == mstype.float64
|
||||
|
||||
rs = net(t_fp16, t_int8)
|
||||
assert rs.dtype() == mstype.float16
|
||||
rs = net(t_fp16, t_uint8)
|
||||
assert rs.dtype() == mstype.float16
|
||||
rs = net(t_fp16, t_int16)
|
||||
assert rs.dtype() == mstype.float16
|
||||
rs = net(t_fp16, t_int32)
|
||||
assert rs.dtype() == mstype.float16
|
||||
rs = net(t_fp16, t_int64)
|
||||
assert rs.dtype() == mstype.float16
|
||||
|
||||
tint = TensorIntAutoCast()
|
||||
rs = tint(t_uint8)
|
||||
assert rs.dtype() == mstype.uint8
|
||||
rs = tint(t_int8)
|
||||
assert rs.dtype() == mstype.int8
|
||||
rs = tint(t_int16)
|
||||
assert rs.dtype() == mstype.int16
|
||||
rs = tint(t_int32)
|
||||
assert rs.dtype() == mstype.int32
|
||||
rs = tint(t_int64)
|
||||
assert rs.dtype() == mstype.int64
|
||||
rs = tint(t_fp16)
|
||||
assert rs.dtype() == mstype.float16
|
||||
rs = tint(t_fp32)
|
||||
assert rs.dtype() == mstype.float32
|
||||
rs = tint(t_fp64)
|
||||
assert rs.dtype() == mstype.float64
|
||||
tfp = TensorFPAutoCast()
|
||||
rs = tfp(t_uint8)
|
||||
assert rs.dtype() == mstype.float32
|
||||
rs = tfp(t_int8)
|
||||
assert rs.dtype() == mstype.float32
|
||||
rs = tfp(t_int16)
|
||||
assert rs.dtype() == mstype.float32
|
||||
rs = tfp(t_int32)
|
||||
assert rs.dtype() == mstype.float32
|
||||
rs = tfp(t_int64)
|
||||
assert rs.dtype() == mstype.float32
|
||||
rs = tfp(t_fp16)
|
||||
assert rs.dtype() == mstype.float32
|
||||
rs = tfp(t_fp32)
|
||||
assert rs.dtype() == mstype.float32
|
||||
rs = tfp(t_fp64)
|
||||
assert rs.dtype() == mstype.float64
|
||||
|
||||
t_uint16 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint16)
|
||||
t_uint32 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint32)
|
||||
t_uint64 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint64)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint16, t_uint8)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint16, t_int8)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint16, t_int16)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint16, t_int32)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint16, t_int64)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint32, t_uint8)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint32, t_int8)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint32, t_int16)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint32, t_int32)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint32, t_int64)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint64, t_uint8)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint64, t_int8)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint64, t_int16)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint64, t_int32)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint64, t_int64)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint16, t_fp16)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint16, t_fp32)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint16, t_fp64)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint32, t_fp16)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint32, t_fp32)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint32, t_fp64)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint64, t_fp16)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint64, t_fp32)
|
||||
with pytest.raises(TypeError):
|
||||
net(t_uint64, t_fp64)
|
||||
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
tfp(t_uint16)
|
||||
with pytest.raises(TypeError):
|
||||
tfp(t_uint32)
|
||||
with pytest.raises(TypeError):
|
||||
tfp(t_uint64)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
tint(t_uint16)
|
||||
with pytest.raises(TypeError):
|
||||
tint(t_uint32)
|
||||
with pytest.raises(TypeError):
|
||||
tint(t_uint64)
|
||||
|
||||
bnet = TensorBoolAutoCast()
|
||||
with pytest.raises(TypeError):
|
||||
bnet(t_uint8)
|
||||
with pytest.raises(TypeError):
|
||||
bnet(t_int8)
|
||||
with pytest.raises(TypeError):
|
||||
bnet(t_int16)
|
||||
with pytest.raises(TypeError):
|
||||
bnet(t_int32)
|
||||
with pytest.raises(TypeError):
|
||||
bnet(t_int64)
|
||||
with pytest.raises(TypeError):
|
||||
bnet(t_fp16)
|
||||
with pytest.raises(TypeError):
|
||||
bnet(t_fp32)
|
||||
with pytest.raises(TypeError):
|
||||
bnet(t_fp64)
|
|
@ -64,7 +64,7 @@ def test_parameter_update_int32_and_tensor():
|
|||
param_step = train_network.parameters_dict()['global_step']
|
||||
update_global_step = ParameterUpdate(param_step)
|
||||
|
||||
input_step = Tensor(np.array([1000]), mstype.float32)
|
||||
input_step = Tensor(np.array([1000]), mstype.int32)
|
||||
_executor.compile(update_global_step, input_step)
|
||||
|
||||
|
||||
|
|
|
@ -463,7 +463,7 @@ raise_set = [
|
|||
'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}),
|
||||
'desc_inputs': [0]}),
|
||||
('AssignAdd_Error', {
|
||||
'block': (P.AssignAdd(), {'exception': TypeError}),
|
||||
'block': (P.AssignAdd(), {'exception': IndexError}),
|
||||
'desc_inputs': [[1]]}),
|
||||
]
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test ops """
|
||||
import functools
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
@ -22,7 +23,8 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore.ops import operations as P
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
|
||||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception, \
|
||||
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
||||
|
||||
|
||||
class AssignAddNet(nn.Cell):
|
||||
|
@ -77,11 +79,6 @@ class CumSumNet(nn.Cell):
|
|||
|
||||
|
||||
raise_set = [
|
||||
# input two tensors, but element types are not same
|
||||
('TensorAdd1', {
|
||||
'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('TensorAdd2', {
|
||||
'block': (P.TensorAdd(), {'exception': ValueError, 'error_keywords': ['TensorAdd']}),
|
||||
|
@ -256,22 +253,12 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input two tensors, but element types are not same
|
||||
('Sub1', {
|
||||
'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('Sub2', {
|
||||
'block': (P.Sub(), {'exception': ValueError, 'error_keywords': ['Sub']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input two tensors, but element types are not same
|
||||
('Mul1', {
|
||||
'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('Mul2', {
|
||||
'block': (P.Mul(), {'exception': ValueError, 'error_keywords': ['Mul']}),
|
||||
|
@ -327,55 +314,30 @@ raise_set = [
|
|||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input two tensors, but element types are not same
|
||||
('Minimum1', {
|
||||
'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('Minimum2', {
|
||||
'block': (P.Minimum(), {'exception': ValueError, 'error_keywords': ['Minimum']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input two tensors, but element types are not same
|
||||
('Maximum1', {
|
||||
'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('Maximum2', {
|
||||
'block': (P.Maximum(), {'exception': ValueError, 'error_keywords': ['Maximum']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input two tensors, but element types are not same
|
||||
('RealDiv1', {
|
||||
'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('RealDiv2', {
|
||||
'block': (P.RealDiv(), {'exception': ValueError, 'error_keywords': ['RealDiv']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input two tensors, but element types are not same
|
||||
('Div1', {
|
||||
'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('Div2', {
|
||||
'block': (P.Div(), {'exception': ValueError, 'error_keywords': ['Div']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input two tensors, but element types are not same
|
||||
('FloorDiv1', {
|
||||
'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('FloorDiv2', {
|
||||
'block': (P.FloorDiv(), {'exception': ValueError, 'error_keywords': ['FloorDiv']}),
|
||||
|
@ -389,11 +351,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input two tensors, but element types are not same
|
||||
('FloorMod1', {
|
||||
'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('FFloorMod2', {
|
||||
'block': (P.FloorMod(), {'exception': ValueError, 'error_keywords': ['FloorMod']}),
|
||||
|
@ -407,11 +364,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of x and y not match
|
||||
('Equal1', {
|
||||
'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('Equal2', {
|
||||
'block': (P.Equal(), {'exception': ValueError, 'error_keywords': ['Equal']}),
|
||||
|
@ -430,55 +382,30 @@ raise_set = [
|
|||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
|
||||
# type of x and y not match
|
||||
('NotEqual1', {
|
||||
'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('NotEqual2', {
|
||||
'block': (P.NotEqual(), {'exception': ValueError, 'error_keywords': ['NotEqual']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of x and y not match
|
||||
('Greater1', {
|
||||
'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('Greater2', {
|
||||
'block': (P.Greater(), {'exception': ValueError, 'error_keywords': ['Greater']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of x and y not match
|
||||
('GreaterEqual1', {
|
||||
'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('GreaterEqual2', {
|
||||
'block': (P.GreaterEqual(), {'exception': ValueError, 'error_keywords': ['GreaterEqual']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of x and y not match
|
||||
('Less1', {
|
||||
'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('Less2', {
|
||||
'block': (P.Less(), {'exception': ValueError, 'error_keywords': ['Less']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of x and y not match
|
||||
('LessEqual1', {
|
||||
'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('LessEqual2', {
|
||||
'block': (P.LessEqual(), {'exception': ValueError, 'error_keywords': ['LessEqual']}),
|
||||
|
@ -643,11 +570,6 @@ raise_set = [
|
|||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input two tensors, but element types are not same
|
||||
('Atan21', {
|
||||
'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('Atan22', {
|
||||
'block': (P.Atan2(), {'exception': ValueError, 'error_keywords': ['Atan2']}),
|
||||
|
@ -655,7 +577,96 @@ raise_set = [
|
|||
'skip': ['backward']}),
|
||||
]
|
||||
|
||||
test_case_math_ops = [
|
||||
# input two tensors, but element types are not same
|
||||
('TensorAdd1', {
|
||||
'block': P.TensorAdd(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Sub1', {
|
||||
'block': P.Sub(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Mul1', {
|
||||
'block': P.Mul(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Minimum1', {
|
||||
'block': P.Minimum(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Maximum1', {
|
||||
'block': P.Maximum(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('RealDiv1', {
|
||||
'block': P.RealDiv(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Div1', {
|
||||
'block': P.Div(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('FloorDiv1', {
|
||||
'block': P.FloorDiv(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('FloorMod1', {
|
||||
'block': P.FloorMod(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('Equal1', {
|
||||
'block': P.Equal(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('NotEqual1', {
|
||||
'block': P.NotEqual(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('Greater1', {
|
||||
'block': P.Greater(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('GreaterEqual1', {
|
||||
'block': P.GreaterEqual(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('Less1', {
|
||||
'block': P.Less(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('LessEqual1', {
|
||||
'block': P.LessEqual(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Atan21', {
|
||||
'block': P.Atan2(),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
]
|
||||
|
||||
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
|
||||
def test_check_exception():
|
||||
return raise_set
|
||||
|
||||
|
||||
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
|
||||
def test_exec():
|
||||
import mindspore.context as context
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
return functools.reduce(lambda x, y: x + y, [test_case_math_ops])
|
||||
|
|
Loading…
Reference in New Issue