commit
a9b46ffb98
|
@ -36,6 +36,7 @@ mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool1d
|
|||
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:pad
|
||||
mindspore/mindspore/python/mindspore/ops/function/math_func.py:cov
|
||||
mindspore/mindspore/python/mindspore/ops/function/math_func.py:norm
|
||||
mindspore/mindspore/python/mindspore/ops/function/math_func.py:einsum
|
||||
mindspore/mindspore/python/mindspore/context.py:set_auto_parallel_context
|
||||
mindspore/mindspore/python/mindspore/common/tensor.py:__init__
|
||||
mindspore/mindspore/python/mindspore/common/parameter.py:set_data
|
||||
|
|
|
@ -63,20 +63,6 @@ abstract::ShapePtr EyeInferShape(const PrimitivePtr &primitive, const std::vecto
|
|||
return std::make_shared<abstract::Shape>(state_shape);
|
||||
}
|
||||
|
||||
TypePtr EyeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
auto dtype_value = input_args[2]->BuildValue();
|
||||
if (!dtype_value->isa<Type>()) {
|
||||
MS_EXCEPTION(TypeError) << "For Eye, the dtype of Eye is invalid!";
|
||||
}
|
||||
auto output_type = dtype_value->cast<TypePtr>();
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
|
||||
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128, kBool};
|
||||
auto dtype_ret = CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name);
|
||||
return dtype_ret;
|
||||
}
|
||||
|
||||
void EyeCheck(const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (!input_args[0]->isa<abstract::AbstractScalar>()) {
|
||||
MS_EXCEPTION(TypeError) << "For Eye, 'n' must be int, but got AnyValue!";
|
||||
|
@ -86,6 +72,8 @@ void EyeCheck(const std::vector<AbstractBasePtr> &input_args) {
|
|||
}
|
||||
auto n_ptr_ = input_args[0]->BuildValue();
|
||||
auto m_ptr_ = input_args[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(n_ptr_);
|
||||
MS_EXCEPTION_IF_NULL(m_ptr_);
|
||||
if (!n_ptr_->isa<Int64Imm>() && !n_ptr_->isa<Int32Imm>()) {
|
||||
MS_EXCEPTION(TypeError) << "For Eye, the dtype of n is invalid!";
|
||||
}
|
||||
|
@ -93,11 +81,27 @@ void EyeCheck(const std::vector<AbstractBasePtr> &input_args) {
|
|||
MS_EXCEPTION(TypeError) << "For Eye, the dtype of m is invalid!";
|
||||
}
|
||||
auto dtype_value_c = input_args[2]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(dtype_value_c);
|
||||
if (!dtype_value_c->isa<Type>()) {
|
||||
MS_EXCEPTION(TypeError) << "For Eye, the dtype of Eye is invalid!";
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr EyeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
EyeCheck(input_args);
|
||||
auto dtype_value = input_args[2]->BuildValue();
|
||||
if (!dtype_value->isa<Type>()) {
|
||||
MS_EXCEPTION(TypeError) << "For Eye, the dtype of Eye is invalid!";
|
||||
}
|
||||
auto output_type = dtype_value->cast<TypePtr>();
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
|
||||
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128, kBool};
|
||||
auto dtype_ret = CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name);
|
||||
return dtype_ret;
|
||||
}
|
||||
|
||||
ValuePtr EyeInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
|
|
|
@ -9463,6 +9463,7 @@ def einsum(equation, *operands):
|
|||
>>> x = mindspore.Tensor([1, 2, 3, 4], mindspore.float32)
|
||||
>>> y = mindspore.Tensor([1, 2], mindspore.float32)
|
||||
>>> output = ops.einsum(x, [..., 1], y, [..., 2], [..., 1, 2])
|
||||
>>> print(output)
|
||||
[[1. 2.]
|
||||
[2. 4.]
|
||||
[3. 6.]
|
||||
|
|
Loading…
Reference in New Issue