!49621 fix ops.eye bug

Merge pull request !49621 from zlq2020/master
This commit is contained in:
i-robot 2023-03-07 10:52:37 +00:00 committed by Gitee
commit a9b46ffb98
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 20 additions and 14 deletions

View File

@ -36,6 +36,7 @@ mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool1d
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:pad
mindspore/mindspore/python/mindspore/ops/function/math_func.py:cov
mindspore/mindspore/python/mindspore/ops/function/math_func.py:norm
mindspore/mindspore/python/mindspore/ops/function/math_func.py:einsum
mindspore/mindspore/python/mindspore/context.py:set_auto_parallel_context
mindspore/mindspore/python/mindspore/common/tensor.py:__init__
mindspore/mindspore/python/mindspore/common/parameter.py:set_data

View File

@ -63,20 +63,6 @@ abstract::ShapePtr EyeInferShape(const PrimitivePtr &primitive, const std::vecto
return std::make_shared<abstract::Shape>(state_shape);
}
TypePtr EyeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto dtype_value = input_args[2]->BuildValue();
if (!dtype_value->isa<Type>()) {
MS_EXCEPTION(TypeError) << "For Eye, the dtype of Eye is invalid!";
}
auto output_type = dtype_value->cast<TypePtr>();
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128, kBool};
auto dtype_ret = CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name);
return dtype_ret;
}
void EyeCheck(const std::vector<AbstractBasePtr> &input_args) {
if (!input_args[0]->isa<abstract::AbstractScalar>()) {
MS_EXCEPTION(TypeError) << "For Eye, 'n' must be int, but got AnyValue!";
@ -86,6 +72,8 @@ void EyeCheck(const std::vector<AbstractBasePtr> &input_args) {
}
auto n_ptr_ = input_args[0]->BuildValue();
auto m_ptr_ = input_args[1]->BuildValue();
MS_EXCEPTION_IF_NULL(n_ptr_);
MS_EXCEPTION_IF_NULL(m_ptr_);
if (!n_ptr_->isa<Int64Imm>() && !n_ptr_->isa<Int32Imm>()) {
MS_EXCEPTION(TypeError) << "For Eye, the dtype of n is invalid!";
}
@ -93,11 +81,27 @@ void EyeCheck(const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION(TypeError) << "For Eye, the dtype of m is invalid!";
}
auto dtype_value_c = input_args[2]->BuildValue();
MS_EXCEPTION_IF_NULL(dtype_value_c);
if (!dtype_value_c->isa<Type>()) {
MS_EXCEPTION(TypeError) << "For Eye, the dtype of Eye is invalid!";
}
}
TypePtr EyeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
EyeCheck(input_args);
auto dtype_value = input_args[2]->BuildValue();
if (!dtype_value->isa<Type>()) {
MS_EXCEPTION(TypeError) << "For Eye, the dtype of Eye is invalid!";
}
auto output_type = dtype_value->cast<TypePtr>();
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128, kBool};
auto dtype_ret = CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name);
return dtype_ret;
}
ValuePtr EyeInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();

View File

@ -9463,6 +9463,7 @@ def einsum(equation, *operands):
>>> x = mindspore.Tensor([1, 2, 3, 4], mindspore.float32)
>>> y = mindspore.Tensor([1, 2], mindspore.float32)
>>> output = ops.einsum(x, [..., 1], y, [..., 2], [..., 1, 2])
>>> print(output)
[[1. 2.]
[2. 4.]
[3. 6.]