!34215 [MSLITE] add op info

Merge pull request !34215 from ling/op
This commit is contained in:
i-robot 2022-05-11 09:06:59 +00:00 committed by Gitee
commit ab4f27ed71
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 410 additions and 0 deletions

View File

@ -66,6 +66,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
static const auto &kRaggedRange = prim::kPrimRaggedRange->name();
static const auto &kDynamicBroadcastTo = prim::kPrimDynamicBroadcastTo->name();
static const auto &kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name();
static const auto &kUnsortedSegmentProd = prim::kPrimUnsortedSegmentProd->name();
static const auto &kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name();
static const auto &kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name();
static const auto &kGather = prim::kPrimGather->name();
@ -88,6 +89,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
{kFractionalAvgPoolGrad, ShapeSet{0}},
{kUnsortedSegmentMin, ShapeSet{2}},
{kUnsortedSegmentMax, ShapeSet{2}},
{kUnsortedSegmentProd, ShapeSet{2}},
{kMatrixDiagV3, ShapeSet{1, 2, 3, 4}},
{kMatrixDiagPartV3, ShapeSet{1, 2}},
{kMatrixSetDiagV3, ShapeSet{2}},

View File

@ -321,6 +321,7 @@ GVAR_DEF(PrimitivePtr, kPrimUnstack, std::make_shared<Primitive>(kUnstack));
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentMax, std::make_shared<Primitive>("UnsortedSegmentMax"));
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentSum, std::make_shared<Primitive>("UnsortedSegmentSum"));
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentMin, std::make_shared<Primitive>("UnsortedSegmentMin"));
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentProd, std::make_shared<Primitive>("UnsortedSegmenProd"));
GVAR_DEF(PrimitivePtr, kPrimConcatOffset, std::make_shared<Primitive>("ConcatOffset"));
GVAR_DEF(PrimitivePtr, kPrimReshape, std::make_shared<Primitive>("Reshape"));
GVAR_DEF(PrimitivePtr, kPrimSubAndFilter, std::make_shared<Primitive>("SubAndFilter"));
@ -365,6 +366,7 @@ GVAR_DEF(PrimitivePtr, kPrimScatterAddWithAxis, std::make_shared<Primitive>("Sca
GVAR_DEF(PrimitivePtr, kPrimTensorScatterUpdate, std::make_shared<Primitive>("TensorScatterUpdate"));
GVAR_DEF(PrimitivePtr, kPrimTensorScatterAdd, std::make_shared<Primitive>("TensorScatterAdd"));
GVAR_DEF(PrimitivePtr, kPrimTensorScatterSub, std::make_shared<Primitive>("TensorScatterSub"));
GVAR_DEF(PrimitivePtr, kPrimTensorScatterMul, std::make_shared<Primitive>("TensorScatterMul"));
GVAR_DEF(PrimitivePtr, kPrimTensorScatterDiv, std::make_shared<Primitive>("TensorScatterDiv"));
GVAR_DEF(PrimitivePtr, kPrimTensorScatterMax, std::make_shared<Primitive>("TensorScatterMax"));
GVAR_DEF(PrimitivePtr, kPrimTensorScatterMin, std::make_shared<Primitive>("TensorScatterMin"));

View File

@ -247,6 +247,7 @@ constexpr auto kSplitStride = "split_stride";
constexpr auto kExtendTop = "extend_top";
constexpr auto kExtendBottom = "extend_bottom";
constexpr auto kNumberSplit = "number_split";
constexpr auto kNumSegments = "num_segments";
constexpr auto kSplitDim = "split_dim";
constexpr auto kPadTop = "pad_top";
constexpr auto kTransFormat = "trans_format";

View File

@ -24,6 +24,9 @@
#include "ops/tensor_scatter_sub.h"
#include "ops/tensor_scatter_max.h"
#include "ops/tensor_scatter_min.h"
#include "ops/tensor_scatter_mul.h"
#include "ops/tensor_scatter_div.h"
#include "ops/tensor_scatter_update.h"
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
@ -98,6 +101,9 @@ MIND_API_BASE_IMPL(TensorScatterAdd, PrimitiveC, BaseOperator);
MIND_API_BASE_IMPL(TensorScatterSub, PrimitiveC, BaseOperator);
MIND_API_BASE_IMPL(TensorScatterMax, PrimitiveC, BaseOperator);
MIND_API_BASE_IMPL(TensorScatterMin, PrimitiveC, BaseOperator);
MIND_API_BASE_IMPL(TensorScatterDiv, PrimitiveC, BaseOperator);
MIND_API_BASE_IMPL(TensorScatterMul, PrimitiveC, BaseOperator);
MIND_API_BASE_IMPL(TensorScatterUpdate, PrimitiveC, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterAdd, prim::kPrimTensorScatterAdd, TensorScatterArithmeticInfer, nullptr,
true);
@ -107,5 +113,11 @@ REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterMax, prim::kPrimTensorScatterMax, Tens
true);
REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterMin, prim::kPrimTensorScatterMin, TensorScatterArithmeticInfer, nullptr,
true);
REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterDiv, prim::kPrimTensorScatterDiv, TensorScatterArithmeticInfer, nullptr,
true);
REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterMul, prim::kPrimTensorScatterMul, TensorScatterArithmeticInfer, nullptr,
true);
REGISTER_PRIMITIVE_EVAL_IMPL(TensorScatterUpdate, prim::kPrimTensorScatterUpdate, TensorScatterArithmeticInfer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_TENSOR_SCATTER_DIV_H_
#define MINDSPORE_CORE_OPS_TENSOR_SCATTER_DIV_H_
#include <memory>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameTensorScatterDiv = "TensorScatterDiv";
/// \brief By division the value at the position indicated by the index in input_x with the value in the update, the
/// value at the index will eventually be equal to the largest one to create a new tensor.
class MIND_API TensorScatterDiv : public BaseOperator {
public:
MIND_API_BASE_MEMBER(TensorScatterDiv);
/// \brief Constructor.
TensorScatterDiv() : BaseOperator(kNameTensorScatterDiv) { InitIOName({"input_x", "indices", "updates"}, {"y"}); }
};
using kPrimTensorScatterDivPtr = std::shared_ptr<TensorScatterDiv>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_TENSOR_SCATTER_DIV_H_

View File

@ -0,0 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_TENSOR_SCATTER_MUL_H_
#define MINDSPORE_CORE_OPS_TENSOR_SCATTER_MUL_H_
#include <memory>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameTensorScatterMul = "TensorScatterMul";
/// \brief By multiple the value at the position indicated by the index in input_x with the value in the update, the
/// value at the index will eventually be equal to the largest one to create a new tensor.
class MIND_API TensorScatterMul : public BaseOperator {
public:
MIND_API_BASE_MEMBER(TensorScatterMul);
/// \brief Constructor.
TensorScatterMul() : BaseOperator(kNameTensorScatterMul) { InitIOName({"input_x", "indices", "updates"}, {"y"}); }
};
using kPrimTensorScatterMulPtr = std::shared_ptr<TensorScatterMul>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_TENSOR_SCATTER_MUL_H_

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_TENSOR_SCATTER_UPDATE_H_
#define MINDSPORE_CORE_OPS_TENSOR_SCATTER_UPDATE_H_
#include <memory>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameTensorScatterUpdate = "TensorScatterUpdate";
/// \brief By update the value at the position indicated by the index in input_x with the value in the update, the
/// value at the index will eventually be equal to the largest one to create a new tensor.
class MIND_API TensorScatterUpdate : public BaseOperator {
public:
MIND_API_BASE_MEMBER(TensorScatterUpdate);
/// \brief Constructor.
TensorScatterUpdate() : BaseOperator(kNameTensorScatterUpdate) {
InitIOName({"input_x", "indices", "updates"}, {"y"});
}
};
using kPrimTensorScatterUpdatePtr = std::shared_ptr<TensorScatterUpdate>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_TENSOR_SCATTER_UPDATE_H_

View File

@ -0,0 +1,114 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/unsorted_segment_arithmetic.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
#include "ops/unsorted_segment_max.h"
#include "ops/unsorted_segment_min.h"
#include "ops/unsorted_segment_prod.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr UnsortedSegmentArithmeticInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape_ptr = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(x_shape_ptr);
auto segment_ids_shape_ptr = input_args[kInputIndex1]->BuildShape();
MS_EXCEPTION_IF_NULL(segment_ids_shape_ptr);
auto num_segments_shape_ptr = input_args[kInputIndex2]->BuildShape();
MS_EXCEPTION_IF_NULL(num_segments_shape_ptr);
auto num_segments = input_args[kInputIndex2]->cast<abstract::AbstractScalarPtr>();
int64_t num_segments_value = 0;
if (num_segments != nullptr && num_segments->BuildValue() != kAnyValue) {
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
primitive->AddAttr(kNumSegments, MakeValue(num_segments_value));
}
std::vector<int64_t> invalid_out_shape = {-1};
if (x_shape_ptr->IsDynamic() || segment_ids_shape_ptr->IsDynamic() || num_segments_shape_ptr->IsDynamic()) {
return std::make_shared<abstract::Shape>(invalid_out_shape);
}
if (num_segments == nullptr || num_segments->BuildValue() == kAnyValue) {
auto value_ptr = primitive->GetAttr(kNumSegments);
if (value_ptr != nullptr) {
num_segments_value = GetValue<int64_t>(value_ptr);
} else {
return std::make_shared<abstract::Shape>(invalid_out_shape);
}
}
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
if (x_shape.size() < ids_shape.size()) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", invalid input_args and segment_ids shape size";
}
for (size_t i = 0; i < ids_shape.size(); i++) {
if (x_shape[i] != ids_shape[i]) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", invalid input_args and segment_ids shape[" << i
<< "]: " << x_shape[i] << ", " << ids_shape[i];
}
}
std::vector<int64_t> out_shape;
out_shape.push_back(num_segments_value);
for (size_t i = ids_shape.size(); i < x_shape.size(); i++) {
out_shape.push_back(x_shape.at(i));
}
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr UnsortedSegmentArithmeticInferType(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto in_type_ptr = input_args[kInputIndex0]->BuildType();
MS_EXCEPTION_IF_NULL(in_type_ptr);
std::set<TypePtr> in_type_set = {kFloat16, kFloat32, kInt32};
return CheckAndConvertUtils::CheckTensorTypeValid("x", in_type_ptr, in_type_set, prim_name);
}
} // namespace
AbstractBasePtr UnsortedSegmentArithmeticInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputNum = 3;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
auto infer_type = UnsortedSegmentArithmeticInferType(primitive, input_args);
auto infer_shape = UnsortedSegmentArithmeticInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
MIND_API_OPERATOR_IMPL(UnsortedSegmentMax, BaseOperator);
MIND_API_OPERATOR_IMPL(UnsortedSegmentMin, BaseOperator);
MIND_API_OPERATOR_IMPL(UnsortedSegmentProd, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(UnsortedSegmentMax, prim::kPrimUnsortedSegmentMax, UnsortedSegmentArithmeticInfer, nullptr,
true);
REGISTER_PRIMITIVE_EVAL_IMPL(UnsortedSegmentMin, prim::kPrimUnsortedSegmentMin, UnsortedSegmentArithmeticInfer, nullptr,
true);
REGISTER_PRIMITIVE_EVAL_IMPL(UnsortedSegmentProd, prim::kPrimUnsortedSegmentProd, UnsortedSegmentArithmeticInfer,
nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_ARITHMETIC_H_
#define MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_ARITHMETIC_H_
#include <set>
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
abstract::AbstractBasePtr UnsortedSegmentArithmeticInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_ARITHMETIC_H_

View File

@ -0,0 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MAX_H_
#define MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MAX_H_
#include <memory>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameUnsortedSegmentMax = "UnsortedSegmentMax";
/// \brief Computes the max of a tensor along segments.
/// Refer to Python API @ref mindspore.ops.UnsortedSegmentMax for more details.
class MIND_API UnsortedSegmentMax : public BaseOperator {
public:
MIND_API_BASE_MEMBER(UnsortedSegmentMax);
/// \brief Constructor.
UnsortedSegmentMax() : BaseOperator(kNameUnsortedSegmentMax) {
InitIOName({"x", "segment_ids", "num_segments"}, {"y"});
}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MAX_H_

View File

@ -0,0 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MIN_H_
#define MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MIN_H_
#include <memory>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameUnsortedSegmentMin = "UnsortedSegmentMin";
/// \brief Computes the min of a tensor along segments.
/// Refer to Python API @ref mindspore.ops.UnsortedSegmentMin for more details.
class MIND_API UnsortedSegmentMin : public BaseOperator {
public:
MIND_API_BASE_MEMBER(UnsortedSegmentMin);
/// \brief Constructor.
UnsortedSegmentMin() : BaseOperator(kNameUnsortedSegmentMin) {
InitIOName({"x", "segment_ids", "num_segments"}, {"y"});
}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_MIN_H_

View File

@ -0,0 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_PROD_H_
#define MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_PROD_H_
#include <memory>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameUnsortedSegmentProd = "UnsortedSegmentProd";
/// \brief Computes the prod of a tensor along segments.
/// Refer to Python API @ref mindspore.ops.UnsortedSegmentProd for more details.
class MIND_API UnsortedSegmentProd : public BaseOperator {
public:
MIND_API_BASE_MEMBER(UnsortedSegmentProd);
/// \brief Constructor.
UnsortedSegmentProd() : BaseOperator(kNameUnsortedSegmentProd) {
InitIOName({"x", "segment_ids", "num_segments"}, {"y"});
}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_UNSORTED_SEGMENT_PROD_H_