add op for ocr
This commit is contained in:
parent
159142a1f4
commit
9d5f70d97a
|
@ -19,6 +19,7 @@
|
|||
"mindspore/mindspore/python/mindspore/ops/operations" "super-init-not-called"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/_quant_ops.py" "unused-import"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/nn_ops.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/_inner_ops.py" "dangerous-default-value"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/_thor_ops.py" "dangerous-default-value"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/_thor_ops.py" "redefined-outer-name"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/_thor_ops.py" "unused-import"
|
||||
|
|
|
@ -101,6 +101,7 @@ constexpr const char kNameElu[] = "Elu";
|
|||
constexpr const char kNameEluGrad[] = "EluGrad";
|
||||
constexpr const char kNameTensorScatterUpdate[] = "TensorScatterUpdate";
|
||||
constexpr const char kNameTensorScatterElements[] = "TensorScatterElements";
|
||||
constexpr const char kNameScatterElements[] = "ScatterElements";
|
||||
constexpr const char kNameNonZero[] = "NonZero";
|
||||
constexpr const char kNameNonZeroWithValue[] = "NonZeroWithValue";
|
||||
constexpr const char kNameNonZeroWithValueShape[] = "NonZeroWithValueShape";
|
||||
|
@ -356,7 +357,6 @@ constexpr const char kNameWhile[] = "While";
|
|||
constexpr const char kNameKMeansCentroids[] = "KMeansCentroids";
|
||||
constexpr const char kNameIsNan[] = "IsNan";
|
||||
constexpr const char kNameKLDiv[] = "KLDivLoss";
|
||||
constexpr const char kNameStringLength[] = "StringLength";
|
||||
constexpr const char kNameGetShape[] = "GetShape";
|
||||
constexpr const char kNameKlDivLossGrad[] = "KLDivLossGrad";
|
||||
constexpr const char kNameRandomStandardNormal[] = "RandomStandardNormal";
|
||||
|
@ -364,6 +364,21 @@ constexpr const char kNameUnsortedSegmentSum[] = "UnsortedSegmentSum";
|
|||
constexpr const char kNameSpaceToBatchTF[] = "SpaceToBatchTF";
|
||||
constexpr const char kNameBatchToSpaceTF[] = "BatchToSpaceTF";
|
||||
constexpr const char kNameMaskedSelect[] = "MaskedSelect";
|
||||
constexpr const char kNamePartitionedCall[] = "PartitionedCall";
|
||||
constexpr const char kNameRangeV2[] = "RangeV2";
|
||||
constexpr const char kNameOCRDetectionPreHandle[] = "OCRDetectionPreHandle";
|
||||
constexpr const char kNameOCRFindContours[] = "OCRFindContours";
|
||||
constexpr const char kNameBatchDilatePolys[] = "BatchDilatePolys";
|
||||
constexpr const char kNameResizeAndClipPolys[] = "ResizeAndClipPolys";
|
||||
constexpr const char kNameOCRDetectionPostHandle[] = "OCRDetectionPostHandle";
|
||||
constexpr const char kNameOCRIdentifyPreHandle[] = "OCRIdentifyPreHandle";
|
||||
constexpr const char kNameBatchEnqueue[] = "BatchEnqueue";
|
||||
constexpr const char kNameDequeue[] = "Dequeue";
|
||||
constexpr const char kNameOCRRecognitionPreHandle[] = "OCRRecognitionPreHandle";
|
||||
constexpr const char kNameStringUpper[] = "StringUpper";
|
||||
constexpr const char kNameStringLength[] = "StringLength";
|
||||
constexpr const char kNameDecodeImage[] = "DecodeImage";
|
||||
constexpr const char kNameDecodeBase64[] = "DecodeBase64";
|
||||
|
||||
class OpAdapterDesc;
|
||||
|
||||
|
|
|
@ -189,6 +189,13 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<AnyValue>) {
|
|||
auto v = GetValue<int32_t>(value);
|
||||
desc.SetRealDimCnt(0);
|
||||
return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(int32_t));
|
||||
} else if (value->isa<UInt32Imm>()) {
|
||||
// convert scalar UInt to GeTensor
|
||||
MS_LOG(INFO) << "Convert scalar to tensor with data type = UInt32";
|
||||
GeTensorDesc desc(GeShape(), ::ge::FORMAT_NCHW, ::ge::DT_UINT32);
|
||||
auto v = GetValue<uint32_t>(value);
|
||||
desc.SetRealDimCnt(0);
|
||||
return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(uint32_t));
|
||||
} else if (value->isa<Int64Imm>()) {
|
||||
// convert scalar Int64 to GeTensor
|
||||
MS_LOG(INFO) << "convert scalar to tensor with data type = Int64";
|
||||
|
|
|
@ -85,6 +85,12 @@ ATTR_MAP(MirrorPadGrad) = {{"mode", ATTR_DESC(mode, AnyTraits<std::string>())}};
|
|||
OUTPUT_MAP(MirrorPadGrad) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(MirrorPadGrad, kNameMirrorPadGrad, ADPT_DESC(MirrorPadGrad))
|
||||
|
||||
// Expand
|
||||
INPUT_MAP(Expand) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}};
|
||||
ATTR_MAP(Expand) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(Expand) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Expand, "Expand", ADPT_DESC(Expand))
|
||||
|
||||
// ExpandDims
|
||||
INPUT_MAP(ExpandDims) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}};
|
||||
ATTR_MAP(ExpandDims) = EMPTY_ATTR_MAP;
|
||||
|
|
|
@ -42,6 +42,9 @@ DECLARE_OP_USE_OUTPUT(MirrorPad)
|
|||
DECLARE_OP_ADAPTER(MirrorPadGrad)
|
||||
DECLARE_OP_USE_OUTPUT(MirrorPadGrad)
|
||||
|
||||
DECLARE_OP_ADAPTER(Expand)
|
||||
DECLARE_OP_USE_OUTPUT(Expand)
|
||||
|
||||
DECLARE_OP_ADAPTER(ExpandDims)
|
||||
DECLARE_OP_USE_OUTPUT(ExpandDims)
|
||||
|
||||
|
|
|
@ -20,6 +20,11 @@
|
|||
#include <string>
|
||||
|
||||
namespace mindspore::transform {
|
||||
INPUT_MAP(ClipByValue) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(clip_value_min)}, {3, INPUT_DESC(clip_value_max)}};
|
||||
ATTR_MAP(ClipByValue) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(ClipByValue) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ClipByValue, "Clip", ADPT_DESC(ClipByValue))
|
||||
|
||||
// Assign
|
||||
INPUT_MAP(Assign) = {{1, INPUT_DESC(ref)}, {2, INPUT_DESC(value)}};
|
||||
ATTR_MAP(Assign) = EMPTY_ATTR_MAP;
|
||||
|
|
|
@ -22,6 +22,9 @@
|
|||
#include "ops/elewise_calculation_ops.h"
|
||||
|
||||
namespace mindspore::transform {
|
||||
DECLARE_OP_ADAPTER(ClipByValue)
|
||||
DECLARE_OP_USE_OUTPUT(ClipByValue)
|
||||
|
||||
DECLARE_OP_ADAPTER(AccumulateNV2)
|
||||
DECLARE_OP_USE_DYN_INPUT(AccumulateNV2)
|
||||
DECLARE_OP_USE_OUTPUT(AccumulateNV2)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "transform/graph_ir/op_declare/functional_ops_declare.h"
|
||||
#include <string>
|
||||
|
||||
namespace mindspore::transform {
|
||||
// Case
|
||||
|
@ -31,4 +32,16 @@ ATTR_MAP(While) = {{"parallel_iterations", ATTR_DESC(parallel_iterations, AnyTra
|
|||
DYN_OUTPUT_MAP(While) = {{0, DYN_OUTPUT_DESC(output)}};
|
||||
SUBGRAPH_MAP(While) = {{0, SUBGRAPH_DESC(cond)}, {1, SUBGRAPH_DESC(body)}};
|
||||
REG_ADPT_DESC(While, kNameWhile, ADPT_DESC(While));
|
||||
|
||||
// PartitionedCall
|
||||
INPUT_MAP(PartitionedCall) = EMPTY_INPUT_MAP;
|
||||
DYN_INPUT_MAP(PartitionedCall) = {{1, DYN_INPUT_DESC(args)}};
|
||||
ATTR_MAP(PartitionedCall) = {
|
||||
{"config", ATTR_DESC(config, AnyTraits<std::string>())},
|
||||
{"config_proto", ATTR_DESC(config_proto, AnyTraits<std::string>())},
|
||||
{"executor_type", ATTR_DESC(executor_type, AnyTraits<std::string>())},
|
||||
};
|
||||
DYN_OUTPUT_MAP(PartitionedCall) = {{0, DYN_OUTPUT_DESC(output)}};
|
||||
SUBGRAPH_MAP(PartitionedCall) = {{0, SUBGRAPH_DESC(f)}};
|
||||
REG_ADPT_DESC(PartitionedCall, kNamePartitionedCall, ADPT_DESC(PartitionedCall))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -32,5 +32,10 @@ DECLARE_OP_ATTR(While)
|
|||
DECLARE_OP_USE_DYN_INPUT(While)
|
||||
DECLARE_OP_USE_SUBGRAPH(While)
|
||||
DECLARE_OP_USE_DYN_OUTPUT(While)
|
||||
|
||||
DECLARE_OP_ADAPTER(PartitionedCall)
|
||||
DECLARE_OP_USE_DYN_INPUT(PartitionedCall)
|
||||
DECLARE_OP_USE_SUBGRAPH(PartitionedCall)
|
||||
DECLARE_OP_USE_DYN_OUTPUT(PartitionedCall)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_FUNCTIONAL_OPS_DECLARE_H_
|
||||
|
|
|
@ -60,4 +60,12 @@ ATTR_MAP(CropAndResize) = {{"extrapolation_value", ATTR_DESC(extrapolation_value
|
|||
{"method", ATTR_DESC(method, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(CropAndResize) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(CropAndResize, kNameCropAndResize, ADPT_DESC(CropAndResize))
|
||||
|
||||
// DecodeImage
|
||||
INPUT_MAP(DecodeImage) = {{1, INPUT_DESC(contents)}};
|
||||
ATTR_MAP(DecodeImage) = {{"channels", ATTR_DESC(channels, AnyTraits<int64_t>())},
|
||||
{"dtype", ATTR_DESC(dtype, AnyTraits<GEType>())},
|
||||
{"expand_animations", ATTR_DESC(expand_animations, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(DecodeImage) = {{0, OUTPUT_DESC(image)}};
|
||||
REG_ADPT_DESC(DecodeImage, kNameDecodeImage, ADPT_DESC(DecodeImage))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -39,5 +39,8 @@ DECLARE_OP_USE_OUTPUT(ResizeBilinearV2Grad)
|
|||
|
||||
DECLARE_OP_ADAPTER(CropAndResize)
|
||||
DECLARE_OP_USE_OUTPUT(CropAndResize)
|
||||
|
||||
DECLARE_OP_ADAPTER(DecodeImage)
|
||||
DECLARE_OP_USE_OUTPUT(DecodeImage)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_IMAGE_OPS_DECLARE_H_
|
||||
|
|
|
@ -96,6 +96,7 @@ ATTR_MAP(MatMulV2) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits<bool>())
|
|||
{"transpose_b", ATTR_DESC(transpose_x2, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(MatMulV2) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(MatMulV2, prim::kPrimMatMul->name(), ADPT_DESC(MatMulV2))
|
||||
REG_ADPT_DESC(MatMulV2Duplicate, prim::kPrimMatMulV2->name(), ADPT_DESC(MatMulV2))
|
||||
|
||||
// MatrixDiag
|
||||
INPUT_MAP(MatrixDiag) = {{1, INPUT_DESC(x)}};
|
||||
|
@ -145,7 +146,8 @@ REG_ADPT_DESC(L2Loss, kNameL2Loss, ADPT_DESC(L2Loss))
|
|||
INPUT_MAP(ScatterElements) = {{1, INPUT_DESC(data)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
|
||||
ATTR_MAP(ScatterElements) = {{"axis", ATTR_DESC(axis, AnyTraits<int64_t>())}};
|
||||
OUTPUT_MAP(ScatterElements) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ScatterElements, kNameTensorScatterElements, ADPT_DESC(ScatterElements))
|
||||
REG_ADPT_DESC(TensorScatterElements, kNameTensorScatterElements, ADPT_DESC(ScatterElements))
|
||||
REG_ADPT_DESC(ScatterElements, kNameScatterElements, ADPT_DESC(ScatterElements))
|
||||
|
||||
// FullyConnection
|
||||
INPUT_MAP(FullyConnection) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(w)}, {3, INPUT_DESC(b)}, {4, INPUT_DESC(offset_w)}};
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "transform/graph_ir/op_declare/ocr_ops_declare.h"
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore::transform {
|
||||
INPUT_MAP(OCRDetectionPreHandle) = {{1, INPUT_DESC(img)}};
|
||||
ATTR_MAP(OCRDetectionPreHandle) = {{"format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(OCRDetectionPreHandle) = {
|
||||
{0, OUTPUT_DESC(resized_img)}, {1, OUTPUT_DESC(h_scale)}, {2, OUTPUT_DESC(w_scale)}};
|
||||
REG_ADPT_DESC(OCRDetectionPreHandle, kNameOCRDetectionPreHandle, ADPT_DESC(OCRDetectionPreHandle))
|
||||
|
||||
INPUT_MAP(OCRFindContours) = {{1, INPUT_DESC(img)}};
|
||||
ATTR_MAP(OCRFindContours) = {{"value_mode", ATTR_DESC(value_mode, AnyTraits<int64_t>())}};
|
||||
OUTPUT_MAP(OCRFindContours) = {
|
||||
{0, OUTPUT_DESC(polys_data)}, {1, OUTPUT_DESC(polys_offset)}, {2, OUTPUT_DESC(polys_size)}};
|
||||
REG_ADPT_DESC(OCRFindContours, kNameOCRFindContours, ADPT_DESC(OCRFindContours))
|
||||
|
||||
INPUT_MAP(BatchDilatePolys) = {{1, INPUT_DESC(polys_data)}, {2, INPUT_DESC(polys_offset)},
|
||||
{3, INPUT_DESC(polys_size)}, {4, INPUT_DESC(score)},
|
||||
{5, INPUT_DESC(min_border)}, {6, INPUT_DESC(min_area_thr)},
|
||||
{7, INPUT_DESC(score_thr)}, {8, INPUT_DESC(expands_cale)}};
|
||||
ATTR_MAP(BatchDilatePolys) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(BatchDilatePolys) = {
|
||||
{0, OUTPUT_DESC(dilated_polys_data)}, {1, OUTPUT_DESC(dilated_polys_offset)}, {2, OUTPUT_DESC(dilated_polys_size)}};
|
||||
REG_ADPT_DESC(BatchDilatePolys, kNameBatchDilatePolys, ADPT_DESC(BatchDilatePolys))
|
||||
|
||||
INPUT_MAP(ResizeAndClipPolys) = {
|
||||
{1, INPUT_DESC(polys_data)}, {2, INPUT_DESC(polys_offset)}, {3, INPUT_DESC(polys_size)}, {4, INPUT_DESC(h_scale)},
|
||||
{5, INPUT_DESC(w_scale)}, {6, INPUT_DESC(img_h)}, {7, INPUT_DESC(img_w)}};
|
||||
ATTR_MAP(ResizeAndClipPolys) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(ResizeAndClipPolys) = {{0, OUTPUT_DESC(clipped_polys_data)},
|
||||
{1, OUTPUT_DESC(clipped_polys_offset)},
|
||||
{2, OUTPUT_DESC(clipped_polys_size)},
|
||||
{3, OUTPUT_DESC(clipped_polys_num)}};
|
||||
REG_ADPT_DESC(ResizeAndClipPolys, kNameResizeAndClipPolys, ADPT_DESC(ResizeAndClipPolys))
|
||||
|
||||
INPUT_MAP(OCRDetectionPostHandle) = {
|
||||
{1, INPUT_DESC(img)}, {2, INPUT_DESC(polys_data)}, {3, INPUT_DESC(polys_offset)}, {4, INPUT_DESC(polys_size)}};
|
||||
ATTR_MAP(OCRDetectionPostHandle) = {{"format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(OCRDetectionPostHandle) = {{0, OUTPUT_DESC(imgs_data)},
|
||||
{1, OUTPUT_DESC(imgs_offset)},
|
||||
{2, OUTPUT_DESC(imgs_size)},
|
||||
{3, OUTPUT_DESC(rect_points)}};
|
||||
REG_ADPT_DESC(OCRDetectionPostHandle, kNameOCRDetectionPostHandle, ADPT_DESC(OCRDetectionPostHandle))
|
||||
|
||||
INPUT_MAP(OCRIdentifyPreHandle) = {
|
||||
{1, INPUT_DESC(imgs_data)}, {2, INPUT_DESC(imgs_offset)}, {3, INPUT_DESC(imgs_size)}};
|
||||
ATTR_MAP(OCRIdentifyPreHandle) = {{"size", ATTR_DESC(size, AnyTraits<std::vector<int64_t>>())},
|
||||
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(OCRIdentifyPreHandle) = {{0, OUTPUT_DESC(resized_imgs)}};
|
||||
REG_ADPT_DESC(OCRIdentifyPreHandle, kNameOCRIdentifyPreHandle, ADPT_DESC(OCRIdentifyPreHandle))
|
||||
|
||||
INPUT_MAP(BatchEnqueue) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(queue_id)}};
|
||||
ATTR_MAP(BatchEnqueue) = {{"batch_size", ATTR_DESC(batch_size, AnyTraits<int64_t>())},
|
||||
{"queue_name", ATTR_DESC(queue_name, AnyTraits<std::string>())},
|
||||
{"queue_depth", ATTR_DESC(queue_depth, AnyTraits<int64_t>())},
|
||||
{"pad_mode", ATTR_DESC(pad_mode, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(BatchEnqueue) = {{0, OUTPUT_DESC(enqueue_count)}};
|
||||
REG_ADPT_DESC(BatchEnqueue, kNameBatchEnqueue, ADPT_DESC(BatchEnqueue))
|
||||
|
||||
INPUT_MAP(Dequeue) = {{1, INPUT_DESC(queue_id)}};
|
||||
ATTR_MAP(Dequeue) = {{"output_type", ATTR_DESC(output_type, AnyTraits<GEType>())},
|
||||
{"output_shape", ATTR_DESC(output_shape, AnyTraits<std::vector<int64_t>>())},
|
||||
{"queue_name", ATTR_DESC(queue_name, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(Dequeue) = {{0, OUTPUT_DESC(data)}};
|
||||
REG_ADPT_DESC(Dequeue, kNameDequeue, ADPT_DESC(Dequeue))
|
||||
|
||||
INPUT_MAP(OCRRecognitionPreHandle) = {{1, INPUT_DESC(imgs_data)},
|
||||
{2, INPUT_DESC(imgs_offset)},
|
||||
{3, INPUT_DESC(imgs_size)},
|
||||
{4, INPUT_DESC(langs)},
|
||||
{5, INPUT_DESC(langs_score)}};
|
||||
ATTR_MAP(OCRRecognitionPreHandle) = {{"batch_size", ATTR_DESC(batch_size, AnyTraits<int64_t>())},
|
||||
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())},
|
||||
{"pad_mode", ATTR_DESC(pad_mode, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(OCRRecognitionPreHandle) = {{0, OUTPUT_DESC(imgs)},
|
||||
{1, OUTPUT_DESC(imgs_relation)},
|
||||
{2, OUTPUT_DESC(imgs_lang)},
|
||||
{3, OUTPUT_DESC(imgs_piece_fillers)}};
|
||||
REG_ADPT_DESC(OCRRecognitionPreHandle, kNameOCRRecognitionPreHandle, ADPT_DESC(OCRRecognitionPreHandle))
|
||||
} // namespace mindspore::transform
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_OCR_OPS_DECLARE_H_
|
||||
#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_OCR_OPS_DECLARE_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "transform/graph_ir/op_declare/op_declare_macro.h"
|
||||
#include "ops/ocr_ops.h"
|
||||
|
||||
namespace mindspore::transform {
|
||||
DECLARE_OP_ADAPTER(OCRDetectionPreHandle)
|
||||
DECLARE_OP_USE_OUTPUT(OCRDetectionPreHandle)
|
||||
|
||||
DECLARE_OP_ADAPTER(OCRFindContours)
|
||||
DECLARE_OP_USE_OUTPUT(OCRFindContours)
|
||||
|
||||
DECLARE_OP_ADAPTER(BatchDilatePolys)
|
||||
DECLARE_OP_USE_OUTPUT(BatchDilatePolys)
|
||||
|
||||
DECLARE_OP_ADAPTER(ResizeAndClipPolys)
|
||||
DECLARE_OP_USE_OUTPUT(ResizeAndClipPolys)
|
||||
|
||||
DECLARE_OP_ADAPTER(OCRDetectionPostHandle)
|
||||
DECLARE_OP_USE_OUTPUT(OCRDetectionPostHandle)
|
||||
|
||||
DECLARE_OP_ADAPTER(OCRIdentifyPreHandle)
|
||||
DECLARE_OP_USE_OUTPUT(OCRIdentifyPreHandle)
|
||||
|
||||
DECLARE_OP_ADAPTER(BatchEnqueue)
|
||||
DECLARE_OP_USE_OUTPUT(BatchEnqueue)
|
||||
|
||||
DECLARE_OP_ADAPTER(Dequeue)
|
||||
DECLARE_OP_USE_OUTPUT(Dequeue)
|
||||
|
||||
DECLARE_OP_ADAPTER(OCRRecognitionPreHandle)
|
||||
DECLARE_OP_USE_OUTPUT(OCRRecognitionPreHandle)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_OCR_OPS_DECLARE_H_
|
|
@ -89,6 +89,7 @@ INPUT_MAP(ReduceMean) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
|
|||
ATTR_MAP(ReduceMean) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ReduceMean) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ReduceMean, prim::kPrimReduceMean->name(), ADPT_DESC(ReduceMean))
|
||||
REG_ADPT_DESC(ReduceMeanV1, "ReduceMeanV1", ADPT_DESC(ReduceMean))
|
||||
|
||||
// ReduceMinD
|
||||
INPUT_MAP(ReduceMinD) = {{1, INPUT_DESC(x)}};
|
||||
|
@ -105,4 +106,9 @@ INPUT_ATTR_MAP(ReduceMaxD) = {
|
|||
ATTR_MAP(ReduceMaxD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ReduceMaxD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ReduceMaxD, prim::kPrimReduceMax->name(), ADPT_DESC(ReduceMaxD))
|
||||
|
||||
INPUT_MAP(ReduceMax) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
|
||||
ATTR_MAP(ReduceMax) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ReduceMax) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ReduceMax, "ReduceMaxV1", ADPT_DESC(ReduceMax))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -33,6 +33,9 @@ DECLARE_OP_ADAPTER(ReduceMaxD)
|
|||
DECLARE_OP_USE_INPUT_ATTR(ReduceMaxD)
|
||||
DECLARE_OP_USE_OUTPUT(ReduceMaxD)
|
||||
|
||||
DECLARE_OP_ADAPTER(ReduceMax)
|
||||
DECLARE_OP_USE_OUTPUT(ReduceMax)
|
||||
|
||||
DECLARE_OP_ADAPTER(ReduceAllD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(ReduceAllD)
|
||||
DECLARE_OP_USE_OUTPUT(ReduceAllD)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "transform/graph_ir/op_declare/rnn_declare.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore::transform {
|
||||
// BasicLSTMCell
|
||||
|
@ -139,4 +140,29 @@ ATTR_MAP(DynamicGRUV2Grad) = {{"direction", ATTR_DESC(direction, AnyTraits<std::
|
|||
OUTPUT_MAP(DynamicGRUV2Grad) = {{0, OUTPUT_DESC(dw_input)}, {1, OUTPUT_DESC(dw_hidden)}, {2, OUTPUT_DESC(db_input)},
|
||||
{3, OUTPUT_DESC(db_hidden)}, {4, OUTPUT_DESC(dx)}, {5, OUTPUT_DESC(dh_prev)}};
|
||||
REG_ADPT_DESC(DynamicGRUV2Grad, kNameDynamicGRUV2Grad, ADPT_DESC(DynamicGRUV2Grad))
|
||||
|
||||
// CommonLSTM
|
||||
INPUT_MAP(CommonLSTM) = {{1, INPUT_DESC(x)},
|
||||
{2, INPUT_DESC(w)},
|
||||
{3, INPUT_DESC(r)},
|
||||
{4, INPUT_DESC(b)},
|
||||
{5, INPUT_DESC(sequence_lens)},
|
||||
{6, INPUT_DESC(initial_h)},
|
||||
{7, INPUT_DESC(initial_c)},
|
||||
{8, INPUT_DESC(p)}};
|
||||
ATTR_MAP(CommonLSTM) = {
|
||||
{"activation_alpha", ATTR_DESC(activation_alpha, AnyTraits<std::vector<float>>())},
|
||||
{"activation_beta", ATTR_DESC(activation_beta, AnyTraits<std::vector<float>>())},
|
||||
{"activations", ATTR_DESC(activations, AnyTraits<std::vector<std::string>>())},
|
||||
{"clip", ATTR_DESC(clip, AnyTraits<float>())},
|
||||
{"direction", ATTR_DESC(direction, AnyTraits<std::string>())},
|
||||
{"hidden_size", ATTR_DESC(hidden_size, AnyTraits<int64_t>())},
|
||||
{"input_forget", ATTR_DESC(input_forget, AnyTraits<int64_t>())},
|
||||
};
|
||||
OUTPUT_MAP(CommonLSTM) = {
|
||||
{0, OUTPUT_DESC(y)},
|
||||
{1, OUTPUT_DESC(y_h)},
|
||||
{2, OUTPUT_DESC(y_c)},
|
||||
};
|
||||
REG_ADPT_DESC(CommonLSTM, "CommonLSTM", ADPT_DESC(CommonLSTM))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -48,5 +48,8 @@ DECLARE_OP_USE_OUTPUT(DynamicGRUV2)
|
|||
|
||||
DECLARE_OP_ADAPTER(DynamicGRUV2Grad)
|
||||
DECLARE_OP_USE_OUTPUT(DynamicGRUV2Grad)
|
||||
|
||||
DECLARE_OP_ADAPTER(CommonLSTM)
|
||||
DECLARE_OP_USE_OUTPUT(CommonLSTM)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_RNN_DECLARE_H_
|
||||
|
|
|
@ -30,6 +30,7 @@ REG_ADPT_DESC(CumsumD, kNameCumSum, ADPT_DESC(CumsumD))
|
|||
INPUT_MAP(GatherV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(axis)}};
|
||||
ATTR_MAP(GatherV2) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(GatherV2) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(GatherV2, prim::kPrimGatherV2->name(), ADPT_DESC(GatherV2))
|
||||
|
||||
// CumprodD
|
||||
INPUT_MAP(CumprodD) = {{1, INPUT_DESC(x)}};
|
||||
|
@ -39,6 +40,11 @@ ATTR_MAP(CumprodD) = {{"exclusive", ATTR_DESC(exclusive, AnyTraits<bool>())},
|
|||
OUTPUT_MAP(CumprodD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(CumprodD, kNameCumProd, ADPT_DESC(CumprodD))
|
||||
|
||||
INPUT_MAP(Tile) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(multiples)}};
|
||||
ATTR_MAP(Tile) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(Tile) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Tile, "TileV1", ADPT_DESC(Tile))
|
||||
|
||||
INPUT_MAP(Slice) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(offsets)}, {3, INPUT_DESC(size)}};
|
||||
ATTR_MAP(Slice) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(Slice) = {{0, OUTPUT_DESC(y)}};
|
||||
|
@ -109,6 +115,12 @@ ATTR_MAP(RangeD) = {{"start", ATTR_DESC(start, AnyTraits<float>())},
|
|||
OUTPUT_MAP(RangeD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(RangeD, kNameRange, ADPT_DESC(RangeD))
|
||||
|
||||
// RangeV2
|
||||
INPUT_MAP(Range) = {{1, INPUT_DESC(start)}, {2, INPUT_DESC(limit)}, {3, INPUT_DESC(delta)}};
|
||||
ATTR_MAP(Range) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(Range) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(RangeV2, kNameRangeV2, ADPT_DESC(Range))
|
||||
|
||||
// InplaceAddD
|
||||
INPUT_MAP(InplaceAddD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(v)}};
|
||||
ATTR_MAP(InplaceAddD) = {{"indices", ATTR_DESC(indices, AnyTraits<std::vector<int64_t>>())}};
|
||||
|
|
|
@ -73,6 +73,9 @@ DECLARE_OP_ADAPTER(CumprodD)
|
|||
DECLARE_OP_USE_INPUT_ATTR(CumprodD)
|
||||
DECLARE_OP_USE_OUTPUT(CumprodD)
|
||||
|
||||
DECLARE_OP_ADAPTER(Tile)
|
||||
DECLARE_OP_USE_OUTPUT(Tile)
|
||||
|
||||
DECLARE_OP_ADAPTER(TileD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(TileD)
|
||||
DECLARE_OP_USE_OUTPUT(TileD)
|
||||
|
@ -87,6 +90,9 @@ DECLARE_OP_USE_OUTPUT(GatherV2D)
|
|||
DECLARE_OP_ADAPTER(RangeD)
|
||||
DECLARE_OP_USE_OUTPUT(RangeD)
|
||||
|
||||
DECLARE_OP_ADAPTER(Range)
|
||||
DECLARE_OP_USE_OUTPUT(Range)
|
||||
|
||||
DECLARE_OP_ADAPTER(InplaceAddD)
|
||||
DECLARE_OP_USE_OUTPUT(InplaceAddD)
|
||||
|
||||
|
|
|
@ -16,11 +16,20 @@
|
|||
|
||||
#include "transform/graph_ir/op_declare/string_ops_declare.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore::transform {
|
||||
INPUT_MAP(StringUpper) = {{1, INPUT_DESC(input)}};
|
||||
ATTR_MAP(StringUpper) = {{"encoding", ATTR_DESC(encoding, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(StringUpper) = {{0, OUTPUT_DESC(output)}};
|
||||
REG_ADPT_DESC(StringUpper, kNameStringUpper, ADPT_DESC(StringUpper))
|
||||
|
||||
INPUT_MAP(StringLength) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(StringLength) = {{"unit", ATTR_DESC(unit, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(StringLength) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(StringLength, kNameStringLength, ADPT_DESC(StringLength))
|
||||
|
||||
INPUT_MAP(DecodeBase64) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(DecodeBase64) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(DecodeBase64) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(DecodeBase64, kNameDecodeBase64, ADPT_DESC(DecodeBase64))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -17,11 +17,19 @@
|
|||
#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_STRING_OPS_DECLARE_H_
|
||||
#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_STRING_OPS_DECLARE_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "transform/graph_ir/op_declare/op_declare_macro.h"
|
||||
#include "ops/string_ops.h"
|
||||
|
||||
namespace mindspore::transform {
|
||||
DECLARE_OP_ADAPTER(StringUpper)
|
||||
DECLARE_OP_USE_OUTPUT(StringUpper)
|
||||
|
||||
DECLARE_OP_ADAPTER(StringLength)
|
||||
DECLARE_OP_USE_OUTPUT(StringLength)
|
||||
|
||||
DECLARE_OP_ADAPTER(DecodeBase64)
|
||||
DECLARE_OP_USE_OUTPUT(DecodeBase64)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_STRING_OPS_DECLARE_H_
|
||||
|
|
|
@ -48,6 +48,12 @@ ATTR_MAP(TransposeD) = EMPTY_ATTR_MAP;
|
|||
// Do not set Transpose operator output descriptor
|
||||
REG_ADPT_DESC(TransposeD, prim::kPrimTranspose->name(), ADPT_DESC(TransposeD))
|
||||
|
||||
// Transpose
|
||||
INPUT_MAP(Transpose) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(perm)}};
|
||||
ATTR_MAP(Transpose) = EMPTY_ATTR_MAP;
|
||||
// Do not set Transpose operator output descriptor
|
||||
REG_ADPT_DESC(Transpose, "TransposeV1", ADPT_DESC(Transpose))
|
||||
|
||||
// SpaceToDepth
|
||||
INPUT_MAP(SpaceToDepth) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(SpaceToDepth) = {{"block_size", ATTR_DESC(block_size, AnyTraits<int64_t>())}};
|
||||
|
|
|
@ -31,6 +31,8 @@ DECLARE_OP_USE_DYN_OUTPUT(Unpack)
|
|||
DECLARE_OP_ADAPTER(TransposeD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(TransposeD)
|
||||
|
||||
DECLARE_OP_ADAPTER(Transpose)
|
||||
|
||||
DECLARE_OP_ADAPTER(Flatten)
|
||||
DECLARE_OP_USE_OUTPUT(Flatten)
|
||||
|
||||
|
|
|
@ -568,6 +568,7 @@ GVAR_DEF(PrimitivePtr, kPrimScatterNdMin, std::make_shared<Primitive>("ScatterNd
|
|||
GVAR_DEF(PrimitivePtr, kPrimScatterNdMul, std::make_shared<Primitive>("ScatterNdMul"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScatterNdDiv, std::make_shared<Primitive>("ScatterNdDiv"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScatterUpdate, std::make_shared<Primitive>("ScatterUpdate"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScatterElements, std::make_shared<Primitive>("ScatterElements"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScatterAddWithAxis, std::make_shared<Primitive>(kScatterAddWithAxis));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterElements, std::make_shared<Primitive>("TensorScatterElements"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorScatterUpdate, std::make_shared<Primitive>("TensorScatterUpdate"));
|
||||
|
@ -1443,11 +1444,15 @@ GVAR_DEF(PrimitivePtr, kPrimTensorArrayStack, std::make_shared<Primitive>("Tenso
|
|||
GVAR_DEF(PrimitivePtr, kPrimTensorArray, std::make_shared<Primitive>("TensorArray"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorArrayWrite, std::make_shared<Primitive>("TensorArrayWrite"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorArrayGather, std::make_shared<Primitive>("TensorArrayGather"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimPartitionedCall, std::make_shared<Primitive>("PartitionedCall"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimDecodeImage, std::make_shared<Primitive>("DecodeImage"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimStridedSliceV2, std::make_shared<Primitive>("StridedSliceV2"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimKMeansCentroids, std::make_shared<Primitive>("KMeansCentroids"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferCreate, std::make_shared<Primitive>("ReservoirReplayBufferCreate"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferPush, std::make_shared<Primitive>("ReservoirReplayBufferPush"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferSample, std::make_shared<Primitive>("ReservoirReplayBufferSample"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferDestroy, std::make_shared<Primitive>("ReservoirReplayBufferDestroy"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimOCRDetectionPreHandle, std::make_shared<Primitive>("OCRDetectionPreHandle"));
|
||||
|
||||
// AdamApplyOne
|
||||
GVAR_DEF(PrimitivePtr, kPrimAdamApplyOne, std::make_shared<Primitive>("AdamApplyOne"));
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/scatter_elements.h"
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace {
|
||||
constexpr size_t kScatterElementsArgSize = 3;
|
||||
} // namespace
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
int64_t ScatterElements::get_axis() const { return axis_; }
|
||||
|
||||
void ScatterElements::set_axis(const int64_t axis) { axis_ = axis; }
|
||||
|
||||
AbstractBasePtr ScatterElementsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const abstract::AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckRequiredArgsSize(op_name, args_spec_list, kScatterElementsArgSize);
|
||||
auto x = abstract::CheckArg<abstract::AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(x->shape());
|
||||
ShapeVector shape = x->shape()->shape();
|
||||
ShapeVector min_shape = x->shape()->min_shape();
|
||||
ShapeVector max_shape = x->shape()->max_shape();
|
||||
abstract::CheckMinMaxShape(shape, &min_shape, &max_shape);
|
||||
return std::make_shared<abstract::AbstractTensor>(x->element(),
|
||||
std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(ScatterElements, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterElements, prim::kPrimScatterElements, ScatterElementsInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SCATTER_ELEMENTS_H_
|
||||
#define MINDSPORE_CORE_OPS_SCATTER_ELEMENTS_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameScatterElements = "ScatterElements";
|
||||
class MIND_API ScatterElements : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ScatterElements);
|
||||
/// \brief Constructor.
|
||||
ScatterElements() : BaseOperator(kNameScatterElements) { InitIOName({"indices", "update", "shape"}, {"output"}); }
|
||||
/// \brief Init.
|
||||
void Init() const {}
|
||||
|
||||
void set_axis(const int64_t axis);
|
||||
|
||||
int64_t get_axis() const;
|
||||
|
||||
private:
|
||||
int64_t axis_ = 0;
|
||||
};
|
||||
abstract::AbstractBasePtr ScatterElementsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SCATTER_ELEMENTS_H_
|
|
@ -25,14 +25,14 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameScatterElements = "TensorScatterElements";
|
||||
constexpr auto kNameTensorScatterElements = "TensorScatterElements";
|
||||
/// \brief Updates tensor values by using input indices and value.
|
||||
/// Refer to Python API @ref mindspore.ops.TensorScatterElements for more details.
|
||||
class MIND_API TensorScatterElements : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(TensorScatterElements);
|
||||
/// \brief Constructor.
|
||||
TensorScatterElements() : BaseOperator(kNameScatterElements) {
|
||||
TensorScatterElements() : BaseOperator(kNameTensorScatterElements) {
|
||||
InitIOName({"input_x", "indices", "update"}, {"output"});
|
||||
}
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.TensorScatterElements for the inputs.
|
||||
|
|
|
@ -1521,6 +1521,34 @@ class NonZeroWithValueShape(Primitive):
|
|||
self.init_prim_io_names(inputs=['value', 'index', 'count'], outputs=['out_value', 'out_index'])
|
||||
|
||||
|
||||
class DecodeImage(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns image data that parse from string Tensor.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor), a Tensor of type string. 0-D. The jPEG, GIF, PNG, BMP-encoded image.
|
||||
|
||||
Outputs:
|
||||
A Tensor of type uint8, uint16, float.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, channels=0, dtype=mstype.uint8, expand_animations=False, _op_max_shape="8192,8192,3",
|
||||
_op_max_size=[8000000]):
|
||||
self.init_prim_io_names(inputs=["contents"], outputs=["image"])
|
||||
self.res_type = dtype
|
||||
|
||||
def infer_shape(self, x):
|
||||
return (-1, -1, 3)
|
||||
|
||||
def infer_dtype(self, x):
|
||||
return self.res_type
|
||||
|
||||
|
||||
class SliceGetItem(Primitive):
|
||||
"""
|
||||
using SliceGetItem to get slice's attribute of 'start' 'stop' 'step'
|
||||
|
@ -1727,6 +1755,33 @@ class ParallelResizeBilinear(PrimitiveWithInfer):
|
|||
'value': None}
|
||||
|
||||
|
||||
class PartitionedCall(PrimitiveWithInfer):
|
||||
"""
|
||||
Pass the input tensors to the subgraph and return the output tensors.
|
||||
|
||||
Inputs:
|
||||
- **inputs** (Tuple), the input tensors, which will be passed to subgraph.
|
||||
|
||||
Outputs:
|
||||
- outputs(Tuple), the output tensor returned by subgraph.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, graph, executor_type=""):
|
||||
self.add_prim_attr("executor_type", executor_type)
|
||||
self.graph = graph
|
||||
|
||||
def infer_shape(self, *inputs):
|
||||
return NotImplementedError
|
||||
|
||||
def infer_dtype(self, *inputs):
|
||||
return NotImplementedError
|
||||
|
||||
|
||||
class CellBackwardHook(PrimitiveWithInfer):
|
||||
r"""
|
||||
This operator is used to hook input gradient and output gradient of Cell object.
|
||||
|
|
|
@ -35,7 +35,7 @@ class GetShape(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=["x"], outputs=["y"])
|
||||
|
||||
def infer_shape(self, x):
|
||||
return (x[0],)
|
||||
return (len(x[0]),)
|
||||
|
||||
def infer_dtype(self, x):
|
||||
return mstype.int32
|
||||
|
|
|
@ -900,6 +900,13 @@ class TensorShape(Primitive):
|
|||
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
||||
|
||||
|
||||
class Unsqueeze(PrimitiveWithCheck):
|
||||
@prim_attr_register
|
||||
def __init__(self, axis):
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
self.axis = axis
|
||||
|
||||
|
||||
class DynamicShape(Primitive):
|
||||
"""
|
||||
Same as operator TensorShape. DynamicShape will be deprecated in the future.
|
||||
|
@ -911,7 +918,7 @@ class DynamicShape(Primitive):
|
|||
|
||||
@deprecated("1.7", "TensorShape", True)
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
def __init__(self, dtype=9):
|
||||
"""init Shape"""
|
||||
self.init_prim_io_names(inputs=['tensor'], outputs=['output'])
|
||||
self.add_prim_attr('is_dynamic_shape', True)
|
||||
|
|
Loading…
Reference in New Issue