forked from mindspore-Ecosystem/mindspore
commit
ab4f27ed71
|
@ -66,6 +66,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
|
|||
static const auto &kRaggedRange = prim::kPrimRaggedRange->name();
|
||||
static const auto &kDynamicBroadcastTo = prim::kPrimDynamicBroadcastTo->name();
|
||||
static const auto &kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name();
|
||||
static const auto &kUnsortedSegmentProd = prim::kPrimUnsortedSegmentProd->name();
|
||||
static const auto &kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name();
|
||||
static const auto &kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name();
|
||||
static const auto &kGather = prim::kPrimGather->name();
|
||||
|
@ -88,6 +89,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
|
|||
{kFractionalAvgPoolGrad, ShapeSet{0}},
|
||||
{kUnsortedSegmentMin, ShapeSet{2}},
|
||||
{kUnsortedSegmentMax, ShapeSet{2}},
|
||||
{kUnsortedSegmentProd, ShapeSet{2}},
|
||||
{kMatrixDiagV3, ShapeSet{1, 2, 3, 4}},
|
||||
{kMatrixDiagPartV3, ShapeSet{1, 2}},
|
||||
{kMatrixSetDiagV3, ShapeSet{2}},
|
||||
|
|
|
@ -321,6 +321,7 @@ GVAR_DEF(PrimitivePtr, kPrimUnstack, std::make_shared<Primitive>(kUnstack));
|
|||
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentMax, std::make_shared<Primitive>("UnsortedSegmentMax"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentSum, std::make_shared<Primitive>("UnsortedSegmentSum"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentMin, std::make_shared<Primitive>("UnsortedSegmentMin"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentProd, std::make_shared<Primitive>("UnsortedSegmenProd"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimConcatOffset, std::make_shared<Primitive>("ConcatOffset"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimReshape, std::make_shared<Primitive>("Reshape"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSubAndFilter, std::make_shared<Primitive>("SubAndFilter"));
|
||||
|
@ -365,6 +366,7 @@ GVAR_DEF(PrimitivePtr, kPrimScatterAddWithAxis, std::make_shared<Primitive>("Sca
|
|||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterUpdate, std::make_shared<Primitive>("TensorScatterUpdate"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterAdd, std::make_shared<Primitive>("TensorScatterAdd"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterSub, std::make_shared<Primitive>("TensorScatterSub"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterMul, std::make_shared<Primitive>("TensorScatterMul"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterDiv, std::make_shared<Primitive>("TensorScatterDiv"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterMax, std::make_shared<Primitive>("TensorScatterMax"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterMin, std::make_shared<Primitive>("TensorScatterMin"));
|
||||
|
|
|
@ -247,6 +247,7 @@ constexpr auto kSplitStride = "split_stride";
|
|||
constexpr auto kExtendTop = "extend_top";
|
||||
constexpr auto kExtendBottom = "extend_bottom";
|
||||
constexpr auto kNumberSplit = "number_split";
|
||||
constexpr auto kNumSegments = "num_segments";
|
||||
constexpr auto kSplitDim = "split_dim";
|
||||
constexpr auto kPadTop = "pad_top";
|
||||
constexpr auto kTransFormat = "trans_format";
|
||||
|
|
|
@ -24,6 +24,9 @@
|
|||
#include "ops/tensor_scatter_sub.h"
|
||||
#include "ops/tensor_scatter_max.h"
|
||||
#include "ops/tensor_scatter_min.h"
|
||||
#include "ops/tensor_scatter_mul.h"
|
||||
#include "ops/tensor_scatter_div.h"
|
||||
#include "ops/tensor_scatter_update.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
@ -98,6 +101,9 @@ MIND_API_BASE_IMPL(TensorScatterAdd, PrimitiveC, BaseOperator);
|
|||
MIND_API_BASE_IMPL(TensorScatterSub, PrimitiveC, BaseOperator);
|
||||
MIND_API_BASE_IMPL(TensorScatterMax, PrimitiveC, BaseOperator);
|
||||
MIND_API_BASE_IMPL(TensorScatterMin, PrimitiveC, BaseOperator);
|
||||
MIND_API_BASE_IMPL(TensorScatterDiv, PrimitiveC, BaseOperator);
|
||||
MIND_API_BASE_IMPL(TensorScatterMul, PrimitiveC, BaseOperator);
|
||||
MIND_API_BASE_IMPL(TensorScatterUpdate, PrimitiveC, BaseOperator);
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterAdd, prim::kPrimTensorScatterAdd, TensorScatterArithmeticInfer, nullptr,
|
||||
true);
|
||||
|
@ -107,5 +113,11 @@ REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterMax, prim::kPrimTensorScatterMax, Tens
|
|||
true);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterMin, prim::kPrimTensorScatterMin, TensorScatterArithmeticInfer, nullptr,
|
||||
true);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterDiv, prim::kPrimTensorScatterDiv, TensorScatterArithmeticInfer, nullptr,
|
||||
true);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterMul, prim::kPrimTensorScatterMul, TensorScatterArithmeticInfer, nullptr,
|
||||
true);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterUpdate, prim::kPrimTensorScatterUpdate, TensorScatterArithmeticInfer, nullptr,
|
||||
true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_TENSOR_SCATTER_DIV_H_
|
||||
#define MINDSPORE_CORE_OPS_TENSOR_SCATTER_DIV_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameTensorScatterDiv = "TensorScatterDiv";
|
||||
/// \brief By division the value at the position indicated by the index in input_x with the value in the update, the
|
||||
/// value at the index will eventually be equal to the largest one to create a new tensor.
|
||||
class MIND_API TensorScatterDiv : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(TensorScatterDiv);
|
||||
/// \brief Constructor.
|
||||
TensorScatterDiv() : BaseOperator(kNameTensorScatterDiv) { InitIOName({"input_x", "indices", "updates"}, {"y"}); }
|
||||
};
|
||||
|
||||
using kPrimTensorScatterDivPtr = std::shared_ptr<TensorScatterDiv>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_TENSOR_SCATTER_DIV_H_
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_TENSOR_SCATTER_MUL_H_
|
||||
#define MINDSPORE_CORE_OPS_TENSOR_SCATTER_MUL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameTensorScatterMul = "TensorScatterMul";
|
||||
/// \brief By multiple the value at the position indicated by the index in input_x with the value in the update, the
|
||||
/// value at the index will eventually be equal to the largest one to create a new tensor.
|
||||
class MIND_API TensorScatterMul : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(TensorScatterMul);
|
||||
/// \brief Constructor.
|
||||
TensorScatterMul() : BaseOperator(kNameTensorScatterMul) { InitIOName({"input_x", "indices", "updates"}, {"y"}); }
|
||||
};
|
||||
using kPrimTensorScatterMulPtr = std::shared_ptr<TensorScatterMul>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_TENSOR_SCATTER_MUL_H_
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_TENSOR_SCATTER_UPDATE_H_
|
||||
#define MINDSPORE_CORE_OPS_TENSOR_SCATTER_UPDATE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameTensorScatterUpdate = "TensorScatterUpdate";
|
||||
/// \brief By update the value at the position indicated by the index in input_x with the value in the update, the
|
||||
/// value at the index will eventually be equal to the largest one to create a new tensor.
|
||||
class MIND_API TensorScatterUpdate : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(TensorScatterUpdate);
|
||||
/// \brief Constructor.
|
||||
TensorScatterUpdate() : BaseOperator(kNameTensorScatterUpdate) {
|
||||
InitIOName({"input_x", "indices", "updates"}, {"y"});
|
||||
}
|
||||
};
|
||||
using kPrimTensorScatterUpdatePtr = std::shared_ptr<TensorScatterUpdate>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_TENSOR_SCATTER_UPDATE_H_
|
|
@ -0,0 +1,114 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/unsorted_segment_arithmetic.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/unsorted_segment_max.h"
|
||||
#include "ops/unsorted_segment_min.h"
|
||||
#include "ops/unsorted_segment_prod.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr UnsortedSegmentArithmeticInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x_shape_ptr);
|
||||
auto segment_ids_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(segment_ids_shape_ptr);
|
||||
auto num_segments_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(num_segments_shape_ptr);
|
||||
|
||||
auto num_segments = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
|
||||
int64_t num_segments_value = 0;
|
||||
|
||||
if (num_segments != nullptr && num_segments->BuildValue() != kAnyValue) {
|
||||
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
|
||||
primitive->AddAttr(kNumSegments, MakeValue(num_segments_value));
|
||||
}
|
||||
|
||||
std::vector<int64_t> invalid_out_shape = {-1};
|
||||
if (x_shape_ptr->IsDynamic() || segment_ids_shape_ptr->IsDynamic() || num_segments_shape_ptr->IsDynamic()) {
|
||||
return std::make_shared<abstract::Shape>(invalid_out_shape);
|
||||
}
|
||||
|
||||
if (num_segments == nullptr || num_segments->BuildValue() == kAnyValue) {
|
||||
auto value_ptr = primitive->GetAttr(kNumSegments);
|
||||
if (value_ptr != nullptr) {
|
||||
num_segments_value = GetValue<int64_t>(value_ptr);
|
||||
} else {
|
||||
return std::make_shared<abstract::Shape>(invalid_out_shape);
|
||||
}
|
||||
}
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
if (x_shape.size() < ids_shape.size()) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", invalid input_args and segment_ids shape size";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < ids_shape.size(); i++) {
|
||||
if (x_shape[i] != ids_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", invalid input_args and segment_ids shape[" << i
|
||||
<< "]: " << x_shape[i] << ", " << ids_shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> out_shape;
|
||||
out_shape.push_back(num_segments_value);
|
||||
for (size_t i = ids_shape.size(); i < x_shape.size(); i++) {
|
||||
out_shape.push_back(x_shape.at(i));
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr UnsortedSegmentArithmeticInferType(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto in_type_ptr = input_args[kInputIndex0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(in_type_ptr);
|
||||
std::set<TypePtr> in_type_set = {kFloat16, kFloat32, kInt32};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("x", in_type_ptr, in_type_set, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr UnsortedSegmentArithmeticInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputNum = 3;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
|
||||
auto infer_type = UnsortedSegmentArithmeticInferType(primitive, input_args);
|
||||
auto infer_shape = UnsortedSegmentArithmeticInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(UnsortedSegmentMax, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(UnsortedSegmentMin, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(UnsortedSegmentProd, BaseOperator);
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(UnsortedSegmentMax, prim::kPrimUnsortedSegmentMax, UnsortedSegmentArithmeticInfer, nullptr,
|
||||
true);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(UnsortedSegmentMin, prim::kPrimUnsortedSegmentMin, UnsortedSegmentArithmeticInfer, nullptr,
|
||||
true);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(UnsortedSegmentProd, prim::kPrimUnsortedSegmentProd, UnsortedSegmentArithmeticInfer,
|
||||
nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_ARITHMETIC_H_
|
||||
#define MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_ARITHMETIC_H_
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
abstract::AbstractBasePtr UnsortedSegmentArithmeticInfer(const abstract::AnalysisEnginePtr &,
|
||||
const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_ARITHMETIC_H_
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MAX_H_
|
||||
#define MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MAX_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameUnsortedSegmentMax = "UnsortedSegmentMax";
|
||||
/// \brief Computes the max of a tensor along segments.
|
||||
/// Refer to Python API @ref mindspore.ops.UnsortedSegmentMax for more details.
|
||||
class MIND_API UnsortedSegmentMax : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(UnsortedSegmentMax);
|
||||
/// \brief Constructor.
|
||||
UnsortedSegmentMax() : BaseOperator(kNameUnsortedSegmentMax) {
|
||||
InitIOName({"x", "segment_ids", "num_segments"}, {"y"});
|
||||
}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MAX_H_
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MIN_H_
|
||||
#define MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MIN_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameUnsortedSegmentMin = "UnsortedSegmentMin";
|
||||
/// \brief Computes the min of a tensor along segments.
|
||||
/// Refer to Python API @ref mindspore.ops.UnsortedSegmentMin for more details.
|
||||
class MIND_API UnsortedSegmentMin : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(UnsortedSegmentMin);
|
||||
/// \brief Constructor.
|
||||
UnsortedSegmentMin() : BaseOperator(kNameUnsortedSegmentMin) {
|
||||
InitIOName({"x", "segment_ids", "num_segments"}, {"y"});
|
||||
}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MIN_H_
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_PROD_H_
|
||||
#define MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_PROD_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameUnsortedSegmentProd = "UnsortedSegmentProd";
|
||||
/// \brief Computes the prod of a tensor along segments.
|
||||
/// Refer to Python API @ref mindspore.ops.UnsortedSegmentProd for more details.
|
||||
class MIND_API UnsortedSegmentProd : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(UnsortedSegmentProd);
|
||||
/// \brief Constructor.
|
||||
UnsortedSegmentProd() : BaseOperator(kNameUnsortedSegmentProd) {
|
||||
InitIOName({"x", "segment_ids", "num_segments"}, {"y"});
|
||||
}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_PROD_H_
|
Loading…
Reference in New Issue