forked from mindspore-Ecosystem/mindspore
commit
a9b46ffb98
|
@ -36,6 +36,7 @@ mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool1d
|
||||||
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:pad
|
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:pad
|
||||||
mindspore/mindspore/python/mindspore/ops/function/math_func.py:cov
|
mindspore/mindspore/python/mindspore/ops/function/math_func.py:cov
|
||||||
mindspore/mindspore/python/mindspore/ops/function/math_func.py:norm
|
mindspore/mindspore/python/mindspore/ops/function/math_func.py:norm
|
||||||
|
mindspore/mindspore/python/mindspore/ops/function/math_func.py:einsum
|
||||||
mindspore/mindspore/python/mindspore/context.py:set_auto_parallel_context
|
mindspore/mindspore/python/mindspore/context.py:set_auto_parallel_context
|
||||||
mindspore/mindspore/python/mindspore/common/tensor.py:__init__
|
mindspore/mindspore/python/mindspore/common/tensor.py:__init__
|
||||||
mindspore/mindspore/python/mindspore/common/parameter.py:set_data
|
mindspore/mindspore/python/mindspore/common/parameter.py:set_data
|
||||||
|
|
|
@ -63,20 +63,6 @@ abstract::ShapePtr EyeInferShape(const PrimitivePtr &primitive, const std::vecto
|
||||||
return std::make_shared<abstract::Shape>(state_shape);
|
return std::make_shared<abstract::Shape>(state_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr EyeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
|
||||||
auto prim_name = prim->name();
|
|
||||||
auto dtype_value = input_args[2]->BuildValue();
|
|
||||||
if (!dtype_value->isa<Type>()) {
|
|
||||||
MS_EXCEPTION(TypeError) << "For Eye, the dtype of Eye is invalid!";
|
|
||||||
}
|
|
||||||
auto output_type = dtype_value->cast<TypePtr>();
|
|
||||||
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
|
|
||||||
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128, kBool};
|
|
||||||
auto dtype_ret = CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name);
|
|
||||||
return dtype_ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
void EyeCheck(const std::vector<AbstractBasePtr> &input_args) {
|
void EyeCheck(const std::vector<AbstractBasePtr> &input_args) {
|
||||||
if (!input_args[0]->isa<abstract::AbstractScalar>()) {
|
if (!input_args[0]->isa<abstract::AbstractScalar>()) {
|
||||||
MS_EXCEPTION(TypeError) << "For Eye, 'n' must be int, but got AnyValue!";
|
MS_EXCEPTION(TypeError) << "For Eye, 'n' must be int, but got AnyValue!";
|
||||||
|
@ -86,6 +72,8 @@ void EyeCheck(const std::vector<AbstractBasePtr> &input_args) {
|
||||||
}
|
}
|
||||||
auto n_ptr_ = input_args[0]->BuildValue();
|
auto n_ptr_ = input_args[0]->BuildValue();
|
||||||
auto m_ptr_ = input_args[1]->BuildValue();
|
auto m_ptr_ = input_args[1]->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(n_ptr_);
|
||||||
|
MS_EXCEPTION_IF_NULL(m_ptr_);
|
||||||
if (!n_ptr_->isa<Int64Imm>() && !n_ptr_->isa<Int32Imm>()) {
|
if (!n_ptr_->isa<Int64Imm>() && !n_ptr_->isa<Int32Imm>()) {
|
||||||
MS_EXCEPTION(TypeError) << "For Eye, the dtype of n is invalid!";
|
MS_EXCEPTION(TypeError) << "For Eye, the dtype of n is invalid!";
|
||||||
}
|
}
|
||||||
|
@ -93,11 +81,27 @@ void EyeCheck(const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION(TypeError) << "For Eye, the dtype of m is invalid!";
|
MS_EXCEPTION(TypeError) << "For Eye, the dtype of m is invalid!";
|
||||||
}
|
}
|
||||||
auto dtype_value_c = input_args[2]->BuildValue();
|
auto dtype_value_c = input_args[2]->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(dtype_value_c);
|
||||||
if (!dtype_value_c->isa<Type>()) {
|
if (!dtype_value_c->isa<Type>()) {
|
||||||
MS_EXCEPTION(TypeError) << "For Eye, the dtype of Eye is invalid!";
|
MS_EXCEPTION(TypeError) << "For Eye, the dtype of Eye is invalid!";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TypePtr EyeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto prim_name = prim->name();
|
||||||
|
EyeCheck(input_args);
|
||||||
|
auto dtype_value = input_args[2]->BuildValue();
|
||||||
|
if (!dtype_value->isa<Type>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For Eye, the dtype of Eye is invalid!";
|
||||||
|
}
|
||||||
|
auto output_type = dtype_value->cast<TypePtr>();
|
||||||
|
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
|
||||||
|
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128, kBool};
|
||||||
|
auto dtype_ret = CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name);
|
||||||
|
return dtype_ret;
|
||||||
|
}
|
||||||
|
|
||||||
ValuePtr EyeInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
ValuePtr EyeInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
auto prim_name = prim->name();
|
auto prim_name = prim->name();
|
||||||
|
|
|
@ -9463,6 +9463,7 @@ def einsum(equation, *operands):
|
||||||
>>> x = mindspore.Tensor([1, 2, 3, 4], mindspore.float32)
|
>>> x = mindspore.Tensor([1, 2, 3, 4], mindspore.float32)
|
||||||
>>> y = mindspore.Tensor([1, 2], mindspore.float32)
|
>>> y = mindspore.Tensor([1, 2], mindspore.float32)
|
||||||
>>> output = ops.einsum(x, [..., 1], y, [..., 2], [..., 1, 2])
|
>>> output = ops.einsum(x, [..., 1], y, [..., 2], [..., 1, 2])
|
||||||
|
>>> print(output)
|
||||||
[[1. 2.]
|
[[1. 2.]
|
||||||
[2. 4.]
|
[2. 4.]
|
||||||
[3. 6.]
|
[3. 6.]
|
||||||
|
|
Loading…
Reference in New Issue