diff --git a/mindspore/core/ops/dtype.cc b/mindspore/core/ops/dtype.cc new file mode 100644 index 00000000000..f75976f2ddb --- /dev/null +++ b/mindspore/core/ops/dtype.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/dtype.h" +#include +#include +#include +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/abstract_value.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector &input_args, + const AbstractBasePtr &infer) { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name); + MS_EXCEPTION_IF_NULL(input_args[0]); + const std::set valid_types = {kTensorType}; + auto type = + CheckAndConvertUtils::CheckTensorTypeValid("infer type", input_args[0]->BuildType(), valid_types, op_name); + return type; +} + +AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + auto value = DTypeInferValue(primitive, input_args, nullptr); + MS_EXCEPTION_IF_NULL(value); + auto type = value->cast(); + MS_EXCEPTION_IF_NULL(type); + auto abstract = std::make_shared(type); + return abstract; +} + +REGISTER_PRIMITIVE_EVAL_IMPL(DType, prim::kPrimDType, DTypeInfer, DTypeInferValue, false); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/dtype.h b/mindspore/core/ops/dtype.h new file mode 100644 index 00000000000..e7029a6676c --- /dev/null +++ b/mindspore/core/ops/dtype.h @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_DTYPE_H_ +#define MINDSPORE_CORE_OPS_DTYPE_H_ +#include +#include +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +class DType : public PrimitiveC { + public: + DType() : PrimitiveC(prim::kPrimDType->name()) { InitIOName({"x"}, {"output"}); } + ~DType() = default; + MS_DECLARE_PARENT(DType, PrimitiveC); + void Init() {} +}; +using PrimDTypePtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_DTYPE_H_ diff --git a/mindspore/core/ops/shape.cc b/mindspore/core/ops/shape.cc index 5fbbe02fd70..35795daf240 100644 --- a/mindspore/core/ops/shape.cc +++ b/mindspore/core/ops/shape.cc @@ -31,7 +31,10 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP // infer shape MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); + CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name); + MS_EXCEPTION_IF_NULL(input_args[0]); + auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); + auto in_shape = shape_map[kShape]; // infer type AbstractBasePtrList abs_list; std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list), @@ -39,9 +42,20 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP return std::make_shared(item); }); auto abs = std::make_shared(abs_list); - abs->set_value(MakeValue(in_shape)); return abs; } -REGISTER_PRIMITIVE_C(kNameShape, Shape); + +ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector &input_args, + const AbstractBasePtr &infer) { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name); + MS_EXCEPTION_IF_NULL(input_args[0]); + auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); + auto inshape = shape_map[kShape]; + auto value = MakeValue(inshape); + return value; +} +REGISTER_PRIMITIVE_EVAL_IMPL(Shape, prim::kPrimShape, ShapeInfer, ShapeInferValue, false); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/shape.h b/mindspore/core/ops/shape.h index 8fdfea9d9a2..d359eebcd40 100644 --- a/mindspore/core/ops/shape.h +++ b/mindspore/core/ops/shape.h @@ -26,17 +26,13 @@ namespace mindspore { namespace ops { -constexpr auto kNameShape = "Shape"; class Shape : public PrimitiveC { public: - Shape() : PrimitiveC(kNameShape) {} + Shape() : PrimitiveC(prim::kPrimShape->name()) {} ~Shape() = default; MS_DECLARE_PARENT(Shape, PrimitiveC); void Init() {} }; - -AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); using PrimShapePtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 314af566e5d..a935ec0534f 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -197,7 +197,7 @@ class ExpandDims(PrimitiveWithInfer): return out -class DType(PrimitiveWithInfer): +class DType(Primitive): """ Returns the data type of the input tensor as mindspore.dtype. @@ -224,14 +224,6 @@ class DType(PrimitiveWithInfer): def __init__(self): """Initialize DType""" - def __infer__(self, x): - addition_error_info = 'Perhaps you are using a mixture of tensors and scalars to operate.' - validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name, addition_error_info) - out = {'shape': (), - 'dtype': mstype.type_type, - 'value': x['dtype'].element_type()} - return out - class SameTypeShape(PrimitiveWithInfer): """ @@ -549,7 +541,7 @@ class Reshape(PrimitiveWithInfer): return out -class Shape(PrimitiveWithInfer): +class Shape(Primitive): """ Returns the shape of the input tensor. @@ -578,13 +570,6 @@ class Shape(PrimitiveWithInfer): def __init__(self): """Initialize Shape""" - def __infer__(self, x): - validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) - out = {'shape': (), - 'dtype': mstype.tuple_, - 'value': tuple(x['shape'])} - return out - class DynamicShape(Primitive): """ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 4d3e0b2460b..17fc4ee8b3a 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -129,7 +129,7 @@ class Flatten(PrimitiveWithInfer): return input_x -class Softmax(PrimitiveWithInfer): +class Softmax(Primitive): r""" Softmax operation.