diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.cc index 18493ff4fcb..4a9340ebe61 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/gatherv2_gpu_kernel.cc @@ -183,6 +183,18 @@ std::vector> Gather .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat64), &GatherV2FwdGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &GatherV2FwdGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &GatherV2FwdGpuKernelMod::LaunchKernel}, {KernelAttr() .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeInt32) @@ -201,18 +213,36 @@ std::vector> Gather .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat32), &GatherV2FwdGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &GatherV2FwdGpuKernelMod::LaunchKernel}, {KernelAttr() .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat16), &GatherV2FwdGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &GatherV2FwdGpuKernelMod::LaunchKernel}, {KernelAttr() .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat16), &GatherV2FwdGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &GatherV2FwdGpuKernelMod::LaunchKernel}, {KernelAttr() .AddInputAttr(kNumberTypeBool) .AddInputAttr(kNumberTypeInt32) @@ -225,12 +255,42 @@ std::vector> Gather .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeBool), &GatherV2FwdGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeBool), + &GatherV2FwdGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeBool), + &GatherV2FwdGpuKernelMod::LaunchKernel}, {KernelAttr() .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeInt32), &GatherV2FwdGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &GatherV2FwdGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &GatherV2FwdGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &GatherV2FwdGpuKernelMod::LaunchKernel}, }; MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Gather, GatherV2FwdGpuKernelMod); diff --git a/mindspore/core/ops/expand_dims.cc b/mindspore/core/ops/expand_dims.cc index d0810c3c728..deab80d0163 100644 --- a/mindspore/core/ops/expand_dims.cc +++ b/mindspore/core/ops/expand_dims.cc @@ -47,7 +47,22 @@ abstract::ShapePtr ExpandDimsInferShape(const PrimitivePtr &primitive, const std constexpr auto kExpandDimsInputsNum = 2; int64_t axis = 0; if (input_args.size() == kExpandDimsInputsNum) { - axis = GetValue(input_args[kInputIndex1]->BuildValue()); + auto input_value = input_args[kInputIndex1]->BuildValue(); + if (input_args[kInputIndex1]->isa()) { + axis = GetValue(input_value); + } else if (input_args[kInputIndex1]->isa()) { + if (input_value->isa()) { + auto axis_vec = CheckAndConvertUtils::CheckTensorIntValue("axis", input_value, prim_name); + if (axis_vec.size() != 1) { + MS_LOG(EXCEPTION) << " The input number of ExpandDims axis must be int, but got " << axis_vec; + } + axis = axis_vec[0]; + } else { + ShapeVector out_shape; + (void)out_shape.insert(out_shape.end(), x_shape.size() + 1, -1); + return std::make_shared(out_shape); + } + } } else if (input_args.size() == 1) { axis = GetValue(primitive->GetAttr(kAxis)); } else { diff --git a/mindspore/core/ops/op_utils.cc b/mindspore/core/ops/op_utils.cc index 66a39ded9c7..25f21ac2c90 100644 --- a/mindspore/core/ops/op_utils.cc +++ b/mindspore/core/ops/op_utils.cc @@ -152,22 +152,22 @@ bool CheckAndGetAxisValue(const std::vector &input_ar if (axis_ptr == nullptr) { return is_dynamic; } - *axis_value = CheckAndConvertUtils::CheckIntOrTupleInt(op_name, axis_ptr, "Reduce"); + *axis_value = CheckAndConvertUtils::CheckIntOrTupleInt("axis", axis_ptr, op_name); return is_dynamic; } auto input_value = input_args[kInputIndex1]->BuildValue(); if (input_args[kInputIndex1]->isa() || input_args[kInputIndex1]->isa() || input_args[kInputIndex1]->isa()) { - *axis_value = CheckAndConvertUtils::CheckIntOrTupleInt(op_name, input_value, "Reduce"); + *axis_value = CheckAndConvertUtils::CheckIntOrTupleInt("axis", input_value, op_name); } else if (input_args[kInputIndex1]->isa()) { - (void)CheckAndConvertUtils::CheckTensorTypeValid(op_name, input_args[kInputIndex1]->BuildType(), {kInt32, kInt64}, - "Reduce"); + (void)CheckAndConvertUtils::CheckTensorTypeValid("axis", input_args[kInputIndex1]->BuildType(), {kInt32, kInt64}, + op_name); if (input_value->isa()) { - *axis_value = CheckAndConvertUtils::CheckTensorIntValue(op_name, input_value, "Reduce"); + *axis_value = CheckAndConvertUtils::CheckTensorIntValue("axis", input_value, op_name); } else { is_dynamic = true; - auto axis_shape = CheckAndConvertUtils::GetTensorInputShape("Reduce", input_args, 1); + auto axis_shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 1); if (axis_shape->shape().size() != 1) { MS_EXCEPTION(ValueError) << "For 'Reduce axis', " << op_name << " must be 1-D, but got" << axis_shape->shape().size() << "-D."; @@ -215,13 +215,12 @@ abstract::ShapePtr ReduceBaseInferShape(const PrimitivePtr &primitive, return std::make_shared(out_shape); } -TypePtr ReduceBaseInferType(const PrimitivePtr &prim, const std::vector &input_args) { +TypePtr ReduceBaseInferType(const PrimitivePtr &prim, const std::vector &input_args, + const std::set &check_list) { MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(input_args[0]); auto x_type = input_args[0]->BuildType(); - std::set valid_types = common_valid_types; - (void)valid_types.insert(kBool); - (void)CheckAndConvertUtils::CheckTensorTypeValid("x dtype", x_type, valid_types, prim->name()); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x dtype", x_type, check_list, prim->name()); return x_type; } diff --git a/mindspore/core/ops/op_utils.h b/mindspore/core/ops/op_utils.h index adfb219c0db..b94e1e9c167 100644 --- a/mindspore/core/ops/op_utils.h +++ b/mindspore/core/ops/op_utils.h @@ -60,7 +60,8 @@ ShapeVector ReduceFuncCalShapeInferImpl(const PrimitivePtr &primitive, const Sha abstract::ShapePtr ReduceBaseInferShape(const PrimitivePtr &primitive, const std::vector &input_args, const std::string &prim_name); -TypePtr ReduceBaseInferType(const PrimitivePtr &prim, const std::vector &input_args); +TypePtr ReduceBaseInferType(const PrimitivePtr &prim, const std::vector &input_args, + const std::set &check_list); template api::SharedPtr GetOperator(const AnfNodePtr &node) { diff --git a/mindspore/core/ops/reduce_all.cc b/mindspore/core/ops/reduce_all.cc deleted file mode 100644 index 29b6a29a188..00000000000 --- a/mindspore/core/ops/reduce_all.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "ops/reduce_all.h" -#include -#include -#include -#include "ops/op_utils.h" -#include "abstract/ops/op_infer.h" -#include "utils/check_convert_utils.h" -#include "abstract/ops/primitive_infer_map.h" -#include "mindapi/src/helper.h" - -namespace mindspore { -namespace ops { -MIND_API_OPERATOR_IMPL(ReduceAll, Reduce); -class ReduceAllInfer : public abstract::OpInferBase { - public: - BaseShapePtr InferShape(const PrimitivePtr &primitive, - const std::vector &input_args) const override { - const int64_t input_num = 1; - MS_EXCEPTION_IF_NULL(primitive); - CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num, - primitive->name()); - return ReduceBaseInferShape(primitive, input_args, kNameReduceAll); - } - - TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) const override { - return ReduceBaseInferType(prim, input_args); - } - - std::set GetValueDependArgIndices() const override { return {1}; } -}; -REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceAll, prim::kPrimReduceAll, ReduceAllInfer, false); -} // namespace ops -} // namespace mindspore diff --git a/mindspore/core/ops/reduce_any.cc b/mindspore/core/ops/reduce_any.cc deleted file mode 100644 index 058e1209d60..00000000000 --- a/mindspore/core/ops/reduce_any.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "ops/reduce_any.h" -#include -#include -#include -#include "ops/op_utils.h" -#include "abstract/ops/op_infer.h" -#include "utils/check_convert_utils.h" -#include "abstract/ops/primitive_infer_map.h" -#include "mindapi/src/helper.h" - -namespace mindspore { -namespace ops { -MIND_API_OPERATOR_IMPL(ReduceAny, Reduce); -class ReduceAnyInfer : public abstract::OpInferBase { - public: - BaseShapePtr InferShape(const PrimitivePtr &primitive, - const std::vector &input_args) const override { - const int64_t input_num = 1; - MS_EXCEPTION_IF_NULL(primitive); - CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num, - primitive->name()); - return ReduceBaseInferShape(primitive, input_args, kNameReduceAny); - } - - TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) const override { - return ReduceBaseInferType(prim, input_args); - } - - std::set GetValueDependArgIndices() const override { return {1}; } -}; -REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceAny, prim::kPrimReduceAny, ReduceAnyInfer, false); -} // namespace ops -} // namespace mindspore diff --git a/mindspore/core/ops/reduce_arithmetic.cc b/mindspore/core/ops/reduce_arithmetic.cc new file mode 100644 index 00000000000..a272550a6ce --- /dev/null +++ b/mindspore/core/ops/reduce_arithmetic.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/reduce_all.h" +#include "ops/reduce_any.h" +#include "ops/reduce_max.h" +#include "ops/reduce_min.h" +#include "ops/reduce_sum.h" +#include "ops/reduce_prod.h" +#include "ops/reduce_mean.h" +#include +#include +#include +#include "ops/op_utils.h" +#include "abstract/ops/op_infer.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +MIND_API_OPERATOR_IMPL(ReduceAll, Reduce); +MIND_API_OPERATOR_IMPL(ReduceAny, Reduce); +MIND_API_OPERATOR_IMPL(ReduceMax, Reduce); +MIND_API_OPERATOR_IMPL(ReduceMin, Reduce); +MIND_API_OPERATOR_IMPL(ReduceSum, Reduce); +MIND_API_OPERATOR_IMPL(ReduceProd, Reduce); +MIND_API_OPERATOR_IMPL(ReduceMean, Reduce); +class ReduceArithmeticInfer : public abstract::OpInferBase { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) const override { + const int64_t input_num = 1; + MS_EXCEPTION_IF_NULL(primitive); + CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num, + primitive->name()); + return ReduceBaseInferShape(primitive, input_args, kNameReduceAll); + } + + TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override { + std::set bool_types = {kBool}; + const std::string &op_name = primitive->name(); + static const std::map> check_list_map{ + {prim::kPrimReduceAll->name(), bool_types}, + {prim::kPrimReduceAny->name(), bool_types}, + {prim::kPrimReduceMax->name(), common_valid_types_with_complex_and_bool}, + {prim::kPrimReduceMin->name(), common_valid_types_with_complex_and_bool}, + {prim::kPrimReduceSum->name(), common_valid_types_with_complex_and_bool}, + {prim::kPrimReduceProd->name(), common_valid_types_with_complex}, + {prim::kPrimReduceMean->name(), common_valid_types_with_complex}, + }; + if (check_list_map.find(op_name) == check_list_map.end()) { + MS_EXCEPTION(TypeError) << "For Primitive[" << op_name << "], the current ops do not support this operation."; + } + return ReduceBaseInferType(primitive, input_args, check_list_map.at(op_name)); + } + + std::set GetValueDependArgIndices() const override { return {1}; } +}; +REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceAll, prim::kPrimReduceAll, ReduceArithmeticInfer, false); +REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceAny, prim::kPrimReduceAny, ReduceArithmeticInfer, false); +REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceMax, prim::kPrimReduceMax, ReduceArithmeticInfer, false); +REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceMin, prim::kPrimReduceMin, ReduceArithmeticInfer, false); +REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceSum, prim::kPrimReduceSum, ReduceArithmeticInfer, false); +REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceProd, prim::kPrimReduceProd, ReduceArithmeticInfer, false); +REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceMean, prim::kPrimReduceMean, ReduceArithmeticInfer, false); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/reduce_max.cc b/mindspore/core/ops/reduce_max.cc deleted file mode 100644 index 7d376488c74..00000000000 --- a/mindspore/core/ops/reduce_max.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "ops/reduce_max.h" -#include "ops/op_utils.h" -#include "abstract/ops/op_infer.h" -#include "mindapi/src/helper.h" -#include "utils/check_convert_utils.h" - -namespace mindspore { -namespace ops { -MIND_API_OPERATOR_IMPL(ReduceMax, Reduce); -class ReduceMaxInfer : public abstract::OpInferBase { - public: - BaseShapePtr InferShape(const PrimitivePtr &primitive, - const std::vector &input_args) const override { - const int64_t input_num = 1; - MS_EXCEPTION_IF_NULL(primitive); - CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num, - primitive->name()); - return ReduceBaseInferShape(primitive, input_args, kNameReduceMax); - } - - TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) const override { - return ReduceBaseInferType(prim, input_args); - } - - std::set GetValueDependArgIndices() const override { return {1}; } -}; -REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceMax, prim::kPrimReduceMax, ReduceMaxInfer, false); -} // namespace ops -} // namespace mindspore diff --git a/mindspore/core/ops/reduce_mean.cc b/mindspore/core/ops/reduce_mean.cc deleted file mode 100644 index 3a5d3e21eb1..00000000000 --- a/mindspore/core/ops/reduce_mean.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "ops/reduce_mean.h" -#include -#include -#include -#include -#include -#include "ops/op_utils.h" -#include "abstract/ops/op_infer.h" -#include "utils/check_convert_utils.h" -#include "abstract/ops/primitive_infer_map.h" -#include "mindapi/src/helper.h" - -namespace mindspore { -namespace ops { -MIND_API_OPERATOR_IMPL(ReduceMean, Reduce); -class ReduceMeanInfer : public abstract::OpInferBase { - public: - BaseShapePtr InferShape(const PrimitivePtr &primitive, - const std::vector &input_args) const override { - const int64_t input_num = 1; - MS_EXCEPTION_IF_NULL(primitive); - CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num, - primitive->name()); - return ReduceBaseInferShape(primitive, input_args, kNameReduceMean); - } - - TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) const override { - return ReduceBaseInferType(prim, input_args); - } - - std::set GetValueDependArgIndices() const override { return {1}; } -}; -REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceMean, prim::kPrimReduceMean, ReduceMeanInfer, false); -} // namespace ops -} // namespace mindspore diff --git a/mindspore/core/ops/reduce_min.cc b/mindspore/core/ops/reduce_min.cc deleted file mode 100644 index 89b176daba6..00000000000 --- a/mindspore/core/ops/reduce_min.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ops/reduce_min.h" -#include -#include -#include "ops/primitive_c.h" -#include "ops/op_utils.h" -#include "abstract/ops/op_infer.h" -#include "utils/check_convert_utils.h" -#include "mindapi/src/helper.h" - -namespace mindspore { -namespace ops { -MIND_API_OPERATOR_IMPL(ReduceMin, Reduce); -class ReduceMinInfer : public abstract::OpInferBase { - public: - BaseShapePtr InferShape(const PrimitivePtr &primitive, - const std::vector &input_args) const override { - const int64_t input_num = 1; - MS_EXCEPTION_IF_NULL(primitive); - CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num, - primitive->name()); - return ReduceBaseInferShape(primitive, input_args, kNameReduceMin); - } - - TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) const override { - return ReduceBaseInferType(prim, input_args); - } - - std::set GetValueDependArgIndices() const override { return {1}; } -}; -REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceMin, prim::kPrimReduceMin, ReduceMinInfer, false); -} // namespace ops -} // namespace mindspore diff --git a/mindspore/core/ops/reduce_prod.cc b/mindspore/core/ops/reduce_prod.cc deleted file mode 100644 index 3173ebf3c60..00000000000 --- a/mindspore/core/ops/reduce_prod.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "ops/reduce_prod.h" -#include -#include -#include -#include -#include -#include "ops/op_utils.h" -#include "abstract/ops/op_infer.h" -#include "utils/check_convert_utils.h" -#include "abstract/ops/primitive_infer_map.h" -#include "mindapi/src/helper.h" - -namespace mindspore { -namespace ops { -MIND_API_OPERATOR_IMPL(ReduceProd, Reduce); -class ReduceProdInfer : public abstract::OpInferBase { - public: - BaseShapePtr InferShape(const PrimitivePtr &primitive, - const std::vector &input_args) const override { - const int64_t input_num = 1; - MS_EXCEPTION_IF_NULL(primitive); - CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num, - primitive->name()); - return ReduceBaseInferShape(primitive, input_args, kNameReduceProd); - } - - TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) const override { - return ReduceBaseInferType(prim, input_args); - } - - std::set GetValueDependArgIndices() const override { return {1}; } -}; -REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceProd, prim::kPrimReduceProd, ReduceProdInfer, false); -} // namespace ops -} // namespace mindspore diff --git a/mindspore/core/ops/reduce_sum.cc b/mindspore/core/ops/reduce_sum.cc deleted file mode 100644 index 0741030fd91..00000000000 --- a/mindspore/core/ops/reduce_sum.cc +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include "ops/reduce_sum.h" -#include "ops/op_utils.h" -#include "abstract/ops/op_infer.h" -#include "utils/check_convert_utils.h" -#include "mindapi/src/helper.h" - -namespace mindspore { -namespace ops { -MIND_API_OPERATOR_IMPL(ReduceSum, Reduce); -class ReduceSumInfer : public abstract::OpInferBase { - public: - BaseShapePtr InferShape(const PrimitivePtr &primitive, - const std::vector &input_args) const override { - const int64_t input_num = 1; - MS_EXCEPTION_IF_NULL(primitive); - CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num, - primitive->name()); - return ReduceBaseInferShape(primitive, input_args, kNameReduceSum); - } - - TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) const override { - return ReduceBaseInferType(prim, input_args); - } - - std::set GetValueDependArgIndices() const override { return {1}; } -}; -REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceSum, prim::kPrimReduceSum, ReduceSumInfer, false); -} // namespace ops -} // namespace mindspore