diff --git a/mindspore/core/abstract/ops/primitive_infer_map.cc b/mindspore/core/abstract/ops/primitive_infer_map.cc index 380758c122c..bf9c215e5b4 100644 --- a/mindspore/core/abstract/ops/primitive_infer_map.cc +++ b/mindspore/core/abstract/ops/primitive_infer_map.cc @@ -66,6 +66,7 @@ std::set GetDependsFormMap(const std::string &prim_name, size_t input_n static const auto &kRaggedRange = prim::kPrimRaggedRange->name(); static const auto &kDynamicBroadcastTo = prim::kPrimDynamicBroadcastTo->name(); static const auto &kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name(); + static const auto &kUnsortedSegmentProd = prim::kPrimUnsortedSegmentProd->name(); static const auto &kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name(); static const auto &kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name(); static const auto &kGather = prim::kPrimGather->name(); @@ -88,6 +89,7 @@ std::set GetDependsFormMap(const std::string &prim_name, size_t input_n {kFractionalAvgPoolGrad, ShapeSet{0}}, {kUnsortedSegmentMin, ShapeSet{2}}, {kUnsortedSegmentMax, ShapeSet{2}}, + {kUnsortedSegmentProd, ShapeSet{2}}, {kMatrixDiagV3, ShapeSet{1, 2, 3, 4}}, {kMatrixDiagPartV3, ShapeSet{1, 2}}, {kMatrixSetDiagV3, ShapeSet{2}}, diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 08948915e6f..0f013877544 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -321,6 +321,7 @@ GVAR_DEF(PrimitivePtr, kPrimUnstack, std::make_shared(kUnstack)); GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentMax, std::make_shared("UnsortedSegmentMax")); GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentSum, std::make_shared("UnsortedSegmentSum")); GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentMin, std::make_shared("UnsortedSegmentMin")); +GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentProd, std::make_shared("UnsortedSegmenProd")); GVAR_DEF(PrimitivePtr, kPrimConcatOffset, std::make_shared("ConcatOffset")); GVAR_DEF(PrimitivePtr, kPrimReshape, std::make_shared("Reshape")); GVAR_DEF(PrimitivePtr, kPrimSubAndFilter, std::make_shared("SubAndFilter")); @@ -365,6 +366,7 @@ GVAR_DEF(PrimitivePtr, kPrimScatterAddWithAxis, std::make_shared("Sca GVAR_DEF(PrimitivePtr, kPrimTensorScatterUpdate, std::make_shared("TensorScatterUpdate")); GVAR_DEF(PrimitivePtr, kPrimTensorScatterAdd, std::make_shared("TensorScatterAdd")); GVAR_DEF(PrimitivePtr, kPrimTensorScatterSub, std::make_shared("TensorScatterSub")); +GVAR_DEF(PrimitivePtr, kPrimTensorScatterMul, std::make_shared("TensorScatterMul")); GVAR_DEF(PrimitivePtr, kPrimTensorScatterDiv, std::make_shared("TensorScatterDiv")); GVAR_DEF(PrimitivePtr, kPrimTensorScatterMax, std::make_shared("TensorScatterMax")); GVAR_DEF(PrimitivePtr, kPrimTensorScatterMin, std::make_shared("TensorScatterMin")); diff --git a/mindspore/core/ops/op_name.h b/mindspore/core/ops/op_name.h index dddac1d5cf7..52a93ea0615 100644 --- a/mindspore/core/ops/op_name.h +++ b/mindspore/core/ops/op_name.h @@ -247,6 +247,7 @@ constexpr auto kSplitStride = "split_stride"; constexpr auto kExtendTop = "extend_top"; constexpr auto kExtendBottom = "extend_bottom"; constexpr auto kNumberSplit = "number_split"; +constexpr auto kNumSegments = "num_segments"; constexpr auto kSplitDim = "split_dim"; constexpr auto kPadTop = "pad_top"; constexpr auto kTransFormat = "trans_format"; diff --git a/mindspore/core/ops/tensor_scatter_arithmetic.cc b/mindspore/core/ops/tensor_scatter_arithmetic.cc index 210135019e7..ef3b1d51cb7 100644 --- a/mindspore/core/ops/tensor_scatter_arithmetic.cc +++ b/mindspore/core/ops/tensor_scatter_arithmetic.cc @@ -24,6 +24,9 @@ #include "ops/tensor_scatter_sub.h" #include "ops/tensor_scatter_max.h" #include "ops/tensor_scatter_min.h" +#include "ops/tensor_scatter_mul.h" +#include "ops/tensor_scatter_div.h" +#include "ops/tensor_scatter_update.h" #include "abstract/ops/primitive_infer_map.h" #include "ops/op_utils.h" #include "utils/check_convert_utils.h" @@ -98,6 +101,9 @@ MIND_API_BASE_IMPL(TensorScatterAdd, PrimitiveC, BaseOperator); MIND_API_BASE_IMPL(TensorScatterSub, PrimitiveC, BaseOperator); MIND_API_BASE_IMPL(TensorScatterMax, PrimitiveC, BaseOperator); MIND_API_BASE_IMPL(TensorScatterMin, PrimitiveC, BaseOperator); +MIND_API_BASE_IMPL(TensorScatterDiv, PrimitiveC, BaseOperator); +MIND_API_BASE_IMPL(TensorScatterMul, PrimitiveC, BaseOperator); +MIND_API_BASE_IMPL(TensorScatterUpdate, PrimitiveC, BaseOperator); REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterAdd, prim::kPrimTensorScatterAdd, TensorScatterArithmeticInfer, nullptr, true); @@ -107,5 +113,11 @@ REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterMax, prim::kPrimTensorScatterMax, Tens true); REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterMin, prim::kPrimTensorScatterMin, TensorScatterArithmeticInfer, nullptr, true); +REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterDiv, prim::kPrimTensorScatterDiv, TensorScatterArithmeticInfer, nullptr, + true); +REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterMul, prim::kPrimTensorScatterMul, TensorScatterArithmeticInfer, nullptr, + true); +REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterUpdate, prim::kPrimTensorScatterUpdate, TensorScatterArithmeticInfer, nullptr, + true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/tensor_scatter_div.h b/mindspore/core/ops/tensor_scatter_div.h new file mode 100644 index 00000000000..cf668350516 --- /dev/null +++ b/mindspore/core/ops/tensor_scatter_div.h @@ -0,0 +1,41 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_TENSOR_SCATTER_DIV_H_ +#define MINDSPORE_CORE_OPS_TENSOR_SCATTER_DIV_H_ +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameTensorScatterDiv = "TensorScatterDiv"; +/// \brief By division the value at the position indicated by the index in input_x with the value in the update, the +/// value at the index will eventually be equal to the largest one to create a new tensor. +class MIND_API TensorScatterDiv : public BaseOperator { + public: + MIND_API_BASE_MEMBER(TensorScatterDiv); + /// \brief Constructor. + TensorScatterDiv() : BaseOperator(kNameTensorScatterDiv) { InitIOName({"input_x", "indices", "updates"}, {"y"}); } +}; + +using kPrimTensorScatterDivPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_TENSOR_SCATTER_DIV_H_ diff --git a/mindspore/core/ops/tensor_scatter_mul.h b/mindspore/core/ops/tensor_scatter_mul.h new file mode 100644 index 00000000000..733a770e010 --- /dev/null +++ b/mindspore/core/ops/tensor_scatter_mul.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_TENSOR_SCATTER_MUL_H_ +#define MINDSPORE_CORE_OPS_TENSOR_SCATTER_MUL_H_ + +#include +#include +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameTensorScatterMul = "TensorScatterMul"; +/// \brief By multiple the value at the position indicated by the index in input_x with the value in the update, the +/// value at the index will eventually be equal to the largest one to create a new tensor. +class MIND_API TensorScatterMul : public BaseOperator { + public: + MIND_API_BASE_MEMBER(TensorScatterMul); + /// \brief Constructor. + TensorScatterMul() : BaseOperator(kNameTensorScatterMul) { InitIOName({"input_x", "indices", "updates"}, {"y"}); } +}; +using kPrimTensorScatterMulPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_TENSOR_SCATTER_MUL_H_ diff --git a/mindspore/core/ops/tensor_scatter_update.h b/mindspore/core/ops/tensor_scatter_update.h new file mode 100644 index 00000000000..de9ddbb6814 --- /dev/null +++ b/mindspore/core/ops/tensor_scatter_update.h @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_TENSOR_SCATTER_UPDATE_H_ +#define MINDSPORE_CORE_OPS_TENSOR_SCATTER_UPDATE_H_ + +#include +#include +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameTensorScatterUpdate = "TensorScatterUpdate"; +/// \brief By update the value at the position indicated by the index in input_x with the value in the update, the +/// value at the index will eventually be equal to the largest one to create a new tensor. +class MIND_API TensorScatterUpdate : public BaseOperator { + public: + MIND_API_BASE_MEMBER(TensorScatterUpdate); + /// \brief Constructor. + TensorScatterUpdate() : BaseOperator(kNameTensorScatterUpdate) { + InitIOName({"input_x", "indices", "updates"}, {"y"}); + } +}; +using kPrimTensorScatterUpdatePtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_TENSOR_SCATTER_UPDATE_H_ diff --git a/mindspore/core/ops/unsorted_segment_arithmetic.cc b/mindspore/core/ops/unsorted_segment_arithmetic.cc new file mode 100644 index 00000000000..99137759845 --- /dev/null +++ b/mindspore/core/ops/unsorted_segment_arithmetic.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/unsorted_segment_arithmetic.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "ops/op_utils.h" +#include "mindapi/src/helper.h" +#include "ops/unsorted_segment_max.h" +#include "ops/unsorted_segment_min.h" +#include "ops/unsorted_segment_prod.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr UnsortedSegmentArithmeticInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + auto x_shape_ptr = input_args[kInputIndex0]->BuildShape(); + MS_EXCEPTION_IF_NULL(x_shape_ptr); + auto segment_ids_shape_ptr = input_args[kInputIndex1]->BuildShape(); + MS_EXCEPTION_IF_NULL(segment_ids_shape_ptr); + auto num_segments_shape_ptr = input_args[kInputIndex2]->BuildShape(); + MS_EXCEPTION_IF_NULL(num_segments_shape_ptr); + + auto num_segments = input_args[kInputIndex2]->cast(); + int64_t num_segments_value = 0; + + if (num_segments != nullptr && num_segments->BuildValue() != kAnyValue) { + num_segments_value = GetValue(num_segments->BuildValue()); + primitive->AddAttr(kNumSegments, MakeValue(num_segments_value)); + } + + std::vector invalid_out_shape = {-1}; + if (x_shape_ptr->IsDynamic() || segment_ids_shape_ptr->IsDynamic() || num_segments_shape_ptr->IsDynamic()) { + return std::make_shared(invalid_out_shape); + } + + if (num_segments == nullptr || num_segments->BuildValue() == kAnyValue) { + auto value_ptr = primitive->GetAttr(kNumSegments); + if (value_ptr != nullptr) { + num_segments_value = GetValue(value_ptr); + } else { + return std::make_shared(invalid_out_shape); + } + } + + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + if (x_shape.size() < ids_shape.size()) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", invalid input_args and segment_ids shape size"; + } + + for (size_t i = 0; i < ids_shape.size(); i++) { + if (x_shape[i] != ids_shape[i]) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", invalid input_args and segment_ids shape[" << i + << "]: " << x_shape[i] << ", " << ids_shape[i]; + } + } + + std::vector out_shape; + out_shape.push_back(num_segments_value); + for (size_t i = ids_shape.size(); i < x_shape.size(); i++) { + out_shape.push_back(x_shape.at(i)); + } + return std::make_shared(out_shape); +} + +TypePtr UnsortedSegmentArithmeticInferType(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + auto in_type_ptr = input_args[kInputIndex0]->BuildType(); + MS_EXCEPTION_IF_NULL(in_type_ptr); + std::set in_type_set = {kFloat16, kFloat32, kInt32}; + return CheckAndConvertUtils::CheckTensorTypeValid("x", in_type_ptr, in_type_set, prim_name); +} +} // namespace + +AbstractBasePtr UnsortedSegmentArithmeticInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t kInputNum = 3; + (void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name()); + auto infer_type = UnsortedSegmentArithmeticInferType(primitive, input_args); + auto infer_shape = UnsortedSegmentArithmeticInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} + +MIND_API_OPERATOR_IMPL(UnsortedSegmentMax, BaseOperator); +MIND_API_OPERATOR_IMPL(UnsortedSegmentMin, BaseOperator); +MIND_API_OPERATOR_IMPL(UnsortedSegmentProd, BaseOperator); + +REGISTER_PRIMITIVE_EVAL_IMPL(UnsortedSegmentMax, prim::kPrimUnsortedSegmentMax, UnsortedSegmentArithmeticInfer, nullptr, + true); +REGISTER_PRIMITIVE_EVAL_IMPL(UnsortedSegmentMin, prim::kPrimUnsortedSegmentMin, UnsortedSegmentArithmeticInfer, nullptr, + true); +REGISTER_PRIMITIVE_EVAL_IMPL(UnsortedSegmentProd, prim::kPrimUnsortedSegmentProd, UnsortedSegmentArithmeticInfer, + nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/unsorted_segment_arithmetic.h b/mindspore/core/ops/unsorted_segment_arithmetic.h new file mode 100644 index 00000000000..8ff8296a669 --- /dev/null +++ b/mindspore/core/ops/unsorted_segment_arithmetic.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_ARITHMETIC_H_ +#define MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_ARITHMETIC_H_ + +#include +#include +#include +#include +#include +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +abstract::AbstractBasePtr UnsortedSegmentArithmeticInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_ARITHMETIC_H_ diff --git a/mindspore/core/ops/unsorted_segment_max.h b/mindspore/core/ops/unsorted_segment_max.h new file mode 100644 index 00000000000..750f2ef6372 --- /dev/null +++ b/mindspore/core/ops/unsorted_segment_max.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MAX_H_ +#define MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MAX_H_ + +#include +#include +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameUnsortedSegmentMax = "UnsortedSegmentMax"; +/// \brief Computes the max of a tensor along segments. +/// Refer to Python API @ref mindspore.ops.UnsortedSegmentMax for more details. +class MIND_API UnsortedSegmentMax : public BaseOperator { + public: + MIND_API_BASE_MEMBER(UnsortedSegmentMax); + /// \brief Constructor. + UnsortedSegmentMax() : BaseOperator(kNameUnsortedSegmentMax) { + InitIOName({"x", "segment_ids", "num_segments"}, {"y"}); + } +}; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MAX_H_ diff --git a/mindspore/core/ops/unsorted_segment_min.h b/mindspore/core/ops/unsorted_segment_min.h new file mode 100644 index 00000000000..18ef0fae09d --- /dev/null +++ b/mindspore/core/ops/unsorted_segment_min.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MIN_H_ +#define MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MIN_H_ + +#include +#include +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameUnsortedSegmentMin = "UnsortedSegmentMin"; +/// \brief Computes the min of a tensor along segments. +/// Refer to Python API @ref mindspore.ops.UnsortedSegmentMin for more details. +class MIND_API UnsortedSegmentMin : public BaseOperator { + public: + MIND_API_BASE_MEMBER(UnsortedSegmentMin); + /// \brief Constructor. + UnsortedSegmentMin() : BaseOperator(kNameUnsortedSegmentMin) { + InitIOName({"x", "segment_ids", "num_segments"}, {"y"}); + } +}; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MIN_H_ diff --git a/mindspore/core/ops/unsorted_segment_prod.h b/mindspore/core/ops/unsorted_segment_prod.h new file mode 100644 index 00000000000..cf8afdd1576 --- /dev/null +++ b/mindspore/core/ops/unsorted_segment_prod.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_PROD_H_ +#define MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_PROD_H_ + +#include +#include +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameUnsortedSegmentProd = "UnsortedSegmentProd"; +/// \brief Computes the prod of a tensor along segments. +/// Refer to Python API @ref mindspore.ops.UnsortedSegmentProd for more details. +class MIND_API UnsortedSegmentProd : public BaseOperator { + public: + MIND_API_BASE_MEMBER(UnsortedSegmentProd); + /// \brief Constructor. + UnsortedSegmentProd() : BaseOperator(kNameUnsortedSegmentProd) { + InitIOName({"x", "segment_ids", "num_segments"}, {"y"}); + } +}; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_PROD_H_