Add regop and adapters for custom aicpu

This commit is contained in:
panzhihui 2023-07-14 16:04:45 +08:00
parent 43991f351f
commit be2abc0952
50 changed files with 4831 additions and 2032 deletions

View File

@ -175,6 +175,7 @@
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cast_kernels.cc" "build/include_order"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cast_kernels.h" "runtime/explicit"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cast_kernels.h" "whitespace/indent"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/meshgrid_kernels.h" "runtime/explicit"
# custom AICPU op_protos
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/" "whitespace/ending_newline"

View File

@ -424,3 +424,4 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/combined_non_max_suppression_proto.cc:ge::IMPLEMT_INFERFUNC
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/im2col_proto.cc:ge::IMPLEMT_COMMON_INFERFUNC
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/math_ops_proto.cc:ge::IMPLEMT_INFERFUNC
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/selection_ops_proto.cc:ge::IMPLEMT_COMMON_INFERFUNC

View File

@ -52,6 +52,7 @@ if(EXISTS ${CMAKE_C_COMPILER} AND EXISTS ${CMAKE_CXX_COMPILER})
${CMAKE_CURRENT_SOURCE_DIR}/candidate_sampler_kernels.cc
${CMAKE_CURRENT_SOURCE_DIR}/cast_kernels.cc
${CMAKE_CURRENT_SOURCE_DIR}/concat_offset_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/meshgrid_kernels.cc
${CMAKE_CURRENT_SOURCE_DIR}/drop_out_gen_mask_kernels.cc
${CMAKE_CURRENT_SOURCE_DIR}/expand_dims_kernels.cc
${CMAKE_CURRENT_SOURCE_DIR}/flatten_kernels.cc

View File

@ -149,6 +149,10 @@ uint32_t GatherNdCpuKernel::GatherNdComputeRealKernel(CpuKernelContext &ctx) {
for (int64_t i = 0; i < n_slices; ++i) {
int64_t from_pos = 0;
for (int64_t j = 0; j < indices_nd; ++j) {
if (indices_data[i * indices_nd + j] < 0) {
KERNEL_LOG_ERROR("For 'GatherNd', indices can't contain negative value.");
return KERNEL_STATUS_INNER_ERROR;
}
from_pos += indices_data[i * indices_nd + j] * dims_to_count[j];
}
auto offset = i * slice_size;

View File

@ -1,106 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/adaptive_avg_pool_3d_grad_op.h"
#include "register/op_impl_registry.h"
#include "external/graph/operator_reg.h"
#include "utils/util.h"
namespace ge {
// --------- AdaptiveAvgPool3dGrad ---------------
CUST_IMPLEMT_VERIFIER(AdaptiveAvgPool3dGrad, AdaptiveAvgPool3dGradVerify) {
auto input_grad_desc = op.GetInputDescByName("input_grad");
auto orig_input_shape_desc = op.GetInputDescByName("orig_input_shape");
ge::AscendString op_name;
(void)op.GetName(op_name);
auto orig_input_shape_dim = orig_input_shape_desc.GetShape().GetDimNum();
if (orig_input_shape_dim != 1) {
OP_LOGE("AdaptiveAvgPool3dGrad", "Num Dim of orig_input_shape is invalid");
return GRAPH_PARAM_INVALID;
}
auto orig_input_dim_num = orig_input_shape_desc.GetShape().GetShapeSize();
auto input_grad_dim_num = input_grad_desc.GetShape().GetDimNum();
if (orig_input_dim_num != static_cast<int64_t>(input_grad_dim_num)) {
OP_LOGE("AdaptiveAvgPool3dGrad", "Num Dim of orig_input and input_grad should be the same");
return GRAPH_PARAM_INVALID;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(AdaptiveAvgPool3dGradInferShape) {
map<int, std::string> format2str = {
{ge::FORMAT_NCHW, "NCHW"}, {ge::FORMAT_NHWC, "NHWC"}, {ge::FORMAT_HWCN, "HWCN"}, {ge::FORMAT_DHWNC, "DHWNC"},
{ge::FORMAT_DHWCN, "DHWCN"}, {ge::FORMAT_NDHWC, "NDHWC"}, {ge::FORMAT_NCDHW, "NCDHW"}};
auto input_desc = op.GetInputDescByName("input_grad");
auto orig_input_shape_desc = op.GetInputDescByName("orig_input_shape");
TensorDesc out_desc = op.GetOutputDescByName("output_grad");
ge::AscendString op_name;
(void)op.GetName(op_name);
// update format
Format input_format = input_desc.GetFormat();
std::string format_str = format2str[input_format];
if (input_format != FORMAT_NCHW) {
OP_LOGE("AdaptiveAvgPool3dGrad",
"Input format only support NCHW"
", input format is [%s]",
format_str.c_str());
return GRAPH_FAILED;
}
out_desc.SetFormat(input_format);
// update data type
DataType input_type = input_desc.GetDataType();
out_desc.SetDataType(input_type);
// infer shape
Tensor orig_input_size_tensor;
if (op.GetInputConstData("orig_input_shape", orig_input_size_tensor) != GRAPH_SUCCESS) {
OP_LOGE("AdaptiveAvgPool3dGrad", "failed to get tensor from output_size");
return GRAPH_FAILED;
}
int32_t *orig_input_size_data = reinterpret_cast<int32_t *>(orig_input_size_tensor.GetData());
if (orig_input_size_data == nullptr) {
OP_LOGE("AdaptiveAvgPool3dGrad", "output_size data is invalid");
return GRAPH_PARAM_INVALID;
}
auto input_size_dim_num = input_desc.GetShape().GetDimNum();
std::vector<int64_t> output_shape(input_size_dim_num);
for (uint64_t i = 0; i < input_size_dim_num; ++i) {
output_shape[i] = orig_input_size_data[i];
}
out_desc.SetShape(Shape(output_shape));
if (op.UpdateOutputDesc("output_grad", out_desc) != GRAPH_SUCCESS) {
OP_LOGE("AdaptiveAvgPool3dGrad", "failed to update output desc");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(AdaptiveAvgPool3dGrad, AdaptiveAvgPool3dGradInferShape);
CUST_VERIFY_FUNC_REG(AdaptiveAvgPool3dGrad, AdaptiveAvgPool3dGradVerify);
// --------- AdaptiveAvgPool3dGrad end---------------
} // namespace ge

View File

@ -1,90 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/adaptive_avg_pool_3d_op.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
namespace ge {
// --------- AdaptiveAvgPool3d ---------------
IMPLEMT_COMMON_INFERFUNC(AdaptiveAvgPool3dInferShape) {
map<int, std::string> format2str = {
{ge::FORMAT_NCHW, "NCHW"}, {ge::FORMAT_NHWC, "NHWC"}, {ge::FORMAT_HWCN, "HWCN"}, {ge::FORMAT_DHWNC, "DHWNC"},
{ge::FORMAT_DHWCN, "DHWCN"}, {ge::FORMAT_NDHWC, "NDHWC"}, {ge::FORMAT_NCDHW, "NCDHW"}};
// verify the dim of output_size
std::vector<int64_t> output_size;
if (GRAPH_SUCCESS != op.GetAttr("output_size", output_size)) {
OP_LOGE(TbeGetName(op).c_str(), "GetOpAttr output_size failed!");
return GRAPH_PARAM_INVALID;
}
ge::AscendString op_name;
(void)op.GetName(op_name);
auto input_desc = op.GetInputDescByName("x");
TensorDesc out_desc = op.GetOutputDescByName("y");
// update data type
DataType input_type = input_desc.GetDataType();
out_desc.SetDataType(input_type);
// update format
Format input_format = input_desc.GetFormat();
std::string format_str = format2str[input_format];
if (input_format != FORMAT_NCHW) {
OP_LOGE("AdaptiveAvgPool3d",
"Input format only support NCHW"
", input format is [%s]",
format_str.c_str());
return GRAPH_FAILED;
}
out_desc.SetFormat(input_format);
std::vector<int64_t> input_size_shape = input_desc.GetShape().GetDims();
auto input_size_dim_num = input_size_shape.size();
std::vector<int64_t> output_shape(input_size_shape.begin(), input_size_shape.end());
auto output_size_num = output_size.size();
if (output_size_num == 1) {
for (uint64_t i = input_size_dim_num - 3; i < input_size_dim_num; ++i) {
if (output_size[0] < 0) {
continue;
}
output_shape[i] = output_size[0];
}
} else if (output_size_num == 3) {
for (uint64_t i = input_size_dim_num - 3; i < input_size_dim_num; ++i) {
auto data = output_size[i - input_size_dim_num + 3];
if (data < 0) {
continue;
}
output_shape[i] = data;
}
} else {
OP_LOGE("AdaptiveAvgPool3d", "Shape of output_size is invalid");
return GRAPH_FAILED;
}
out_desc.SetShape(Shape(output_shape));
if (op.UpdateOutputDesc("y", out_desc) != GRAPH_SUCCESS) {
OP_LOGE("AdaptiveAvgPool3d", "failed to update output desc");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(AdaptiveAvgPool3d, AdaptiveAvgPool3dInferShape);
// --------- AdaptiveAvgPool3d end---------------
} // namespace ge

View File

@ -1,39 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/adaptive_max_pool3_d_grad_op.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
namespace ge {
CUST_IMPLEMT_INFERFUNC(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradInferShape) {
TensorDesc output_grad = op.GetOutputDescByName("output_grad");
TensorDesc input = op.GetInputDescByName("x");
DataType input_dtype = input.GetDataType();
Shape input_shape = input.GetShape();
output_grad.SetShape(input_shape);
output_grad.SetDataType(input_dtype);
if (op.UpdateOutputDesc("output_grad", output_grad) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradVerify) { return GRAPH_SUCCESS; }
CUST_INFER_FUNC_REG(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradInferShape);
CUST_VERIFY_FUNC_REG(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradVerify);
} // namespace ge

View File

@ -1,58 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/adaptive_max_pool3d_op.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
namespace ge {
CUST_IMPLEMT_INFERFUNC(AdaptiveMaxPool3d, AdaptiveMaxPool3dInferShape) {
TensorDesc input = op.GetInputDesc(0);
TensorDesc output_size = op.GetInputDesc(1);
TensorDesc output = op.GetOutputDesc(0);
TensorDesc argmax = op.GetOutputDesc(1);
const size_t input_num_dims = input.GetShape().GetDimNum();
const std::vector<int64_t> output_size_shape = output_size.GetShape().GetDims();
if ((input_num_dims == 4 || input_num_dims == 5) == false) {
OP_LOGE(TbeGetName(op), "Input dimensions must be equal to 4 or 5.");
return GRAPH_FAILED;
}
if (output_size_shape.size() != 1) {
OP_LOGE(TbeGetName(op), "output_size dim should be equal to 1.");
return GRAPH_FAILED;
}
if (output_size_shape[0] != 3) {
OP_LOGE(TbeGetName(op), "output_size shape[0] should be equal to 3.");
return GRAPH_FAILED;
}
DataType input_dtype = input.GetDataType();
Shape output_shape(UNKNOWN_SHAPE);
output.SetDataType(input_dtype);
output.SetShape(output_shape);
argmax.SetDataType(DT_INT32);
argmax.SetShape(output_shape);
(void)op.UpdateOutputDesc("y", output);
(void)op.UpdateOutputDesc("argmax", argmax);
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(AdaptiveMaxPool3d, AdaptiveMaxPool3dVerify) { return GRAPH_SUCCESS; }
CUST_INFER_FUNC_REG(AdaptiveMaxPool3d, AdaptiveMaxPool3dInferShape);
CUST_VERIFY_FUNC_REG(AdaptiveMaxPool3d, AdaptiveMaxPool3dVerify);
} // namespace ge

View File

@ -1,39 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/adaptive_max_pool_2d_grad_op.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
namespace ge {
CUST_IMPLEMT_INFERFUNC(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradInferShape) {
TensorDesc input_grad = op.GetOutputDescByName("x_grad");
TensorDesc input = op.GetInputDescByName("x");
DataType input_dtype = input.GetDataType();
Shape input_shape = input.GetShape();
input_grad.SetShape(input_shape);
input_grad.SetDataType(input_dtype);
if (op.UpdateOutputDesc("x_grad", input_grad) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradVerify) { return GRAPH_SUCCESS; }
CUST_INFER_FUNC_REG(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradInferShape);
CUST_VERIFY_FUNC_REG(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradVerify);
} // namespace ge

View File

@ -160,6 +160,11 @@ IMPLEMT_INFERFUNC(Expand, ExpandInferShape) {
OP_LOGE(op_name, "Data shape are not compatible!");
return GRAPH_FAILED;
}
} else if (data_type == DT_INT16) {
if (!ExpandCalDim<int16_t>(data, vec_dim, x_dims, range_vector)) {
OP_LOGE(op_name, "Data shape are not compatible!");
return GRAPH_FAILED;
}
} else {
OP_LOGE(op_name, "Data type not supported!");
return GRAPH_PARAM_INVALID;
@ -412,12 +417,12 @@ static bool CheckSteps(const Operator &op, const string &attr_num_steps) {
CUST_IMPLEMT_VERIFIER(LogSpace, LogSpaceVerify) {
AscendString opName;
op.GetName(opName);
if (op.GetInputDescByName("start").GetShape().GetDims().size() != 1) {
OP_LOGE(opName.GetString(), "Input start size must be 1.");
if (op.GetInputDescByName("start").GetShape().GetDims().size() > 1) {
OP_LOGE(opName.GetString(), "Input start size must be <= 1.");
return GRAPH_FAILED;
}
if (op.GetInputDescByName("end").GetShape().GetDims().size() != 1) {
OP_LOGE(opName.GetString(), "Input end size must be 1.");
if (op.GetInputDescByName("end").GetShape().GetDims().size() > 1) {
OP_LOGE(opName.GetString(), "Input end size must be <= 1.");
return GRAPH_FAILED;
}
DataType input_type_start = op.GetInputDescByName("start").GetDataType();
@ -462,4 +467,108 @@ CUST_COMMON_INFER_FUNC_REG(LogSpace, LogSpaceInferShape);
// Registered verify function
CUST_VERIFY_FUNC_REG(LogSpace, LogSpaceVerify);
// --------------------------LogSpace END---------------------
// ----------------UniqueConsecutive Begin-------------------
IMPLEMT_INFERFUNC(UniqueConsecutive, UniqueConsecutiveInfer) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto x_desc_ptr = op_desc->MutableInputDesc(0);
auto y_desc_ptr = op_desc->MutableOutputDesc(0);
y_desc_ptr->SetDataType(x_desc_ptr->GetDataType());
auto idx_desc_ptr = op_desc->MutableOutputDesc(1);
auto count_desc_ptr = op_desc->MutableOutputDesc(2);
auto &y_shape = y_desc_ptr->MutableShape();
auto &idx_shape = idx_desc_ptr->MutableShape();
auto &count_shape = count_desc_ptr->MutableShape();
bool return_idx = false;
bool return_counts = false;
int64_t axis = 1000;
op.GetAttr("axis", axis);
op.GetAttr("return_idx", return_idx);
op.GetAttr("return_counts", return_counts);
count_shape.SetIsUnknownDimNum();
count_desc_ptr->SetDataType(DT_INT64);
idx_shape.SetIsUnknownDimNum();
idx_desc_ptr->SetDataType(DT_INT64);
y_shape.SetIsUnknownDimNum();
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(UniqueConsecutive, UniqueConsecutiveInfer);
// ----------------UniqueConsecutive End-----------------------
// ----------------UpperBound-----------------------
IMPLEMT_INFERFUNC(UpperBound, UpperBoundInfer) {
Shape unused_shape;
if (WithRank(op.GetInputDesc(0), 2, unused_shape, op) != GRAPH_SUCCESS) {
string err_msg = ConcatString("failed to call WithRank function, input[sorted_x] rank must be 2D, got rank[",
op.GetInputDesc(0).GetShape().GetDimNum(), "]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (WithRank(op.GetInputDesc(1), 2, unused_shape, op) != GRAPH_SUCCESS) {
string err_msg = ConcatString("failed to call WithRank function, input[values] rank must be 2D, got rank[",
op.GetInputDesc(1).GetShape().GetDimNum(), "]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
DataType type;
if (op.GetAttr("out_type", type) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("get attr[out_type] failed"));
return GRAPH_FAILED;
}
TensorDesc out_desc = op.GetOutputDescByName("y");
out_desc.SetShape(op.GetInputDesc(1).GetShape());
out_desc.SetDataType(type);
return op.UpdateOutputDesc("y", out_desc);
}
INFER_FUNC_REG(UpperBound, UpperBoundInfer);
// ----------------UpperBound END-----------------------
// ----------------UnravelIndex-----------------------
IMPLEMT_INFERFUNC(UnravelIndex, UnravelIndexInfer) {
auto indices_desc = op.GetInputDesc(0);
auto dims_desc = op.GetInputDesc(1);
Shape dims_shape;
if (WithRank(dims_desc, 1, dims_shape, op) != GRAPH_SUCCESS) {
string err_msg = ConcatString("input[dims] must be 1D, real rank is ", dims_shape.GetDimNum());
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
Shape indices_shape;
if (WithRankAtMost(indices_desc, 1, indices_shape, op) != GRAPH_SUCCESS) {
string err_msg = ConcatString("input[indices] must be less than 1D, real rank is ", dims_shape.GetDimNum());
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
std::vector<int64_t> out_dims({-1, -1});
std::vector<int64_t> dims_shape_vec = dims_shape.GetDims();
std::vector<int64_t> indices_shape_vec = indices_shape.GetDims();
if (indices_shape.GetDimNum() == 0) {
out_dims.pop_back();
} else {
if (indices_shape_vec != ge::UNKNOWN_RANK && indices_shape_vec != ge::UNKNOWN_SHAPE) {
out_dims[1] = indices_shape_vec[0];
}
}
if (dims_shape_vec != ge::UNKNOWN_RANK && dims_shape_vec != ge::UNKNOWN_SHAPE) {
out_dims[0] = dims_shape_vec[0];
}
TensorDesc out_desc = op.GetOutputDescByName("y");
out_desc.SetShape(Shape(out_dims));
out_desc.SetDataType(indices_desc.GetDataType());
return op.UpdateOutputDesc("y", out_desc);
}
INFER_FUNC_REG(UnravelIndex, UnravelIndexInfer);
// ----------------UnravelIndex END-----------------------
} // namespace ge

View File

@ -1,45 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/linalg_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
#include "utils/linalg_ops_shape_fns.h"
namespace ge {
IMPLEMT_INFERFUNC(CholeskyGrad, CholeskyGradInfer) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto x_desc = op_desc->MutableInputDesc(0);
GeShape y_shape;
if (MakeBatchSquareMatrix(x_desc, y_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(),
"Op CholeskyGrad first input x tensor make batch square matrix "
"failed.");
return GRAPH_FAILED;
}
DataType type = x_desc->GetDataType();
auto y_desc = op_desc->MutableOutputDesc(0);
y_desc->SetShape(y_shape);
y_desc->SetDataType(type);
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(CholeskyGrad, CholeskyGradInfer);
} // namespace ge

View File

@ -59,6 +59,10 @@ COMMON_INFER_FUNC_REG(Cos, OneInOneOutCommonInferShape);
COMMON_INFER_FUNC_REG(Expm1, OneInOneOutCommonInferShape);
COMMON_INFER_FUNC_REG(Log1p, OneInOneOutCommonInferShape);
COMMON_INFER_FUNC_REG(Log, OneInOneOutCommonInferShape);
COMMON_INFER_FUNC_REG(Tanh, OneInOneOutCommonInferShape);
COMMON_INFER_FUNC_REG(Sin, OneInOneOutCommonInferShape);
COMMON_INFER_FUNC_REG(Reciprocal, OneInOneOutCommonInferShape);
COMMON_INFER_FUNC_REG(Sign, OneInOneOutCommonInferShape);
// ----------------------------------OneInOneOutCommonInfer END-----------------------------
// ----------------------------------TowInOneOutCommonInfer-----------------------------
@ -68,6 +72,9 @@ CUST_COMMON_INFER_FUNC_REG(Gcd, TwoInOneOutCommonInferShape);
CUST_COMMON_INFER_FUNC_REG(Heaviside, TwoInOneOutCommonInferShape);
CUST_COMMON_INFER_FUNC_REG(Hypot, TwoInOneOutCommonInferShape);
CUST_COMMON_INFER_FUNC_REG(Lcm, TwoInOneOutCommonInferShape);
CUST_COMMON_INFER_FUNC_REG(Pow, TwoInOneOutCommonInferShape);
CUST_COMMON_INFER_FUNC_REG(Xlogy, TwoInOneOutCommonInferShape);
CUST_COMMON_INFER_FUNC_REG(Xdivy, TwoInOneOutCommonInferShape);
// ----------------------------------TowInOneOutCommonInfer END-----------------------------
// --------------AcosGrad----------------
@ -399,4 +406,33 @@ IMPLEMT_VERIFIER(FloorDiv, FloorDivVerify) {
COMMON_INFER_FUNC_REG(FloorDiv, TwoInOneOutCommonInferShape);
VERIFY_FUNC_REG(FloorDiv, FloorDivVerify);
// ----------------FloorDiv END------------------------
// ----------------SqrtGrad Op Begin-----------------
IMPLEMT_VERIFIER(SqrtGrad, SqrtGradVerify) {
if (!CheckTwoInputDtypeSame(op, "y", "dy")) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(SqrtGradInferShape) {
Shape shape_x = op.GetInputDescByName("y").GetShape();
DataType input_dtype = op.GetInputDescByName("y").GetDataType();
TensorDesc tensordesc_output = op.GetOutputDescByName("z");
std::vector<std::pair<int64_t, int64_t>> shape_range_x;
op.GetInputDescByName("y").GetShapeRange(shape_range_x);
tensordesc_output.SetShape(shape_x);
tensordesc_output.SetDataType(input_dtype);
tensordesc_output.SetShapeRange(shape_range_x);
if (op.UpdateOutputDesc("z", tensordesc_output) != GRAPH_SUCCESS) {
std::string err_msg = UpdateParamErrMsg("z");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(SqrtGrad, SqrtGradInferShape);
VERIFY_FUNC_REG(SqrtGrad, SqrtGradVerify);
// ----------------SqrtGrad Op End-------------------
} // namespace ge

View File

@ -14,9 +14,11 @@
* limitations under the License.
*/
#include <graph/utils/type_utils.h>
#include "inc/ops/image_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/op_const.h"
#include "utils/image_ops_shape_fns.h"
namespace ge {
// ----------------AdjustHue Start-------------------
@ -138,5 +140,425 @@ IMPLEMT_INFERFUNC(ExtractGlimpse, ExtractGlimpseInfer) {
}
INFER_FUNC_REG(ExtractGlimpse, ExtractGlimpseInfer);
// ----------------ExtractGlimpse-------------------
// ----------------ExtractGlimpse END-------------------
// ----------------ResizeArea-------------------
IMPLEMT_INFERFUNC(ResizeArea, ResizeAreaInfer) {
TensorDesc desc = op.GetOutputDescByName("y");
desc.SetDataType(DT_FLOAT);
if (op.UpdateOutputDesc("y", desc) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
return ResizeShapeFn(op, "images", "size", "y");
}
INFER_FUNC_REG(ResizeArea, ResizeAreaInfer);
// ----------------ResizeArea END-------------------
// ----------------ResizeBicubic-------------------
IMPLEMT_INFERFUNC(ResizeBicubic, ResizeBicubicInfer) {
TensorDesc desc = op.GetOutputDescByName("y");
DataType data_type = DT_FLOAT;
// Update DataType when Attr "dtype" is set
if (op.GetAttr("dtype", data_type) == GRAPH_SUCCESS) {
CHECK(((data_type != DT_FLOAT) && (data_type != DT_UINT8)),
OP_LOGE(TbeGetName(op), "Attr dtype should only be DT_FLOAT or DT_UNIT8"), return GRAPH_FAILED);
OP_LOGI(TbeGetName(op), "Update Bicubic DataType from attr, which is %d", data_type);
}
desc.SetDataType(data_type);
if (op.UpdateOutputDesc("y", desc) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
return ResizeShapeFn(op, "images", "size", "y");
}
INFER_FUNC_REG(ResizeBicubic, ResizeBicubicInfer);
// ----------------ResizeBicubic END-------------------
bool ResizeConstInferShape(const Operator &op, const std::pair<uint32_t, std::string> image_info,
const std::pair<uint32_t, std::string> size_info,
const std::pair<uint32_t, std::string> output_info) {
static const size_t output_len = 4;
static const size_t size_len = 2;
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
CHECK(op_desc == nullptr, OP_LOGE(TbeGetName(op), "op desc is null."), return false);
auto input_desc_x = op_desc->MutableInputDesc(image_info.first);
CHECK(input_desc_x == nullptr,
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), OtherErrMsg("input x is null.")), return false);
auto output_desc_y = op_desc->MutableOutputDesc(output_info.first);
CHECK(output_desc_y == nullptr,
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), OtherErrMsg("output y is null.")), return false);
// infer dtype start
output_desc_y->SetDataType(input_desc_x->GetDataType());
// infer dtype end
// infer shape start
const GeShape &x_shape = input_desc_x->MutableShape();
auto input_format = input_desc_x->GetFormat();
OP_LOGD(TbeGetName(op), "get the format is %s", ge::TypeUtils::FormatToSerialString(input_format).c_str());
CHECK(input_format != FORMAT_NHWC && input_format != FORMAT_NCHW,
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), OtherErrMsg("The input format is valid")),
return false);
const int64_t image_n_idx = 0;
// format is NHWC, c_idx = 3, format is NCHW, c_idx = 1,
const int64_t image_c_idx = input_format == FORMAT_NHWC ? 3 : 1;
const int64_t image_h_idx = input_format == FORMAT_NHWC ? 1 : 2;
const int64_t image_w_idx = input_format == FORMAT_NHWC ? 2 : 3;
// get const value
bool is_size_const = true;
vector<int64_t> size_out;
if (!ops::GetConstIntData(op, size_info.first, size_out)) {
OP_LOGW(TbeGetName(op).c_str(), "get const value of input size failed, set out hw = -1, -1");
size_out = {-1, -1};
is_size_const = false;
}
// the size num must be 2, mean output h, output w
OP_LOGD(TbeGetName(op), "the size num must be 2. get the num is %zu", size_out.size());
CHECK(size_out.size() != size_len,
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), OtherErrMsg("the input size num must be 2.")),
return false);
// get y shape
GeShape &y_shape = output_desc_y->MutableShape();
y_shape.SetDimNum(output_len);
if (!x_shape.IsUnknownDimNum()) {
OP_LOGD(TbeGetName(op), "the input shape size must be 4. get shape size is %zu", x_shape.GetDimNum());
CHECK(x_shape.GetDimNum() != output_len,
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), OtherErrMsg("The dim of input x is not 4")),
return false);
y_shape.SetDim(image_n_idx, x_shape.GetDim(image_n_idx));
y_shape.SetDim(image_c_idx, x_shape.GetDim(image_c_idx));
} else {
OP_LOGW(TbeGetName(op).c_str(), "the input is unknown rank, will set the out nc = -1, -1");
y_shape.SetDim(image_n_idx, -1);
y_shape.SetDim(image_c_idx, -1);
}
y_shape.SetDim(image_h_idx, size_out[0]);
y_shape.SetDim(image_w_idx, size_out[1]);
// infer shape end
// charge whether is dynamic, when output is static shape, return true
CHECK(!y_shape.IsUnknownShape(), OP_LOGD(TbeGetName(op), "the output is static shape. infer succ"), return true);
OP_LOGD(TbeGetName(op), "the output is dynamic shape. will infer range");
// infer shape_range start
std::vector<std::pair<int64_t, int64_t>> x_range;
vector<int64_t> image_shape{-1, -1, -1, -1};
// check whether is -2 case
if (!x_shape.IsUnknownDimNum()) {
image_shape = x_shape.GetDims();
(void)input_desc_x->GetShapeRange(x_range);
}
MakeUpShapeRange(image_shape, x_range);
OP_LOGD(TbeGetName(op), "the input range size must be 4. get size is %zu", x_range.size());
CHECK(x_range.size() != output_len,
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), OtherErrMsg("the x range size is not equal 4")),
return false);
if (!is_size_const) {
std::vector<std::pair<int64_t, int64_t>> size_value_range;
auto input_size_x = op_desc->MutableInputDesc(size_info.first);
CHECK(input_size_x == nullptr,
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), OtherErrMsg("input size is null.")),
return false);
// means no const value, will get the value range
(void)input_size_x->GetValueRange(size_value_range);
// the size num must be 2, so the value range num must be 2
if (size_value_range.size() != size_len) {
x_range[image_h_idx] = std::pair<int64_t, int64_t>(0, -1);
x_range[image_w_idx] = std::pair<int64_t, int64_t>(0, -1);
} else {
x_range[image_h_idx] = size_value_range[0];
x_range[image_w_idx] = size_value_range[1];
}
} else {
x_range[image_h_idx] = std::pair<int64_t, int64_t>(size_out[0], size_out[0]);
x_range[image_w_idx] = std::pair<int64_t, int64_t>(size_out[1], size_out[1]);
}
output_desc_y->SetShapeRange(x_range);
// infer shape_range end
return true;
}
// ---------------ResizeNearestNeighborV2 Op Start-------------------
IMPLEMT_COMMON_INFERFUNC(ResizeNearestNeighborV2InferShape) {
static const std::pair<uint32_t, std::string> input_x{0, "x"};
static const std::pair<uint32_t, std::string> input_size{1, "size"};
static const std::pair<uint32_t, std::string> output_y{0, "y"};
const vector<string> depends{input_size.second};
PREPARE_DYNAMIC_SHAPE(depends);
if (!ResizeConstInferShape(op, input_x, input_size, output_y)) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(ResizeNearestNeighborV2, ResizeNearestNeighborV2InferShape);
INFER_VALUE_RANGE_DEFAULT_REG(ResizeNearestNeighborV2);
// ---------------ResizeNearestNeighborV2 Op End-------------------
// ----------------ResizeNearestNeighborV2Grad-------------------
IMPLEMT_INFERFUNC(ResizeNearestNeighborV2Grad, ResizeNearestNeighborV2GradInfer) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto y_desc = op_desc->MutableOutputDesc(0);
auto size_desc = op_desc->MutableInputDesc(1);
auto grads_desc = op_desc->MutableInputDesc(0);
if (op.GetInputDesc(0).GetShape().GetDims() == UNKNOWN_RANK ||
op.GetInputDesc(1).GetShape().GetDims() == UNKNOWN_RANK) {
y_desc->SetShape(GeShape(UNKNOWN_RANK));
y_desc->SetDataType(grads_desc->GetDataType());
return GRAPH_SUCCESS;
}
// unknown shape support
std::vector<std::string> input_infer_depends = {"size"};
op_desc->SetOpInferDepends(input_infer_depends);
GeShape grads_shape;
if (WithRank(grads_desc, 4, grads_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(op_desc->GetName().c_str(), "Input grads must be 4-D, real rank is [%lu]",
grads_desc->GetShape().GetDimNum());
return GRAPH_PARAM_INVALID;
}
GeShape size_shape;
if (WithRank(size_desc, 1, size_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(op_desc->GetName().c_str(), "Input size must be 1-D, real rank is [%lu]",
size_desc->GetShape().GetDimNum());
return GRAPH_PARAM_INVALID;
}
auto size_dims = size_shape.GetDims();
if (size_dims[0] != 2 && size_dims[0] != UNKNOWN_DIM) {
OP_LOGE(op_desc->GetName().c_str(), "Input size must be 1-D of 2 elements, real dim size is [%ld]", size_dims[0]);
return GRAPH_PARAM_INVALID;
}
int64_t size_height = UNKNOWN_DIM;
int64_t size_width = UNKNOWN_DIM;
Tensor size_tensor;
if (op.GetInputConstData("size", size_tensor) == GRAPH_SUCCESS) {
auto size_data = reinterpret_cast<const int32_t *>(size_tensor.GetData());
if (size_data == nullptr) {
OP_LOGE(op_desc->GetName().c_str(), "Get size data failed");
return GRAPH_PARAM_INVALID;
}
size_height = static_cast<int64_t>(size_data[0]);
size_width = static_cast<int64_t>(size_data[1]);
}
std::vector<int64_t> output_dims;
auto grads_dims = grads_shape.GetDims();
auto input_format = static_cast<ge::Format>(ge::GetPrimaryFormat(grads_desc->GetFormat()));
if (input_format == FORMAT_NCHW) {
output_dims.push_back(grads_dims[0]);
output_dims.push_back(grads_dims[1]);
output_dims.push_back(size_height);
output_dims.push_back(size_width);
} else if (input_format == FORMAT_NHWC) {
output_dims.push_back(grads_dims[0]);
output_dims.push_back(size_height);
output_dims.push_back(size_width);
output_dims.push_back(grads_dims[3]);
} else {
OP_LOGE(op_desc->GetName().c_str(), "Not supported this format: [%d]", input_format);
return GRAPH_PARAM_INVALID;
}
GeShape output_shape(output_dims);
if (ShapeFullyDefined(output_shape) == false) {
std::vector<std::pair<int64_t, int64_t>> output_range;
for (const int64_t &output_dim : output_dims) {
output_range.push_back(output_dim == UNKNOWN_DIM ? std::pair<int64_t, int64_t>{1, -1}
: std::pair<int64_t, int64_t>{output_dim, output_dim});
}
y_desc->SetShapeRange(output_range);
}
y_desc->SetDataType(grads_desc->GetDataType());
y_desc->SetShape(output_shape);
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(ResizeNearestNeighborV2Grad, ResizeNearestNeighborV2GradInfer);
// ----------------ResizeNearestNeighborV2Grad END-------------------
// ----------------RGBToHSV-------------------
IMPLEMT_INFERFUNC(RGBToHSV, RGBToHSVInfer) {
TensorDesc desc = op.GetOutputDescByName("y");
desc.SetDataType(op.GetInputDesc(0).GetDataType());
if (op.UpdateOutputDesc("y", desc) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("update output[y] desc failed"));
return GRAPH_FAILED;
}
return ColorspaceShapeFn(op, "y");
}
INFER_FUNC_REG(RGBToHSV, RGBToHSVInfer);
// ----------------RGBToHSV END-------------------
// ----------------NonMaxSuppressionWithOverlaps-------------------
IMPLEMT_INFERFUNC(NonMaxSuppressionWithOverlaps, NonMaxSuppressionWithOverlapsInfer) {
Shape overlaps_shape = op.GetInputDescByName("overlaps").GetShape();
Shape scores_shape = op.GetInputDescByName("scores").GetShape();
Shape max_output_size_shape = op.GetInputDescByName("max_output_size").GetShape();
Shape overlap_threshold_shape = op.GetInputDescByName("overlap_threshold").GetShape();
Shape score_threshold_shape = op.GetInputDescByName("score_threshold").GetShape();
if (WithRank(op.GetInputDescByName("overlaps"), 2, overlaps_shape, op) != GRAPH_SUCCESS) {
std::string err_msg =
ConcatString("failed to call WithRank function, ", "input[overlaps] rank must be 2, but got rank[",
op.GetInputDescByName("overlaps").GetShape().GetDimNum(), "]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (WithRank(op.GetInputDescByName("scores"), 1, scores_shape, op) != GRAPH_SUCCESS) {
std::string err_msg =
ConcatString("failed to call WithRank function, ", "input[scores] rank must be 1, but got rank[",
op.GetInputDescByName("scores").GetShape().GetDimNum(), "]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (WithRank(op.GetInputDescByName("max_output_size"), 0, max_output_size_shape, op) != GRAPH_SUCCESS) {
std::string err_msg =
ConcatString("failed to call WithRank function, ", "input[max_output_size] rank must be 0, but got rank[",
op.GetInputDescByName("max_output_size").GetShape().GetDimNum(), "]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (WithRank(op.GetInputDescByName("overlap_threshold"), 0, overlap_threshold_shape, op) != GRAPH_SUCCESS) {
std::string err_msg =
ConcatString("failed to call WithRank function, ", "input[overlap_threshold] rank must be 0, but got rank[",
op.GetInputDescByName("overlap_threshold").GetShape().GetDimNum(), "]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (WithRank(op.GetInputDescByName("score_threshold"), 0, score_threshold_shape, op) != GRAPH_SUCCESS) {
std::string err_msg =
ConcatString("failed to call WithRank function, ", "input[score_threshold] rank must be 0, but got rank[",
op.GetInputDescByName("score_threshold").GetShape().GetDimNum(), "]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
int64_t unused_dim = 0;
if (Merge(overlaps_shape.GetDim(0), scores_shape.GetDim(0), unused_dim) != GRAPH_SUCCESS) {
std::string err_msg =
ConcatString("failed to call Merge function to merge the input[overlaps] 0th dim", "[", overlaps_shape.GetDim(0),
"] and the input[scores]'s 0th dim [", scores_shape.GetDim(0), "]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (Merge(overlaps_shape.GetDim(0), overlaps_shape.GetDim(1), unused_dim) != GRAPH_SUCCESS) {
std::string err_msg =
ConcatString("failed to call Merge function to merge the input[overlaps] 0th dim", "[", overlaps_shape.GetDim(0),
"] and the input[overlaps]'s 1th dim [", overlaps_shape.GetDim(1), "]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
TensorDesc selected_indices_desc = op.GetOutputDescByName("selected_indices");
Shape selecte_indices_shape;
Vector(ge::UNKNOWN_DIM, selecte_indices_shape);
selected_indices_desc.SetDataType(DT_INT32);
selected_indices_desc.SetShape(selecte_indices_shape);
if (op.UpdateOutputDesc("selected_indices", selected_indices_desc) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("update output[selected_indices] desc failed"));
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(NonMaxSuppressionWithOverlaps, NonMaxSuppressionWithOverlapsInfer);
// ----------------NonMaxSuppressionWithOverlaps END-------------------
// ----------------ScaleAndTranslate-------------------
IMPLEMT_INFERFUNC(ScaleAndTranslate, ScaleAndTranslateInfer) {
TensorDesc desc = op.GetOutputDescByName("y");
desc.SetDataType(DT_FLOAT);
if (op.UpdateOutputDesc("y", desc) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("update description for output[y] failed."));
return GRAPH_FAILED;
}
return ResizeShapeFn(op, "images", "size", "y");
}
INFER_FUNC_REG(ScaleAndTranslate, ScaleAndTranslateInfer);
// ----------------ScaleAndTranslate END-------------------
// ----------------ScaleAndTranslateGrad-------------------
IMPLEMT_INFERFUNC(ScaleAndTranslateGrad, ScaleAndTranslateGradInfer) {
TensorDesc desc = op.GetOutputDescByName("y");
Format input_format = static_cast<ge::Format>(ge::GetPrimaryFormat(op.GetInputDesc(0).GetFormat()));
vector<int64_t> grads_shape = op.GetInputDesc(0).GetShape().GetDims();
vector<int64_t> org_images_shape = op.GetInputDesc(1).GetShape().GetDims();
vector<int64_t> y_shape;
if (input_format == FORMAT_NHWC && grads_shape.size() > 3 && org_images_shape.size() > 2) {
y_shape.push_back(grads_shape[0]);
y_shape.push_back(org_images_shape[1]);
y_shape.push_back(org_images_shape[2]);
y_shape.push_back(grads_shape[3]);
} else if (input_format == FORMAT_NCHW && grads_shape.size() > 1 && org_images_shape.size() > 3) {
y_shape.push_back(grads_shape[0]);
y_shape.push_back(grads_shape[1]);
y_shape.push_back(org_images_shape[2]);
y_shape.push_back(org_images_shape[3]);
} else {
if (grads_shape.size() < 4) {
std::string err_msg =
ConcatString("the 0th input[grads]'s rank should not be less than 4, ", "current rank is ", grads_shape.size());
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (org_images_shape.size() < 2) {
std::string err_msg = ConcatString("the 1th input[original_images]'s rank should not be less than 2, ",
"current rank is ", org_images_shape.size());
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
y_shape.push_back(grads_shape[0]);
y_shape.push_back(org_images_shape[1]);
y_shape.push_back(org_images_shape[2]);
y_shape.push_back(grads_shape[3]);
OP_LOGI(TbeGetName(op).c_str(), "Real format is %d", input_format);
}
desc.SetShape(ge::Shape(y_shape));
desc.SetDataType(DT_FLOAT);
return op.UpdateOutputDesc("y", desc);
}
INFER_FUNC_REG(ScaleAndTranslateGrad, ScaleAndTranslateGradInfer);
// ----------------ScaleAndTranslateGrad END-------------------
// ----------------ResizeBicubicGrad-------------------
IMPLEMT_INFERFUNC(ResizeBicubicGrad, ResizeBicubicGradInfer) {
TensorDesc desc = op.GetOutputDescByName("y");
Format input_format = static_cast<ge::Format>(ge::GetPrimaryFormat(op.GetInputDesc(0).GetFormat()));
vector<int64_t> grads_shape = op.GetInputDesc(0).GetShape().GetDims();
vector<int64_t> org_images_shape = op.GetInputDesc(1).GetShape().GetDims();
vector<int64_t> y_shape;
if (input_format == FORMAT_NHWC && grads_shape.size() > 3 && org_images_shape.size() > 2) {
y_shape.push_back(grads_shape[0]);
y_shape.push_back(org_images_shape[1]);
y_shape.push_back(org_images_shape[2]);
y_shape.push_back(grads_shape[3]);
} else if (input_format == FORMAT_NCHW && grads_shape.size() > 1 && org_images_shape.size() > 3) {
y_shape.push_back(grads_shape[0]);
y_shape.push_back(grads_shape[1]);
y_shape.push_back(org_images_shape[2]);
y_shape.push_back(org_images_shape[3]);
} else {
std::string str_input_format = ge::TypeUtils::FormatToSerialString(input_format);
std::string err_msg = ConcatString("only supporting NCHW and NHWC, current format is [", str_input_format, "]");
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
}
desc.SetShape(ge::Shape(y_shape));
auto type = op.GetInputDesc(1).GetDataType();
desc.SetDataType(type);
return op.UpdateOutputDesc("y", desc);
}
INFER_FUNC_REG(ResizeBicubicGrad, ResizeBicubicGradInfer);
// ----------------ResizeBicubicGrad END-------------------
} // namespace ge

View File

@ -1,44 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CUSTOMIZE_OP_PROTO_INC_TRACE_GRAD_OP_H
#define CUSTOMIZE_OP_PROTO_INC_TRACE_GRAD_OP_H
#include "op_proto_macro.h"
namespace ge {
/**
* @brief Computes the grad of x1 in trace. \n
* @par Inputs:
* Four inputs, including:
* @li y_grad: A tensor. \n
* @li x_shape: A tensor. Must be one of the following types:
* int32, int64. \n
* @par Outputs:
* x_grad: A Tensor with the same type and shape of y_grad's. \n
* @par Third-party framework compatibility
* Compatible with the Pytorch operator Trace Backward. \n
*/
REG_CUST_OP(TraceGrad)
.INPUT(y_grad, TensorType::BasicType())
.INPUT(x_shape, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(x_grad, TensorType::BasicType())
.CUST_OP_END_FACTORY_REG(TraceGrad)
} // namespace ge
#endif

View File

@ -404,4 +404,92 @@ CUST_IMPLEMT_VERIFIER(LuSolve, LuSolveVerify) {
CUST_COMMON_INFER_FUNC_REG(LuSolve, LuSolveInferShape);
CUST_VERIFY_FUNC_REG(LuSolve, LuSolveVerify);
// -----------------------LuSolve END---------------------------------
// -----------------------Qr---------------------------------
IMPLEMT_INFERFUNC(Qr, QrInfer) {
auto tensor = op.get_input_desc_x();
Shape input;
if (WithRankAtLeast(tensor, 2, input, op) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
Shape batch_shape;
if (SubShape(input, 0, -2, 1, batch_shape, op) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
int dim_num = input.GetDimNum();
int m = input.GetDim(dim_num - 2);
int n = input.GetDim(dim_num - 1);
Shape q_shape;
Shape r_shape;
auto full_matrices = op.get_attr_full_matrices();
if (full_matrices) {
// [...,M,M]; [...,M,N], if full_matrices is true
Shape m_m_shape;
Shape m_n_shape;
Matrix(m, m, m_m_shape);
Matrix(m, n, m_n_shape);
Concatenate(batch_shape, m_m_shape, q_shape);
Concatenate(batch_shape, m_n_shape, r_shape);
} else {
// [...,M,P]; [...,P,N], if full_matrices is false
int p = m > n ? n : m;
Shape m_p_shape;
Shape p_n_shape;
Matrix(m, p, m_p_shape);
Matrix(p, n, p_n_shape);
Concatenate(batch_shape, m_p_shape, q_shape);
Concatenate(batch_shape, p_n_shape, r_shape);
}
DataType type = op.GetInputDescByName("x").GetDataType();
TensorDesc q_desc = op.GetOutputDescByName("q");
q_desc.SetShape(Shape(q_shape));
q_desc.SetDataType(type);
if (op.UpdateOutputDesc("q", q_desc) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Update q desc failed.");
return GRAPH_FAILED;
}
TensorDesc r_desc = op.GetOutputDescByName("r");
r_desc.SetShape(Shape(r_shape));
r_desc.SetDataType(type);
if (op.UpdateOutputDesc("r", r_desc) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Update r desc failed.");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(Qr, QrInfer);
// -----------------------Qr END---------------------------------
// -----------------------CholeskyGrad---------------------------------
IMPLEMT_INFERFUNC(CholeskyGrad, CholeskyGradInfer) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto x_desc = op_desc->MutableInputDesc(0);
GeShape y_shape;
if (MakeBatchSquareMatrix(x_desc, y_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(),
"Op CholeskyGrad first input x tensor make batch square matrix "
"failed.");
return GRAPH_FAILED;
}
DataType type = x_desc->GetDataType();
auto y_desc = op_desc->MutableOutputDesc(0);
y_desc->SetShape(y_shape);
y_desc->SetDataType(type);
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(CholeskyGrad, CholeskyGradInfer);
// -----------------------CholeskyGrad END---------------------------------
} // namespace ge

View File

@ -5,9 +5,11 @@
*/
#include "inc/ops/math_ops.h"
#include "inc/ops/ragged_math_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
#include "utils/reduce_infer_util.h"
namespace ge {
// ----------------ComplexAbs-------------------
@ -212,4 +214,113 @@ IMPLEMT_INFERFUNC(IsInf, IsInfInfer) {
INFER_FUNC_REG(IsInf, IsInfInfer);
// ----------------IsInf END------------------------
// ----------------ReduceOp-------------------
static bool InferReduceShapeProcess(const Operator &op, const int64_t input_x_idx, const int64_t output_y_idx,
const int64_t input_axes_idx) {
bool keep_dims = false;
op.GetAttr("keep_dims", keep_dims);
reduce_ops::CommonReduceInferWithInputAxes(op, input_x_idx, output_y_idx, input_axes_idx, keep_dims);
return true;
}
IMPLEMT_COMMON_INFERFUNC(TypicalReduceInferShape) {
OP_LOGD(TbeGetName(op), "Enter %s InferShape", TbeGetOpType(op).c_str());
const int64_t input_x_idx = 0;
const int64_t output_y_idx = 0;
const int64_t input_axes_idx = 1;
if (InferReduceShapeProcess(op, input_x_idx, output_y_idx, input_axes_idx)) {
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}
COMMON_INFER_FUNC_REG(ReduceSum, TypicalReduceInferShape);
// ----------------ReduceOp END-------------------
// ----------------RaggedRange-------------------
IMPLEMT_INFERFUNC(RaggedRange, RaggedRangeInfer) {
Shape starts;
Shape limits;
Shape deltas;
if (WithRankAtMost(op.GetInputDesc(0), 1, starts, op) != GRAPH_SUCCESS) {
std::string err_msg =
ConcatString("failed to call WithRankAtMost function, ", "input[starts] rank must be at most 1D, got rank[",
op.GetInputDesc(0).GetShape().GetDimNum(), "]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (WithRankAtMost(op.GetInputDesc(1), 1, limits, op) != GRAPH_SUCCESS) {
std::string err_msg =
ConcatString("failed to call WithRankAtMost function, ", "input[limits] rank must be at most 1D, got rank[",
op.GetInputDesc(1).GetShape().GetDimNum(), "]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (WithRankAtMost(op.GetInputDesc(2), 1, deltas, op) != GRAPH_SUCCESS) {
std::string err_msg =
ConcatString("failed to call WithRankAtMost function, input[deltas] ", "rank must be at most 1D, got rank[",
op.GetInputDesc(2).GetShape().GetDimNum(), "]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
int64_t dim = ge::UNKNOWN_DIM;
int64_t starts_dim = starts.GetDim(0);
int64_t limits_dim = limits.GetDim(0);
int64_t deltas_dim = deltas.GetDim(0);
if (op.GetInputDesc(0).GetShape().GetDimNum() == 1) {
if (Merge(starts_dim, dim, dim) != GRAPH_SUCCESS) {
std::string err_msg = ConcatString("failed to call Merge function, the 0th dim[", starts_dim,
"] of input[starts] not equal UNKNOWN_DIM");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
}
if (op.GetInputDesc(1).GetShape().GetDimNum() == 1) {
if (Merge(limits_dim, dim, dim) != GRAPH_SUCCESS) {
std::string err_msg = ConcatString("failed to call Merge function, the 0th dim[", limits_dim,
"] of input[limits] not equal UNKNOWN_DIM");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
}
if (op.GetInputDesc(2).GetShape().GetDimNum() == 1) {
if (Merge(deltas_dim, dim, dim) != GRAPH_SUCCESS) {
std::string err_msg = ConcatString("failed to call Merge function, the 0th dim[", deltas_dim,
"] of input[deltas] not equal UNKNOWN_DIM");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
}
int64_t rt_nested_splits_dim = ge::UNKNOWN_DIM;
if (dim != ge::UNKNOWN_DIM) {
rt_nested_splits_dim = dim + 1;
} else if (op.GetInputDesc(0).GetShape().GetDimNum() == 0 && op.GetInputDesc(1).GetShape().GetDimNum() == 0 &&
op.GetInputDesc(2).GetShape().GetDimNum() == 0) {
rt_nested_splits_dim = 2;
}
DataType Tsplits_type;
if (op.GetAttr("Tsplits", Tsplits_type) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("get attr[Tsplits] failed"));
return GRAPH_FAILED;
}
TensorDesc rt_nested_desc = op.GetOutputDescByName("rt_nested_splits");
rt_nested_desc.SetShape(Shape({rt_nested_splits_dim}));
rt_nested_desc.SetDataType(Tsplits_type);
(void)op.UpdateOutputDesc("rt_nested_splits", rt_nested_desc);
DataType T_type = op.GetInputDescByName("starts").GetDataType();
std::vector<int64_t> unknow_dim_vec(1, UNKNOWN_DIM);
TensorDesc dense_desc = op.GetOutputDescByName("rt_dense_values");
dense_desc.SetShape(Shape(unknow_dim_vec));
dense_desc.SetDataType(T_type);
(void)op.UpdateOutputDesc("rt_dense_values", dense_desc);
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(RaggedRange, RaggedRangeInfer);
// ----------------RaggedRange END-------------------
} // namespace ge

View File

@ -159,4 +159,162 @@ CUST_COMMON_INFER_FUNC_REG(MatrixLogarithm, MatrixLogarithmInferShaper);
// ----------------MatrixExp-------------------
CUST_COMMON_INFER_FUNC_REG(MatirxExp, OneInOneOutCommonInferShape);
// ----------------MatrixExp END-------------------
// ----------------TraceGrad Begin------------------------
IMPLEMT_COMMON_INFERFUNC(TraceGradInferShape) {
Shape shape = op.GetInputDescByName("y_grad").GetShape();
DataType input_dtype = op.GetInputDescByName("y_grad").GetDataType();
std::vector<std::string> input_infer_depends = {"x_shape"};
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
op_desc->SetOpInferDepends(input_infer_depends);
Tensor tensor_input;
Shape output_shape;
if (op.GetInputConstData("x_shape", tensor_input) == GRAPH_SUCCESS) {
MakeShapeFromShapeTensor(tensor_input, output_shape, op);
} else {
output_shape = Shape({UNKNOWN_RANK});
}
TensorDesc td = op.GetOutputDescByName("x_grad");
td.SetShape(output_shape);
td.SetDataType(input_dtype);
td.SetFormat(FORMAT_ND);
(void)op.UpdateOutputDesc("x_grad", td);
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(TraceGrad, TraceGradVerify) {
DataType x_shape_dtype = op.GetInputDescByName("x_shape").GetDataType();
if ((x_shape_dtype != DT_INT32) && (x_shape_dtype != DT_INT64)) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(TraceGrad, TraceGradInferShape);
CUST_VERIFY_FUNC_REG(TraceGrad, TraceGradVerify);
// ---------------TraceGrad END-------------------------------
// ---------------Tril--------------
IMPLEMT_COMMON_INFERFUNC(TrilInferShape) {
Shape input_shape = op.GetInputDesc(0).GetShape();
DataType input_dtype = op.GetInputDesc(0).GetDataType();
TensorDesc td = op.GetOutputDesc(0);
td.SetShape(ge::Shape(input_shape));
td.SetDataType(input_dtype);
(void)op.UpdateOutputDesc("y", td);
return GRAPH_SUCCESS;
}
IMPLEMT_VERIFIER(Tril, TrilVerify) { return GRAPH_SUCCESS; }
INFER_FUNC_REG(Tril, TrilInferShape);
VERIFY_FUNC_REG(Tril, TrilVerify);
// ----------------Tril END----------------
// -----------------ScatterNdUpdate-----------------
IMPLEMT_VERIFIER(ScatterNdUpdate, ScatterNdUpdateVerify) {
if (!CheckTwoInputDtypeSame(op, "var", "updates")) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(ScatterNdUpdateInferShape) {
// main part of shape infer
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
ge::GeShape var_shape = op_desc->MutableInputDesc("var")->GetShape();
std::vector<std::pair<int64_t, int64_t>> var_shape_range;
op_desc->MutableInputDesc("var")->GetShapeRange(var_shape_range);
DataType input_dtype = op_desc->MutableInputDesc("var")->GetDataType();
GeTensorDescPtr td = op_desc->MutableOutputDesc("var");
td->SetShape(var_shape);
td->SetDataType(input_dtype);
td->SetShapeRange(var_shape_range);
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(ScatterNdUpdate, ScatterNdUpdateInferShape);
VERIFY_FUNC_REG(ScatterNdUpdate, ScatterNdUpdateVerify);
// -------------------ScatterNdUpdate END----------------
// -----------------TensorScatterUpdate-----------------
IMPLEMT_VERIFIER(TensorScatterUpdate, TensorScatterUpdateVerify) {
if (!CheckTwoInputDtypeSame(op, "x", "updates")) {
return GRAPH_FAILED;
}
DataType indices_dtype = op.GetInputDescByName("indices").GetDataType();
if ((indices_dtype != DT_INT32) && (indices_dtype != DT_INT64)) {
OP_LOGE("tensor_scatter_update", "The indices type is not int32 or int64, please check!");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_INFERFUNC(TensorScatterUpdate, TensorScatterUpdateInferShape) {
Shape var_shape = op.GetInputDescByName("x").GetShape();
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
TensorDesc td = op.GetOutputDescByName("y");
td.SetShape(ge::Shape(var_shape));
td.SetDataType(input_dtype);
(void)op.UpdateOutputDesc("y", td);
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(TensorScatterUpdate, TensorScatterUpdateInferShape);
VERIFY_FUNC_REG(TensorScatterUpdate, TensorScatterUpdateVerify);
// -------------------TensorScatterUpdate END----------------
// -------------------Orgqr----------------
CUST_COMMON_INFER_FUNC_REG(Orgqr, OneInOneOutCommonInferShape);
// -------------------Orgqr END----------------
// -----------------------Trace-----------------------
static bool InferShapeAndTypeTrace(Operator &op, const std::string &inputName, const std::string outputName) {
TensorDesc vOutputDesc = op.GetOutputDescByName(outputName.c_str());
DataType inputDtype = op.GetInputDescByName(inputName.c_str()).GetDataType();
ge::Shape inputShape = op.GetInputDescByName(inputName.c_str()).GetShape();
// set output tensor dim
std::vector<int64_t> dimVec;
ge::Shape outputShape = ge::Shape(dimVec);
vOutputDesc.SetShape(outputShape);
const std::vector<DataType> unchange_dtype = {DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT,
DT_FLOAT16, DT_INT64, DT_UINT64};
if (CheckInputDataType(op, "x", unchange_dtype) == true) {
vOutputDesc.SetDataType(inputDtype);
} else {
vOutputDesc.SetDataType(ge::DT_INT64);
}
op.UpdateOutputDesc(outputName.c_str(), vOutputDesc);
return true;
}
IMPLEMT_VERIFIER(Trace, TraceVerify) {
AscendString op_name;
CHECK(op.GetName(op_name) != GRAPH_SUCCESS, OP_LOGE("", "GetName failed."), return GRAPH_FAILED);
ge::Shape shapeX = op.GetInputDescByName("x").GetShape();
DataType dtypeX = op.GetInputDescByName("x").GetDataType();
constexpr int64_t shapeDimsLimit = 2;
if (shapeX.GetDims() != UNKNOWN_RANK && shapeX.GetDimNum() != shapeDimsLimit) {
OP_LOGE(op_name.GetString(), "the input shape must be 2-D matrix.\n");
return GRAPH_FAILED;
}
const std::vector<DataType> support_list = {DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16,
DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32,
DT_UINT32, DT_INT64, DT_UINT64};
if (CheckInputDataType(op, "x", support_list) == false) {
OP_LOGE(TbeGetName(op).c_str(), "dataType [%s] is not supported in Trace.", DTypeStr(dtypeX).c_str());
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(TraceInferShape) {
if (InferShapeAndTypeTrace(op, "x", "y")) {
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}
COMMON_INFER_FUNC_REG(Trace, TraceInferShape);
VERIFY_FUNC_REG(Trace, TraceVerify);
// ---------------------Trace END----------------------
} // namespace ge

View File

@ -0,0 +1,58 @@
/**
* Copyright (c) 2023 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "custom_op_proto/cust_array_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
// ---------------Meshgrid-------------------
IMPLEMT_COMMON_INFERFUNC(MeshgridInfer) {
auto input_size = op.GetInputsSize();
if (input_size == 0) {
return GRAPH_SUCCESS;
}
std::vector<int64_t> out_dims = std::vector<int64_t>();
auto out_dtype = op.GetDynamicInputDesc("x", 0).GetDataType();
for (size_t i = 0; i < op.GetInputsSize(); ++i) {
auto in_shape = op.GetInputDesc(i).GetShape();
if (in_shape.GetDims() == UNKNOWN_RANK) {
out_dims = ge::UNKNOWN_RANK;
break;
}
out_dims.push_back(in_shape.GetDim(0));
}
std::string indexing;
if (op.GetAttr("indexing", indexing) == GRAPH_SUCCESS) {
if (indexing == "xy") {
std::swap(out_dims[0], out_dims[1]);
}
} else {
OP_LOGW(TbeGetName(op).c_str(), "get attr indexing failed.");
}
Shape out_shape = Shape(out_dims);
for (size_t i = 0; i < op.GetInputsSize(); ++i) {
TensorDesc output_desc = op.GetDynamicOutputDesc("y", i);
output_desc.SetShape(out_shape);
output_desc.SetDataType(out_dtype);
op.UpdateDynamicOutputDesc("y", i, output_desc);
}
return GRAPH_SUCCESS;
}
CUST_INFER_FUNC_REG(Meshgrid, MeshgridInfer);
// ---------------Meshgrid End---------------
} // namespace ge

View File

@ -11,6 +11,309 @@
#include "utils/common_shape_fns.h"
namespace ge {
// ---------------AdaptiveAvgPool2D-------------------
CUST_IMPLEMT_INFERFUNC(AdaptiveAvgPool2D, AdaptiveAvgPool2dInferShape) {
OP_LOGI(TbeGetName(op).c_str(), " AdaptiveAvgPool2d inferShape begin!");
const size_t DIM_SIZE2 = 2;
auto input_tensor_desc = op.GetInputDescByName("x");
auto shape = input_tensor_desc.GetShape();
// get output_size
std::vector<int64_t> ouput_size_list;
if (GRAPH_SUCCESS != op.GetAttr("output_size", ouput_size_list)) {
OP_LOGE(TbeGetName(op).c_str(), "GetOpAttr ouput_size_list failed!");
return GRAPH_FAILED;
}
// check output size
if (ouput_size_list.size() != DIM_SIZE2) {
OP_LOGE(TbeGetName(op).c_str(), "length of output_size must be 2");
return GRAPH_FAILED;
}
std::vector<int64_t> dims_input = shape.GetDims();
// set output shape
std::vector<int64_t> dim_vector;
for (size_t i = 0; i < dims_input.size(); i++) {
int64_t dims = dims_input[i];
dim_vector.push_back(dims);
}
size_t index0 = dims_input.size() - 2;
size_t index1 = dims_input.size() - 1;
if (ouput_size_list[0] > 0) {
dim_vector[index0] = ouput_size_list[0];
}
if (ouput_size_list[1] > 0) {
dim_vector[index1] = ouput_size_list[1];
}
TensorDesc td = op.GetOutputDescByName("y");
DataType input_dtype = input_tensor_desc.GetDataType();
Shape output_shape(dim_vector);
td.SetShape(output_shape);
td.SetDataType(input_dtype);
(void)op.UpdateOutputDesc("y", td);
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(AdaptiveAvgPool2D, AdaptiveAvgPool2dVerify) { return GRAPH_SUCCESS; }
CUST_INFER_FUNC_REG(AdaptiveAvgPool2D, AdaptiveAvgPool2dInferShape);
CUST_VERIFY_FUNC_REG(AdaptiveAvgPool2D, AdaptiveAvgPool2dVerify);
// ---------------AdaptiveAvgPool2D End---------------
// ---------------AdaptiveAvgPool2DGrad-------------------
CUST_IMPLEMT_INFERFUNC(AdaptiveAvgPool2DGrad, AdaptiveAvgPool2dGradInferShape) {
std::vector<std::string> input_infer_depends = {"orig_input_shape"};
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
op_desc->SetOpInferDepends(input_infer_depends);
DataType input_dtype = op.GetInputDescByName("input_grad").GetDataType();
Shape output_shape;
Tensor orig_input_shape_tensor;
if (op.GetInputConstData("orig_input_shape", orig_input_shape_tensor) != GRAPH_SUCCESS) {
auto output_desc = op.GetOutputDescByName("output_grad");
output_desc.SetDataType(input_dtype);
output_desc.SetShape(Shape(ge::UNKNOWN_RANK));
return op.UpdateOutputDesc("output_grad", output_desc);
}
MakeShapeFromShapeTensor(orig_input_shape_tensor, output_shape, op);
TensorDesc output_grad = op.GetOutputDescByName("output_grad");
output_grad.SetShape(output_shape);
output_grad.SetDataType(input_dtype);
return op.UpdateOutputDesc("output_grad", output_grad);
}
CUST_INFER_FUNC_REG(AdaptiveAvgPool2DGrad, AdaptiveAvgPool2dGradInferShape);
// ---------------AdaptiveAvgPool2DGrad END-------------------
// --------- AdaptiveAvgPool3d ---------------
IMPLEMT_COMMON_INFERFUNC(AdaptiveAvgPool3dInferShape) {
map<int, std::string> format2str = {
{ge::FORMAT_NCHW, "NCHW"}, {ge::FORMAT_NHWC, "NHWC"}, {ge::FORMAT_HWCN, "HWCN"}, {ge::FORMAT_DHWNC, "DHWNC"},
{ge::FORMAT_DHWCN, "DHWCN"}, {ge::FORMAT_NDHWC, "NDHWC"}, {ge::FORMAT_NCDHW, "NCDHW"}};
// verify the dim of output_size
std::vector<int64_t> output_size;
if (GRAPH_SUCCESS != op.GetAttr("output_size", output_size)) {
OP_LOGE(TbeGetName(op).c_str(), "GetOpAttr output_size failed!");
return GRAPH_PARAM_INVALID;
}
ge::AscendString op_name;
(void)op.GetName(op_name);
auto input_desc = op.GetInputDescByName("x");
TensorDesc out_desc = op.GetOutputDescByName("y");
// update data type
DataType input_type = input_desc.GetDataType();
out_desc.SetDataType(input_type);
// update format
Format input_format = input_desc.GetFormat();
std::string format_str = format2str[input_format];
if (input_format != FORMAT_NCHW) {
OP_LOGE("AdaptiveAvgPool3d",
"Input format only support NCHW"
", input format is [%s]",
format_str.c_str());
return GRAPH_FAILED;
}
out_desc.SetFormat(input_format);
std::vector<int64_t> input_size_shape = input_desc.GetShape().GetDims();
auto input_size_dim_num = input_size_shape.size();
std::vector<int64_t> output_shape(input_size_shape.begin(), input_size_shape.end());
auto output_size_num = output_size.size();
if (output_size_num == 1) {
for (uint64_t i = input_size_dim_num - 3; i < input_size_dim_num; ++i) {
if (output_size[0] < 0) {
continue;
}
output_shape[i] = output_size[0];
}
} else if (output_size_num == 3) {
for (uint64_t i = input_size_dim_num - 3; i < input_size_dim_num; ++i) {
auto data = output_size[i - input_size_dim_num + 3];
if (data < 0) {
continue;
}
output_shape[i] = data;
}
} else {
OP_LOGE("AdaptiveAvgPool3d", "Shape of output_size is invalid");
return GRAPH_FAILED;
}
out_desc.SetShape(Shape(output_shape));
if (op.UpdateOutputDesc("y", out_desc) != GRAPH_SUCCESS) {
OP_LOGE("AdaptiveAvgPool3d", "failed to update output desc");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(AdaptiveAvgPool3d, AdaptiveAvgPool3dInferShape);
// --------- AdaptiveAvgPool3d end---------------
// --------- AdaptiveAvgPool3dGrad ---------------
CUST_IMPLEMT_VERIFIER(AdaptiveAvgPool3dGrad, AdaptiveAvgPool3dGradVerify) {
auto input_grad_desc = op.GetInputDescByName("input_grad");
auto orig_input_shape_desc = op.GetInputDescByName("orig_input_shape");
ge::AscendString op_name;
(void)op.GetName(op_name);
auto orig_input_shape_dim = orig_input_shape_desc.GetShape().GetDimNum();
if (orig_input_shape_dim != 1) {
OP_LOGE("AdaptiveAvgPool3dGrad", "Num Dim of orig_input_shape is invalid");
return GRAPH_PARAM_INVALID;
}
auto orig_input_dim_num = orig_input_shape_desc.GetShape().GetShapeSize();
auto input_grad_dim_num = input_grad_desc.GetShape().GetDimNum();
if (orig_input_dim_num != static_cast<int64_t>(input_grad_dim_num)) {
OP_LOGE("AdaptiveAvgPool3dGrad", "Num Dim of orig_input and input_grad should be the same");
return GRAPH_PARAM_INVALID;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(AdaptiveAvgPool3dGradInferShape) {
map<int, std::string> format2str = {
{ge::FORMAT_NCHW, "NCHW"}, {ge::FORMAT_NHWC, "NHWC"}, {ge::FORMAT_HWCN, "HWCN"}, {ge::FORMAT_DHWNC, "DHWNC"},
{ge::FORMAT_DHWCN, "DHWCN"}, {ge::FORMAT_NDHWC, "NDHWC"}, {ge::FORMAT_NCDHW, "NCDHW"}};
auto input_desc = op.GetInputDescByName("input_grad");
auto orig_input_shape_desc = op.GetInputDescByName("orig_input_shape");
TensorDesc out_desc = op.GetOutputDescByName("output_grad");
ge::AscendString op_name;
(void)op.GetName(op_name);
// update format
Format input_format = input_desc.GetFormat();
std::string format_str = format2str[input_format];
if (input_format != FORMAT_NCHW) {
OP_LOGE("AdaptiveAvgPool3dGrad",
"Input format only support NCHW"
", input format is [%s]",
format_str.c_str());
return GRAPH_FAILED;
}
out_desc.SetFormat(input_format);
// update data type
DataType input_type = input_desc.GetDataType();
out_desc.SetDataType(input_type);
// infer shape
Tensor orig_input_size_tensor;
if (op.GetInputConstData("orig_input_shape", orig_input_size_tensor) != GRAPH_SUCCESS) {
OP_LOGE("AdaptiveAvgPool3dGrad", "failed to get tensor from output_size");
return GRAPH_FAILED;
}
int32_t *orig_input_size_data = reinterpret_cast<int32_t *>(orig_input_size_tensor.GetData());
if (orig_input_size_data == nullptr) {
OP_LOGE("AdaptiveAvgPool3dGrad", "output_size data is invalid");
return GRAPH_PARAM_INVALID;
}
auto input_size_dim_num = input_desc.GetShape().GetDimNum();
std::vector<int64_t> output_shape(input_size_dim_num);
for (uint64_t i = 0; i < input_size_dim_num; ++i) {
output_shape[i] = orig_input_size_data[i];
}
out_desc.SetShape(Shape(output_shape));
if (op.UpdateOutputDesc("output_grad", out_desc) != GRAPH_SUCCESS) {
OP_LOGE("AdaptiveAvgPool3dGrad", "failed to update output desc");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(AdaptiveAvgPool3dGrad, AdaptiveAvgPool3dGradInferShape);
CUST_VERIFY_FUNC_REG(AdaptiveAvgPool3dGrad, AdaptiveAvgPool3dGradVerify);
// --------- AdaptiveAvgPool3dGrad end---------------
// --------- AdaptiveMaxPool3d---------------
CUST_IMPLEMT_INFERFUNC(AdaptiveMaxPool3d, AdaptiveMaxPool3dInferShape) {
TensorDesc input = op.GetInputDesc(0);
TensorDesc output_size = op.GetInputDesc(1);
TensorDesc output = op.GetOutputDesc(0);
TensorDesc argmax = op.GetOutputDesc(1);
const size_t input_num_dims = input.GetShape().GetDimNum();
const std::vector<int64_t> output_size_shape = output_size.GetShape().GetDims();
if ((input_num_dims == 4 || input_num_dims == 5) == false) {
OP_LOGE(TbeGetName(op), "Input dimensions must be equal to 4 or 5.");
return GRAPH_FAILED;
}
if (output_size_shape.size() != 1) {
OP_LOGE(TbeGetName(op), "output_size dim should be equal to 1.");
return GRAPH_FAILED;
}
if (output_size_shape[0] != 3) {
OP_LOGE(TbeGetName(op), "output_size shape[0] should be equal to 3.");
return GRAPH_FAILED;
}
DataType input_dtype = input.GetDataType();
Shape output_shape(UNKNOWN_SHAPE);
output.SetDataType(input_dtype);
output.SetShape(output_shape);
argmax.SetDataType(DT_INT32);
argmax.SetShape(output_shape);
(void)op.UpdateOutputDesc("y", output);
(void)op.UpdateOutputDesc("argmax", argmax);
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(AdaptiveMaxPool3d, AdaptiveMaxPool3dVerify) { return GRAPH_SUCCESS; }
CUST_INFER_FUNC_REG(AdaptiveMaxPool3d, AdaptiveMaxPool3dInferShape);
CUST_VERIFY_FUNC_REG(AdaptiveMaxPool3d, AdaptiveMaxPool3dVerify);
// --------- AdaptiveMaxPool3d END---------------
// --------- AdaptiveMaxPool2dGrad---------------
CUST_IMPLEMT_INFERFUNC(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradInferShape) {
TensorDesc input_grad = op.GetOutputDescByName("x_grad");
TensorDesc input = op.GetInputDescByName("x");
DataType input_dtype = input.GetDataType();
Shape input_shape = input.GetShape();
input_grad.SetShape(input_shape);
input_grad.SetDataType(input_dtype);
if (op.UpdateOutputDesc("x_grad", input_grad) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradVerify) { return GRAPH_SUCCESS; }
CUST_INFER_FUNC_REG(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradInferShape);
CUST_VERIFY_FUNC_REG(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradVerify);
// --------- AdaptiveMaxPool2dGrad END---------------
// --------- AdaptiveMaxPool3dGrad---------------
CUST_IMPLEMT_INFERFUNC(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradInferShape) {
TensorDesc output_grad = op.GetOutputDescByName("output_grad");
TensorDesc input = op.GetInputDescByName("x");
DataType input_dtype = input.GetDataType();
Shape input_shape = input.GetShape();
output_grad.SetShape(input_shape);
output_grad.SetDataType(input_dtype);
if (op.UpdateOutputDesc("output_grad", output_grad) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradVerify) { return GRAPH_SUCCESS; }
CUST_INFER_FUNC_REG(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradInferShape);
CUST_VERIFY_FUNC_REG(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradVerify);
// --------- AdaptiveMaxPool3dGrad END---------------
// -------------------DataFormatVecPermute---------------------
IMPLEMT_INFERFUNC(DataFormatVecPermute, DataFormatVecPermuteInfer) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
@ -284,4 +587,56 @@ CUST_INFER_FUNC_REG(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxInferShape);
CUST_VERIFY_FUNC_REG(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxVerify);
//-------------------MaxPool3DGradWithArgMax---------------------
//-------------------NthElement---------------------
IMPLEMT_INFERFUNC(NthElement, NthElementInfer) {
std::vector<std::string> input_infer_depends = {"n"};
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
op_desc->SetOpInferDepends(input_infer_depends);
Shape x_shape;
auto x_tensor = op.get_input_desc_x();
if (WithRankAtLeast(x_tensor, 1, x_shape, op) != GRAPH_SUCCESS) {
std::string err_msg =
ConcatString("failed to call WithRankAtLeast function, ", "input[x] rank must be at least 1D, but got rank[",
op.get_input_desc_x().GetShape().GetDimNum(), "]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
Tensor n_tensor;
int64_t n_dim = 0;
if (op.GetInputConstData("n", n_tensor) != GRAPH_SUCCESS) {
n_dim = ge::UNKNOWN_DIM;
} else {
if (MakeDimForScalarInput(n_tensor, n_dim, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), std::string("failed to call MakeDimForScalarInput function, "
"get input n shape failed"));
return GRAPH_FAILED;
}
}
int64_t existing = x_shape.GetDimNum();
int64_t last_input_dim = x_shape.GetDim(existing - 1);
if ((last_input_dim != ge::UNKNOWN_DIM) && (n_dim != ge::UNKNOWN_DIM) && (last_input_dim <= n_dim)) {
std::string err_msg =
ConcatString("input[x] last dim value[", last_input_dim, "] must be greater than [", n_dim, "]");
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
Shape output_shape;
if (SubShape(x_shape, 0, -1, 1, output_shape, op) != GRAPH_SUCCESS) {
std::string err_msg = ConcatString("failed to call SubShape function, input[x] shape[",
DebugString(x_shape.GetDims()), "], start[0], end[-1], stride[1]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
TensorDesc y_tensor = op.GetOutputDescByName("y");
y_tensor.SetDataType(x_tensor.GetDataType());
y_tensor.SetShape(output_shape);
return op.UpdateOutputDesc("y", y_tensor);
}
INFER_FUNC_REG(NthElement, NthElementInfer);
//-------------------NthElement END---------------------
} // namespace ge

View File

@ -0,0 +1,316 @@
/**
* Copyright (c) 2022-202 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstring>
#include <vector>
#include <algorithm>
#include <numeric>
#include "inc/ops/pad_ops.h"
#include "graph/utils/node_utils.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
#include "utils/error_util.h"
#include "utils/op_log.h"
#include "utils/op_const.h"
namespace ge {
// ----------------Pad Op Begin-------------------
static graphStatus PadInferShapeAndType(ge::Operator &op, std::vector<int64_t> &paddings) {
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
static const int64_t input_x_idx = 0;
auto input_desc = op_info->MutableInputDesc(input_x_idx);
const ge::GeShape &input_shape = input_desc->MutableShape();
auto input_dtype = input_desc->GetDataType();
static const int64_t output_y_idx = 0;
auto output_desc = op_info->MutableOutputDesc(output_y_idx);
ge::GeShape &output_shape = output_desc->MutableShape();
output_desc->SetDataType(input_dtype);
// input shape is -2, output is -2
if (input_shape.IsUnknownDimNum()) {
output_desc->SetShape(input_shape);
return GRAPH_SUCCESS;
}
size_t dim_num = input_shape.GetDimNum();
if (!input_shape.IsUnknownShape()) {
// not dynamic shape, will output shape and dtype
if (dim_num == 0) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("input shape cannot empty"));
return GRAPH_FAILED;
}
if (dim_num * 2 != paddings.size()) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op),
OtherErrMsg("the num of paddings must be double the input dim size"));
return GRAPH_FAILED;
}
// calce the output shape
output_shape.SetDimNum(dim_num);
for (size_t dim = 0; dim < dim_num; dim++) {
output_shape.SetDim(dim, input_shape.GetDim(dim) + paddings[dim * 2] + paddings[dim * 2 + 1]);
}
return GRAPH_SUCCESS;
}
// input shape is -1, will get the shape and range
// calcu the output shape
output_shape.SetDimNum(dim_num);
for (size_t dim = 0; dim < dim_num; dim++) {
if (input_shape.GetDim(dim) == -1) {
output_shape.SetDim(dim, input_shape.GetDim(dim));
} else {
output_shape.SetDim(dim, input_shape.GetDim(dim) + paddings[dim * 2] + paddings[dim * 2 + 1]);
}
}
// calcu the output range
std::vector<std::pair<int64_t, int64_t>> input_range;
input_desc->GetShapeRange(input_range);
MakeUpShapeRange(input_shape, input_range);
std::vector<std::pair<int64_t, int64_t>> output_range;
for (size_t dim = 0; dim < dim_num; dim++) {
auto range_min = input_range[dim].first + paddings[dim * 2] + paddings[dim * 2 + 1];
auto range_max =
input_range[dim].second == -1 ? -1 : input_range[dim].second + paddings[dim * 2] + paddings[dim * 2 + 1];
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
}
output_desc->SetShapeRange(output_range);
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(PadInferShape) {
OP_LOGD(TbeGetName(op), "InferShape Begin.");
const vector<string> depend_names = {"paddings"};
PREPARE_DYNAMIC_SHAPE(depend_names);
// first get the padding const
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
// get const paddings data
std::vector<int64_t> paddings;
static const int64_t input_paddings_idx = 1;
if (!ops::GetConstIntData(op, input_paddings_idx, paddings)) {
OP_LOGW(TbeGetName(op), "the node paddings is not const node, will set the output dynamic");
auto input_desc = op_info->MutableInputDesc("x");
const ge::GeShape &input_shape = input_desc->MutableShape();
DataType input_dtype = input_desc->GetDataType();
auto output_desc = op_info->MutableOutputDesc("y");
// shape_x is UNKNOWN_RANK
if (IsUnknownRankShape(input_shape)) {
OP_LOGW(TbeGetName(op), "shape_x is UNKNOWN_RANK. Set output UNKNOWN_RANK");
output_desc->SetShape(input_shape);
output_desc->SetDataType(input_dtype);
return GRAPH_SUCCESS;
}
size_t dim_num = input_shape.GetDimNum();
// shape_x is UNKNOWN_DIM
if (dim_num == 0) {
dim_num = 1;
}
vector<int64_t> out_shape(dim_num, -1);
std::vector<std::pair<int64_t, int64_t>> output_range;
input_desc->GetShapeRange(output_range);
MakeUpShapeRange(out_shape, output_range);
for (size_t i = 0; i < dim_num; i++) {
output_range[i].second = -1;
}
output_desc->SetShape(GeShape(out_shape));
output_desc->SetDataType(input_dtype);
output_desc->SetShapeRange(output_range);
return GRAPH_SUCCESS;
}
return PadInferShapeAndType(op, paddings);
}
COMMON_INFER_FUNC_REG(Pad, PadInferShape);
// ----------------Pad Op End-------------------
// ----------------PadV3 Op Begin-------------------
IMPLEMT_COMMON_INFERFUNC(PadV3InferShape) {
const vector<string> depend_names = {"paddings"};
PREPARE_DYNAMIC_SHAPE(depend_names);
Tensor paddings_tensor;
if (ge::GRAPH_SUCCESS != op.GetInputConstData("paddings", paddings_tensor)) {
OP_LOGW(TbeGetName(op).c_str(), "Get Const Value [paddings] failed, Setting shape to UNKNOWN_DIM");
Shape shape_x = op.GetInputDescByName("x").GetShape();
vector<int64_t> shape;
for (size_t dim = 0; dim < shape_x.GetDimNum(); dim++) {
shape.push_back(UNKNOWN_DIM);
}
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
TensorDesc tensordesc_output = op.GetOutputDescByName("y");
Shape out_shape(shape);
tensordesc_output.SetShape(out_shape);
tensordesc_output.SetDataType(input_dtype);
(void)op.UpdateOutputDesc("y", tensordesc_output);
return GRAPH_SUCCESS;
}
DataType dtype = op.GetInputDescByName("paddings").GetDataType();
std::vector<int64_t> paddings;
if (!GetConstValue(op, paddings_tensor, dtype, paddings)) {
OP_LOGE(TbeGetName(op).c_str(), "Get Const Value [paddings] failed ");
return GRAPH_FAILED;
}
bool paddings_contiguous = true;
if (op.GetAttr("paddings_contiguous", paddings_contiguous) == GRAPH_FAILED) {
OP_LOGI(TbeGetName(op).c_str(), "Get attr [paddings_contiguous] failed");
}
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
auto input_desc = op_info->MutableInputDesc("x");
auto input_shape = input_desc->MutableShape().GetDims();
auto input_shape_max = input_shape.size();
auto paddings_shape = paddings.size();
// expand paddings by 0
auto expand_num = input_shape_max * 2 - paddings_shape;
for (size_t dim = 0; dim < expand_num; dim++) {
paddings.push_back(0);
}
if (expand_num > 0) {
std::vector<int64_t> pad_vec;
for (int i = input_shape_max; i > 0; i--) {
pad_vec.push_back(paddings[i * 2 - 2]);
pad_vec.push_back(paddings[i * 2 - 1]);
}
paddings = pad_vec;
}
if (!paddings_contiguous) {
std::vector<int64_t> pads;
int64_t rank = paddings.size() / 2;
for (int i = 0; i < rank; i++) {
pads.push_back(paddings[i]);
pads.push_back(paddings[i + rank]);
}
paddings = pads;
OP_LOGI(TbeGetName(op).c_str(), "Get attr paddings_contiguous = false");
} else {
OP_LOGI(TbeGetName(op).c_str(), "Get attr paddings_contiguous = true[default]");
}
return PadInferShapeAndType(op, paddings);
}
COMMON_INFER_FUNC_REG(PadV3, PadV3InferShape);
// ----------------PadV3 Op End-------------------
// ----------------PadV3Grad Op Begin-------------------
static graphStatus PadV3GradInferShapeAndType(ge::Operator &op, std::vector<int64_t> &paddings) {
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
auto input_desc = op_info->MutableInputDesc("x");
auto input_shape = input_desc->MutableShape().GetDims();
auto input_dtype = input_desc->GetDataType();
auto output_desc = op_info->MutableOutputDesc("y");
output_desc->SetDataType(input_dtype);
if (!IsUnknown(input_shape)) {
// calce the output shape
vector<int64_t> output_shape;
for (size_t dim = 0; dim < input_shape.size(); dim++) {
output_shape.push_back(input_shape[dim] - paddings[dim * 2] - paddings[dim * 2 + 1]);
}
output_desc->SetShape(GeShape(output_shape));
return GRAPH_SUCCESS;
}
// input shape is -2, output is -2
if (IsUnknownRankShape(input_shape)) {
output_desc->SetShape(GeShape(input_shape));
return GRAPH_SUCCESS;
}
// input shape is -1, will get the shape and range
// calcu the output shape
vector<int64_t> output_shape;
for (size_t dim = 0; dim < input_shape.size(); dim++) {
if (input_shape[dim] == -1) {
output_shape.push_back(input_shape[dim]);
} else {
output_shape.push_back(input_shape[dim] - paddings[dim * 2] - paddings[dim * 2 + 1]);
}
}
output_desc->SetShape(GeShape(output_shape));
// calcu the output range
std::vector<std::pair<int64_t, int64_t>> input_range;
input_desc->GetShapeRange(input_range);
MakeUpShapeRange(input_shape, input_range);
std::vector<std::pair<int64_t, int64_t>> output_range;
for (size_t dim = 0; dim < input_shape.size(); dim++) {
auto range_min = input_range[dim].first - paddings[dim * 2] - paddings[dim * 2 + 1];
if (range_min < 1) {
range_min = 1;
}
auto range_max =
input_range[dim].second == -1 ? -1 : input_range[dim].second - paddings[dim * 2] - paddings[dim * 2 + 1];
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
}
output_desc->SetShapeRange(output_range);
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(PadV3GradInferShape) {
const vector<string> depend_names = {"paddings"};
PREPARE_DYNAMIC_SHAPE(depend_names);
Tensor paddings_tensor;
if (ge::GRAPH_SUCCESS != op.GetInputConstData("paddings", paddings_tensor)) {
OP_LOGW(TbeGetName(op).c_str(), "Get Const Value [paddings] failed, Setting shape to UNKNOWN_DIM");
Shape shape_x = op.GetInputDescByName("x").GetShape();
vector<int64_t> shape;
for (size_t dim = 0; dim < shape_x.GetDimNum(); dim++) {
shape.push_back(UNKNOWN_DIM);
}
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
TensorDesc tensordesc_output = op.GetOutputDescByName("y");
Shape out_shape(shape);
tensordesc_output.SetShape(out_shape);
tensordesc_output.SetDataType(input_dtype);
(void)op.UpdateOutputDesc("y", tensordesc_output);
return GRAPH_SUCCESS;
}
DataType dtype = op.GetInputDescByName("paddings").GetDataType();
std::vector<int64_t> paddings;
if (!GetConstValue(op, paddings_tensor, dtype, paddings)) {
OP_LOGE(TbeGetName(op).c_str(), "Get Const Value [paddings] failed ");
return GRAPH_FAILED;
}
bool paddings_contiguous = true;
if (op.GetAttr("paddings_contiguous", paddings_contiguous) == GRAPH_FAILED) {
OP_LOGI(TbeGetName(op).c_str(), "Get attr [paddings_contiguous] failed");
}
return PadV3GradInferShapeAndType(op, paddings);
}
COMMON_INFER_FUNC_REG(PadV3Grad, PadV3GradInferShape);
// ----------------PadV3Grad Op End-------------------
} // namespace ge

View File

@ -8,7 +8,9 @@
#include "custom_op_proto/cust_array_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/op_const.h"
#include "utils/common_shape_fns.h"
#include "utils/vector_proto_profiling.h"
namespace ge {
// ----------------CumulativeLogsumexp-------------------
@ -162,4 +164,519 @@ IMPLEMT_COMMON_INFERFUNC(IndexFillInferShape) {
// Registered inferfunction
CUST_COMMON_INFER_FUNC_REG(IndexFill, IndexFillInferShape);
// ----------------IndexFill END-------------------
// ----------------SegmentSum-------------------
static bool SegmentSumShapeVerify(const Operator &op, const std::string &input_name,
const std::string &segment_ids_name) {
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
auto input_shape_dims = op_info->MutableInputDesc("x")->MutableShape().GetDims();
auto segment_ids_shape_dims = op_info->MutableInputDesc("segment_ids")->MutableShape().GetDims();
return true;
}
IMPLEMT_VERIFIER(SegmentSum, SegmentSumInferShapeVerifier) {
if (!SegmentSumShapeVerify(op, "x", "segment_ids")) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(SegmentSumInferShape) {
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
auto input_x_desc = op_info->MutableInputDesc("x");
auto output_desc = op_info->MutableOutputDesc("y");
auto shape_x = input_x_desc->MutableShape().GetDims();
auto output_shape_dims = input_x_desc->MutableShape().GetDims();
if (output_shape_dims.empty()) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("the input[x]'s shape should not be empty."));
return GRAPH_FAILED;
}
const vector<string> depend_name = {"segment_ids"};
PREPARE_DYNAMIC_SHAPE(depend_name);
const std::string segment_ids_name = "segment_ids";
Tensor segment_ids;
int64_t first_axis_dims;
int64_t out_range_first_dims;
if (GRAPH_SUCCESS != op.GetInputConstData(segment_ids_name.c_str(), segment_ids)) {
OP_LOGI("segment_max", "GetInputConstData %s failed.", segment_ids_name.c_str());
first_axis_dims = -1;
out_range_first_dims = 0;
} else {
auto data_type = op.GetInputDescByName(segment_ids_name.c_str()).GetDataType();
std::vector<int64_t> const_data;
if (!GetConstIntData(segment_ids, data_type, const_data)) {
std::string err_msg =
ConcatString("failed to call GetConstIntData function ",
"due to invalid data type of input[segment_ids]. data_type is ", DTypeStr(data_type));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
first_axis_dims = (*std::max_element(const_data.begin(), const_data.end())) + 1;
out_range_first_dims = first_axis_dims;
}
if (IsUnknownRankShape(shape_x)) {
output_desc->SetShape(GeShape(shape_x));
} else {
output_shape_dims[0] = first_axis_dims;
GeShape output_shape(output_shape_dims);
output_desc->SetShape(GeShape(output_shape_dims));
if (output_shape.IsUnknownShape()) {
std::vector<std::pair<int64_t, int64_t>> shape_range_x;
std::vector<std::pair<int64_t, int64_t>> output_shape_range;
output_shape_range.push_back(std::pair<int64_t, int64_t>(out_range_first_dims, first_axis_dims));
input_x_desc->GetShapeRange(shape_range_x);
MakeUpShapeRange(output_shape_dims, shape_range_x);
for (size_t i = 1; i < output_shape_dims.size(); i++) {
output_shape_range.push_back(shape_range_x[i]);
}
output_desc->SetShapeRange(output_shape_range);
}
}
DataType input_dtype = input_x_desc->GetDataType();
output_desc->SetDataType(input_dtype);
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(SegmentSum, SegmentSumInferShape);
VERIFY_FUNC_REG(SegmentSum, SegmentSumInferShapeVerifier);
// ----------------SegmentSum END-------------------
// ----------------Select----------------------
IMPLEMT_VERIFIER(Select, SelectVerify) {
if (!CheckTwoInputDtypeSame(op, "x1", "x2")) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
TbeGetName(op),
string("call function CheckTwoInputDtypeSame failed, data type of input[x1] is not same as input[x2]"));
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(SelectInferShape) {
if (!TwoInOneOutDynamicInferNoBroadcast(op, "x1", "x2", {"y"})) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
TbeGetName(op), string("call function TwoInOneOutDynamicInferNoBroadcast failed, update output[y] desc failed"));
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(Select, SelectInferShape);
VERIFY_FUNC_REG(Select, SelectVerify);
// ---------------Select END-----------------------
// ----------------ReverseV2 Op Begin-----------------
IMPLEMT_COMMON_INFERFUNC(ReverseV2InferShape) {
const vector<string> depend_names = {"axis"};
PREPARE_DYNAMIC_SHAPE(depend_names);
if (OneInOneOutDynamicInfer(op, "x", {"y"})) {
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}
COMMON_INFER_FUNC_REG(ReverseV2, ReverseV2InferShape);
// ----------------ReverseV2 Op End-------------------
// ----------------ScatterNd-------------------
IMPLEMT_COMMON_INFERFUNC(ScatterNdInferShape) {
vector<string> input_infer_depends = {"shape"};
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
op_desc->SetOpInferDepends(input_infer_depends);
auto output_desc = op_desc->MutableOutputDesc("y");
auto shape_desc = op_desc->MutableInputDesc("shape");
std::vector<int64_t> shape_shape = shape_desc->MutableShape().GetDims();
std::vector<std::pair<int64_t, int64_t>> out_range;
Tensor shape;
std::vector<int64_t> const_data;
if (GRAPH_SUCCESS != op.GetInputConstData("shape", shape)) {
const_data = {-2};
} else {
auto data_type = shape_desc->GetDataType();
if (!GetConstIntData(shape, data_type, const_data)) {
USER_GE_LOGE("Invalid data type of shape, data_type is %d.", (int)data_type);
return GRAPH_FAILED;
}
}
vector<int64_t> shape_dims;
if (shape_shape.size() == 1 && shape_shape[0] > 0 && IsUnknownRankShape(const_data)) {
for (int64_t i = 0; i < shape_shape[0]; i++) {
shape_dims.push_back(-1);
}
} else {
for (size_t i = 0; i < (uint32_t)const_data.size(); ++i) {
shape_dims.push_back(const_data[i]);
}
}
if (IsUnknownRankShape(shape_dims)) {
out_range.push_back(std::pair<int64_t, int64_t>(1, -1));
} else if (IsUnknownVec(shape_dims)) {
for (size_t i = 0; i < shape_dims.size(); i++) {
if (shape_dims[i] == -1) {
out_range.push_back(std::pair<int64_t, int64_t>(1, -1));
} else {
out_range.push_back(std::pair<int64_t, int64_t>(shape_dims[i], shape_dims[i]));
}
}
}
GeShape output_shape(shape_dims);
output_desc->SetShape(output_shape);
output_desc->SetShapeRange(out_range);
output_desc->SetDataType(op_desc->MutableInputDesc("x")->GetDataType());
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(ScatterNd, ScatterNdInferShape);
// ----------------ScatterNd End-------------------
// ----------------OneHot---------------------------
IMPLEMT_COMMON_INFERFUNC(OneHotInferShape) {
const vector<string> depend_names = {"depth"};
PREPARE_DYNAMIC_SHAPE(depend_names);
// get attr axis
int32_t axis = -1;
if (ge::GRAPH_SUCCESS != op.GetAttr("axis", axis)) {
std::string err_msg = GetInputInvalidErrMsg("Get const axis failed from op of 'OneHot'!\n");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (axis < -1) {
string correct_size = ConcatString("attr axis(", axis, ") must be >= -1");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), correct_size);
return GRAPH_FAILED;
}
// get all Desc info
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
static const int64_t input_x_idx = 0;
auto input_desc = op_info->MutableInputDesc(input_x_idx);
const ge::GeShape &input_shape = input_desc->MutableShape();
static const int64_t input_on_value_idx = 2;
auto value_desc = op_info->MutableInputDesc(input_on_value_idx);
DataType value_dtype = value_desc->GetDataType();
// output desc and set dtype
static const int64_t output_y_idx = 0;
auto output_desc = op_info->MutableOutputDesc(output_y_idx);
output_desc->SetDataType(value_dtype);
if (input_shape.IsUnknownDimNum()) {
// input is UnknownRank, set output UnknownRank
OP_LOGW("OneHot", "input shape is UnknownRank, set output UnknownRank");
output_desc->SetShape(input_shape);
return GRAPH_SUCCESS;
}
// update axis to positive number
int32_t dimnum = input_shape.GetDimNum();
if (axis > dimnum) {
string correct_size = ConcatString("attr axis(", axis, ") must be < ", input_shape.GetDimNum());
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), correct_size);
return GRAPH_FAILED;
}
// get depth const value, depth index is 1
int64_t depth_value = -1;
static const int64_t input_depth_idx = 1;
if (!ops::GetConstInt(op, input_depth_idx, depth_value)) {
OP_LOGW("OneHot", "Get depth const tensor failed, set depth -1");
}
// update output shape
ge::GeShape &output_shape = output_desc->MutableShape();
output_shape.SetDimNum(dimnum + 1);
if (-1 == axis) {
for (int32_t i = 0; i < dimnum; i++) {
output_shape.SetDim(i, input_shape.GetDim(i));
}
output_shape.SetDim(dimnum, depth_value);
} else {
while (dimnum > axis) {
output_shape.SetDim(dimnum, input_shape.GetDim(dimnum - 1));
dimnum--;
}
output_shape.SetDim(axis, depth_value);
for (int32_t i = 0; i < axis; i++) {
output_shape.SetDim(i, input_shape.GetDim(i));
}
}
// if output shape is dynamic update output range
if (output_shape.IsUnknownShape()) {
output_desc->SetOriginShape(output_shape);
std::vector<std::pair<int64_t, int64_t>> input_range;
input_desc->GetShapeRange(input_range);
MakeUpShapeRange(input_shape, input_range);
std::pair<int64_t, int64_t> depth_range =
depth_value == -1 ? std::pair<int64_t, int64_t>(1, -1) : std::pair<int64_t, int64_t>(depth_value, depth_value);
if (-1 == axis) {
input_range.insert(input_range.end(), depth_range);
} else {
input_range.insert(input_range.begin() + axis, depth_range);
}
output_desc->SetShapeRange(input_range);
}
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(OneHot, OneHotInferShape);
// ----------------OneHot END----------------------
// ----------------UnsortedSegmentSum-------------------
static void GetUnsortedSegmentSumConstValue(const Tensor &const_tensor, const DataType &dtype, int64_t &const_data) {
if (dtype == ge::DT_INT32) {
int32_t *const_data_ptr = (int32_t *)const_tensor.GetData();
const_data = *const_data_ptr;
} else {
int64_t *const_data_ptr = (int64_t *)const_tensor.GetData();
const_data = *const_data_ptr;
}
}
static void GetRealRange(ge::GeShape shape, std::vector<std::pair<int64_t, int64_t>> &range) {
if (shape.IsUnknownDimNum()) {
return;
}
if (range.empty()) {
for (size_t i = 0; i < shape.GetDimNum(); i++) {
int64_t dim = shape.GetDim(i);
if (dim == -1) {
range.push_back(std::pair<int64_t, int64_t>(0, -1));
} else {
range.push_back(std::pair<int64_t, int64_t>(dim, dim));
}
}
}
}
IMPLEMT_COMMON_INFERFUNC(UnsortedSegmentSumInferShape) {
PROFILING_PROTO_INIT(TbeGetName(op).c_str());
vector<string> input_infer_depends = {"num_segments"};
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
op_desc->SetOpInferDepends(input_infer_depends);
Tensor input_num_segments_tensor;
int64_t input_num_segments;
DataType input_num_segments_dtype = op_desc->GetInputDescPtr(2)->GetDataType();
std::vector<std::pair<int64_t, int64_t>> shape_range_x;
op_desc->GetInputDescPtr(0)->GetShapeRange(shape_range_x);
std::vector<std::pair<int64_t, int64_t>> shape_range_seg_id;
op_desc->GetInputDescPtr(1)->GetShapeRange(shape_range_seg_id);
std::vector<std::pair<int64_t, int64_t>> out_range;
if (GRAPH_SUCCESS != op.GetInputConstData("num_segments", input_num_segments_tensor)) {
input_num_segments = -1;
out_range.push_back(std::pair<int64_t, int64_t>(0, -1));
} else {
GetUnsortedSegmentSumConstValue(input_num_segments_tensor, input_num_segments_dtype, input_num_segments);
out_range.push_back(std::pair<int64_t, int64_t>(input_num_segments, input_num_segments));
}
ge::GeShape shape = op_desc->GetInputDescPtr(0)->GetShape();
ge::GeShape shape_id = op_desc->GetInputDescPtr(1)->GetShape();
auto output_desc = op_desc->MutableOutputDesc(0);
ge::GeShape output_shape = output_desc->MutableShape();
GetRealRange(shape, shape_range_x);
GetRealRange(shape_id, shape_range_seg_id);
int64_t dim_idsize_input = shape_id.GetDimNum();
int64_t dim_size_input = shape.GetDimNum();
DataType input_dtype = op_desc->GetInputDescPtr(0)->GetDataType();
PROFILING_PROTO_AFTER_GET_SHAPE_REG();
if (shape.IsUnknownDimNum() || shape_id.IsUnknownDimNum()) {
if (shape.IsUnknownDimNum()) {
output_desc->SetShape(shape);
output_desc->SetDataType(input_dtype);
} else {
output_desc->SetShape(shape_id);
output_desc->SetDataType(input_dtype);
}
return GRAPH_SUCCESS;
} else if (dim_idsize_input > 1) {
size_t rank = dim_size_input - dim_idsize_input + 1;
size_t idx = 1;
output_shape.SetDimNum(rank);
output_shape.SetDim(0, input_num_segments);
for (int64_t i = dim_idsize_input; i < dim_size_input; i++) {
int64_t x_dim = shape.GetDim(i);
output_shape.SetDim(idx, x_dim);
if ((size_t)i < shape_range_x.size()) {
out_range.push_back(shape_range_x[i]);
}
idx++;
}
} else {
size_t rank = shape.GetDimNum();
output_shape.SetDimNum(rank);
output_shape.SetDim(0, input_num_segments);
for (size_t i = 1; i < rank; i++) {
int64_t x_dim = shape.GetDim(i);
output_shape.SetDim(i, x_dim);
if ((size_t)i < shape_range_x.size()) {
out_range.push_back(shape_range_x[i]);
}
}
}
PROFILING_PROTO_AFTER_INFER_SHAPE_REG();
output_desc->SetShape(output_shape);
output_desc->SetDataType(input_dtype);
output_desc->SetShapeRange(out_range);
PROFILING_PROTO_END();
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(UnsortedSegmentSum, UnsortedSegmentSumInferShape);
// ----------------UnsortedSegmentSum END----------------
// ----------------Slice Op Begin ----------------------
static void GetSliceConstValue(const Tensor &const_tensor, const DataType &dtype, std::vector<int64_t> &const_data) {
size_t size = 0;
if (dtype == ge::DT_INT32) {
int32_t *const_data_ptr = (int32_t *)const_tensor.GetData();
size = const_tensor.GetSize() / sizeof(int32_t);
for (size_t i = 0; i < size; ++i) {
const_data.push_back((int32_t)((*(const_data_ptr + i))));
}
} else {
int64_t *const_data_ptr = (int64_t *)const_tensor.GetData();
size = const_tensor.GetSize() / sizeof(int64_t);
for (size_t i = 0; i < size; ++i) {
const_data.push_back(((int64_t)(*(const_data_ptr + i))));
}
}
}
IMPLEMT_COMMON_INFERFUNC(SliceInferShape) {
const vector<string> depend_names = {"offsets", "size"};
PREPARE_DYNAMIC_SHAPE(depend_names);
Tensor input_begin_tensor;
Tensor input_size_tensor;
auto input_desc = op.GetInputDescByName("x");
const Shape shape = input_desc.GetShape();
DataType input_dtype = input_desc.GetDataType();
std::vector<int64_t> input_begin;
std::vector<int64_t> input_size;
bool has_offsets = true;
if (op.GetInputConstData("offsets", input_begin_tensor) != GRAPH_SUCCESS) {
OP_LOGI(TbeGetName(op).c_str(), "Get offsets failed.");
has_offsets = false;
} else {
DataType input_begin_dtype = op.GetInputDescByName("offsets").GetDataType();
GetSliceConstValue(input_begin_tensor, input_begin_dtype, input_begin);
}
bool has_size = true;
if (op.GetInputConstData("size", input_size_tensor) != GRAPH_SUCCESS) {
OP_LOGI(TbeGetName(op).c_str(), "Get size failed.");
has_size = false;
} else {
DataType input_size_dtype = op.GetInputDescByName("size").GetDataType();
GetSliceConstValue(input_size_tensor, input_size_dtype, input_size);
}
bool is_unknown_rank = !has_size && !has_offsets && shape.GetDims() == UNKNOWN_RANK;
if (is_unknown_rank) {
TensorDesc output_desc = op.GetOutputDescByName("y");
output_desc.SetDataType(input_dtype);
Shape outputShape(UNKNOWN_RANK);
output_desc.SetShape(outputShape);
OP_LOGD(TbeGetName(op).c_str(), "output_shape:%s", to_string(output_desc.GetShape()).c_str());
(void)op.UpdateOutputDesc("y", output_desc);
return GRAPH_SUCCESS;
}
auto shape_dims = shape.GetDims();
if (shape.GetDims() == UNKNOWN_RANK) {
shape_dims.assign(std::max(input_begin.size(), input_size.size()), -1);
}
size_t dimNum = shape_dims.size();
std::vector<int64_t> outputList;
vector<pair<int64_t, int64_t>> ranges;
input_desc.GetShapeRange(ranges);
if (ranges.empty()) {
MakeUpShapeRange(shape_dims, ranges);
}
if (ranges.size() < dimNum) {
OP_LOGE(TbeGetName(op).c_str(), "ranges.size is:%ld, smaller than dimNum, dimNum is %ld.", ranges.size(), dimNum);
return GRAPH_FAILED;
}
if (!has_size && !has_offsets) {
for (size_t i = 0; i < dimNum; ++i) {
outputList.push_back(-1);
ranges[i].first = 0;
}
} else if (!has_offsets && has_size) {
for (size_t i = 0; i < dimNum; ++i) {
if (input_size[i] == -1) {
outputList.push_back(-1);
ranges[i].first = 0;
} else {
outputList.push_back(input_size[i]);
ranges[i].first = input_size[i];
ranges[i].second = input_size[i];
}
}
} else if (has_offsets && !has_size) {
for (size_t i = 0; i < dimNum; ++i) {
outputList.push_back(-1);
ranges[i].first = 0;
if (ranges[i].second != -1) {
if (shape_dims[i] != -1) {
ranges[i].second = std::min(ranges[i].second, shape_dims[i]);
}
ranges[i].second -= input_begin[i];
}
}
} else {
for (size_t i = 0; i < dimNum; ++i) {
if (input_size[i] == -1) {
if (shape_dims[i] == -1) {
outputList.push_back(-1);
} else {
outputList.push_back(shape_dims[i] - input_begin[i]);
}
ranges[i].first = 0;
} else {
outputList.push_back(input_size[i]);
ranges[i].first = input_size[i];
ranges[i].second = input_size[i];
}
}
}
TensorDesc tensordesc_output = op.GetOutputDescByName("y");
tensordesc_output.SetDataType(input_dtype);
if (IsUnKnownShape(outputList)) {
tensordesc_output.SetShapeRange(ranges);
OP_LOGD(TbeGetName(op).c_str(), "output_ranges:%s", to_string(ranges).c_str());
}
Shape outputShape(outputList);
tensordesc_output.SetShape(outputShape);
OP_LOGD(TbeGetName(op).c_str(), "output_ranges:%s", to_string(ranges).c_str());
OP_LOGD(TbeGetName(op).c_str(), "offset:%s", to_string(input_begin).c_str());
OP_LOGD(TbeGetName(op).c_str(), "size:%s", to_string(input_size).c_str());
OP_LOGD(TbeGetName(op).c_str(), "output_shape:%s", to_string(tensordesc_output.GetShape()).c_str());
(void)op.UpdateOutputDesc("y", tensordesc_output);
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(Slice, SliceInferShape);
// ----------------Slice Op END ----------------------
} // namespace ge

View File

@ -1,53 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/trace_grad_op.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
namespace ge {
// ----------------TraceGrad Begin------------------------
IMPLEMT_COMMON_INFERFUNC(TraceGradInferShape) {
Shape shape = op.GetInputDescByName("y_grad").GetShape();
DataType input_dtype = op.GetInputDescByName("y_grad").GetDataType();
Tensor tensor_input;
std::vector<int64_t> dim_vector;
if (op.GetInputConstData("x_shape", tensor_input) == GRAPH_SUCCESS) {
uint8_t *input_shape = tensor_input.GetData();
for (int64_t i = 0; i < 2; i++) {
dim_vector.push_back(shape.GetDim(*(input_shape + i)));
}
}
Shape output_shape(dim_vector);
TensorDesc td = op.GetOutputDescByName("x_grad");
td.SetShape(output_shape);
td.SetDataType(input_dtype);
td.SetFormat(FORMAT_ND);
(void)op.UpdateOutputDesc("x_grad", td);
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(TraceGrad, TraceGradVerify) {
DataType x_shape_dtype = op.GetInputDescByName("x_shape").GetDataType();
if ((x_shape_dtype != DT_INT32) || (x_shape_dtype != DT_INT64)) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(TraceGrad, TraceGradInferShape);
CUST_VERIFY_FUNC_REG(TraceGrad, TraceGradVerify);
// ---------------TraceGrad END-------------------------------
} // namespace ge

View File

@ -7,7 +7,9 @@
#include "inc/ops/transformation_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/op_const.h"
#include "utils/common_shape_fns.h"
#include "utils/vector_proto_profiling.h"
namespace ge {
// ------------------DepthToSpace------------------
@ -185,4 +187,149 @@ IMPLEMT_COMMON_INFERFUNC(DepthToSpaceInfer) {
COMMON_INFER_FUNC_REG(DepthToSpace, DepthToSpaceInfer);
VERIFY_FUNC_REG(DepthToSpace, DepthToSpaceVerify);
// -------------------DepthToSpace END-----------------
// -------------------Transpose-----------------
static graphStatus TransposeCommonInferShape(const std::vector<int64_t> &perm_list, Operator &op) {
PROFILING_PROTO_INIT(TbeGetName(op).c_str());
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
const int64_t input_x_idx = 0;
auto input_desc = op_info->MutableInputDesc(input_x_idx);
const int64_t output_y_idx = 0;
auto output_desc = op_info->MutableOutputDesc(output_y_idx);
auto input_dtype = input_desc->GetDataType();
const GeShape &input_ge_shape = input_desc->MutableShape();
int64_t input_shape_len = input_ge_shape.GetDimNum();
PROFILING_PROTO_AFTER_GET_SHAPE_REG();
if (IsUnknownRankShape(input_ge_shape)) {
// UnknownRankShape, set shape is -1, -1, -1....
std::vector<int64_t> out_vec(perm_list.size(), -1);
output_desc->SetShape(GeShape(out_vec));
output_desc->SetDataType(input_dtype);
return GRAPH_SUCCESS;
}
// infer the shape
GeShape &output_ge_shape = output_desc->MutableShape();
output_ge_shape.SetDimNum(input_shape_len);
for (size_t i = 0; i < perm_list.size(); ++i) {
// verify perm_list begin
int64_t perm_value = perm_list[i] < 0 ? perm_list[i] + input_shape_len : perm_list[i];
if (perm_value >= input_shape_len) {
std::string err_msg = GetAttrValueErrMsg("perm", ConcatString(perm_value),
ConcatString("less than input shape size[", input_shape_len, "]"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
// verify perm_list end
// set the output shape
output_ge_shape.SetDim(i, input_ge_shape.GetDim(perm_value));
}
PROFILING_PROTO_AFTER_INFER_SHAPE_REG();
// set output dtype as the same with input x
output_desc->SetDataType(input_dtype);
// infer the range, when need
if (output_ge_shape.IsUnknownShape()) {
output_desc->SetOriginShape(output_ge_shape);
std::vector<std::pair<int64_t, int64_t>> input_range;
std::vector<std::pair<int64_t, int64_t>> output_range;
input_desc->GetShapeRange(input_range);
MakeUpShapeRange(input_ge_shape, input_range);
for (size_t i = 0; i < perm_list.size(); ++i) {
output_range.push_back(input_range[perm_list[i]]);
}
output_desc->SetShapeRange(output_range);
return GRAPH_SUCCESS;
}
PROFILING_PROTO_END();
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(TransposeInferShape) {
const vector<string> depend_names = {"perm"};
PREPARE_DYNAMIC_SHAPE(depend_names);
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
bool perm_done = true;
std::vector<int64_t> perm_list;
static const int64_t perm_input_index = 1;
if (!(ops::GetConstIntData(op, perm_input_index, perm_list))) {
perm_done = false;
OP_LOGW(TbeGetName(op), "Get Const perm value failed ");
}
// perm is const node , will do infer use function TransposeCommonInferShape
if (perm_done) {
if (GRAPH_SUCCESS != TransposeCommonInferShape(perm_list, op)) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
// perm is not const node, infer for aicpu
static const int64_t x_input_index = 0;
static const int64_t y_output_index = 0;
auto input_desc = op_desc->MutableInputDesc(x_input_index);
auto input_shape = input_desc->MutableShape().GetDims();
auto input_dtype = input_desc->GetDataType();
auto output_desc = op_desc->MutableOutputDesc(y_output_index);
// set output dtype as the same with input x
output_desc->SetDataType(input_dtype);
if (IsUnknownRankShape(input_shape)) {
auto perm_desc = op_desc->MutableInputDesc("perm");
auto perm_shape = perm_desc->MutableShape().GetDims();
if (IsUnknown(perm_shape)) {
// set output is -2 UnknownRank
OP_LOGW(TbeGetName(op), "the output will be set to -2");
output_desc->SetShape(GeShape(input_shape));
output_desc->SetOriginShape(GeShape(input_shape));
return GRAPH_SUCCESS;
}
// pert is not dynamic shape, will update the input shape
if (perm_shape.empty()) {
perm_shape.push_back(1);
}
input_shape.clear();
for (auto i = 0; i < perm_shape[0]; ++i) {
input_shape.push_back(-1);
}
}
// begin to infer shape and range
std::vector<std::pair<int64_t, int64_t>> input_range;
std::vector<std::pair<int64_t, int64_t>> output_range;
vector<int64_t> out_vec;
input_desc->GetShapeRange(input_range);
MakeUpShapeRange(input_shape, input_range);
int64_t range_first = input_range[0].first;
int64_t range_second = input_range[0].second;
for (size_t i = 0; i < input_range.size(); ++i) {
// all range is the same and get the shape range
range_first = std::min(range_first, input_range[i].first);
range_second =
(range_second == -1 || input_range[i].second == -1) ? -1 : std::max(range_second, input_range[i].second);
}
for (size_t i = 0; i < input_range.size(); ++i) {
out_vec.push_back(-1);
output_range.push_back(std::pair<int64_t, int64_t>(range_first, range_second));
}
output_desc->SetShape(GeShape(out_vec));
output_desc->SetOriginShape(GeShape(out_vec));
output_desc->SetShapeRange(output_range);
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(Transpose, TransposeInferShape);
// -------------------Transpose END-----------------
} // namespace ge

View File

@ -77,9 +77,11 @@
namespace ge {
// enum type and string type mapping
const std::map<ge::DataType, std::string> DTYPE_STR_MAP{
{ge::DT_FLOAT16, "float16"}, {ge::DT_FLOAT, "float32"}, {ge::DT_INT8, "int8"}, {ge::DT_INT16, "int16"},
{ge::DT_INT32, "int32"}, {ge::DT_INT64, "int64"}, {ge::DT_UINT8, "uint8"}, {ge::DT_UINT16, "uint16"},
{ge::DT_UINT32, "uint32"}, {ge::DT_UINT64, "uint64"}, {ge::DT_BOOL, "bool"}, {ge::DT_INT4, "int4"},
{ge::DT_DOUBLE, "double"}, {ge::DT_COMPLEX64, "complex64"}, {ge::DT_COMPLEX128, "complex128"},
{ge::DT_FLOAT16, "float16"}, {ge::DT_FLOAT, "float32"}, {ge::DT_INT8, "int8"},
{ge::DT_INT16, "int16"}, {ge::DT_INT32, "int32"}, {ge::DT_INT64, "int64"},
{ge::DT_UINT8, "uint8"}, {ge::DT_UINT16, "uint16"}, {ge::DT_UINT32, "uint32"},
{ge::DT_UINT64, "uint64"}, {ge::DT_BOOL, "bool"}, {ge::DT_INT4, "int4"},
{ge::DT_BF16, "bfloat16"}};
// define the input num of shape

View File

@ -95,6 +95,7 @@ cust_op_lists = [
"leftshift",
"lessequal",
"listdiff",
"lgamma",
"log",
"log1p",
"lognormalreverse",
@ -102,7 +103,6 @@ cust_op_lists = [
"lowerbound",
"lusolve",
"luunpackgrad",
"maskedselect",
"maskedselectgrad",
"matrixdeterminant",
"matrixexp",
@ -111,6 +111,7 @@ cust_op_lists = [
"matrixtriangularsolve",
"maxpool3dgradwithargmax",
"maxpool3dwithargmax",
"meshgrid",
"mul",
"mulnonan",
"multimarginloss",
@ -123,7 +124,73 @@ cust_op_lists = [
"gatherdgradv2",
"isnan",
"maskedselectgrad",
"slicegrad"
"slicegrad",
"orgqr",
"tracegrad",
"nonmaxsuppressionwithoverlaps",
"nthelement",
"onehot",
"orgqr",
"padv3",
"padv3grad",
"parameterizedtruncatednormal",
"pow",
"qr",
"raggedrange",
"randompoisson",
"reciprocal",
"reciprocalgrad",
"reducemean",
"reduceprod",
"reducesum",
"resizearea",
"resizebicubic",
"resizebicubicgrad",
"resizenearestneighborv2",
"resizenearestneighborv2grad",
"reversev2",
"rgbtohsv",
"rightshift",
"rsqrtgrad",
"sampledistortedboundingboxv2",
"scaleandtranslate",
"scaleandtranslategrad",
"scatternd",
"scatterndupdate",
"segmentsum",
"select",
"sign",
"sin",
"sinh",
"slice",
"smoothl1loss",
"smoothl1lossgrad",
"split",
"sqrt",
"sqrtgrad",
"stack",
"tanh",
"tensorscatterupdate",
"tile",
"trace",
"tracegrad",
"transpose",
"tril",
"truncatednormal",
"uniqueconsecutive",
"unravelindex",
"unsortedsegmentsum",
"unstack",
"upperbound",
"xdivy",
"xlogy",
"zeroslike",
"flatten",
"maxpoolv1",
"norepeatngram",
"randint",
"reversesequence",
"standardlaplace"
]

View File

@ -0,0 +1,123 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/aicpu/aicpu_ops/meshgrid_kernels.h"
#include <vector>
#include <string>
#include <map>
#include "Eigen/Core"
#include "unsupported/Eigen/CXX11/Tensor"
#include "common/kernel_log.h"
#include "common/kernel_errcode.h"
#include "common/tensor.h"
#include "aicpu_sharder/aicpu_sharder.h"
#include "proto/aicpu_tensor.pb.h"
namespace aicpu {
template <typename T>
uint32_t MeshgridTask(const std::vector<uintptr_t> &io_addrs_, const std::string &indexing, size_t ndim,
const std::vector<int> &bcast) {
auto shards = [&](const int64_t begin, const int64_t end) {
for (int i = begin; i < end; ++i) { // 0~ndim
auto new_i = i;
auto s = bcast;
if (indexing == "xy" && i < 2) {
new_i = 1 - i;
auto tmp = s[0];
s[0] = s[1];
s[1] = tmp;
}
size_t row_ = 1;
size_t col_ = 1;
for (int j = 0; j <= new_i; j++) {
row_ *= s[j];
}
for (int j = new_i + 1; j < static_cast<int>(s.size()); j++) {
col_ *= s[j];
}
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>> input_map(reinterpret_cast<T *>(io_addrs_[i]), bcast[i],
1);
const auto &input = Eigen::Tensor<T, 2, Eigen::RowMajor>(input_map);
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>> output(reinterpret_cast<T *>(io_addrs_[ndim + i]), row_,
col_);
Eigen::Tensor<T, 2, Eigen::RowMajor> origin(bcast[i], row_ * col_ / bcast[i]);
for (int c = 0; c < bcast[i]; ++c) {
for (int r = 0; r < static_cast<int>(row_ * col_ / bcast[i]); ++r) {
origin(c, r) = input(c, 0);
}
}
for (size_t j = 0; j < row_ * col_ / bcast[i] / col_; ++j) {
Eigen::array<int64_t, 2> offsets_in = {0, static_cast<int64_t>(col_ * j)};
Eigen::array<int64_t, 2> offsets_out = {static_cast<int64_t>(bcast[i] * j), 0};
Eigen::array<int64_t, 2> extents = {static_cast<int64_t>(bcast[i]), static_cast<int64_t>(col_)};
output.slice(offsets_out, extents) = origin.slice(offsets_in, extents);
}
}
};
const int64_t perUnitSize = 1; // shard unit size
ParallelFor(ndim, perUnitSize, shards);
return kAicpuKernelStateSucess;
}
uint32_t MeshgridKernel::DoCompute() {
std::map<int, std::function<uint32_t(std::vector<uintptr_t> &, std::string &, size_t &, std::vector<int> &)>> calls;
calls[aicpuops::DataType::MS_INT8] = MeshgridTask<int8_t>;
calls[aicpuops::DataType::MS_INT16] = MeshgridTask<int16_t>;
calls[aicpuops::DataType::MS_INT32] = MeshgridTask<int32_t>;
calls[aicpuops::DataType::MS_INT64] = MeshgridTask<int64_t>;
calls[aicpuops::DataType::MS_FLOAT16] = MeshgridTask<Eigen::half>;
calls[aicpuops::DataType::MS_FLOAT32] = MeshgridTask<float>;
calls[aicpuops::DataType::MS_FLOAT64] = MeshgridTask<double>;
calls[aicpuops::DataType::MS_UINT8] = MeshgridTask<uint8_t>;
calls[aicpuops::DataType::MS_UINT16] = MeshgridTask<uint16_t>;
calls[aicpuops::DataType::MS_UINT32] = MeshgridTask<uint32_t>;
calls[aicpuops::DataType::MS_UINT64] = MeshgridTask<uint64_t>;
calls[aicpuops::DataType::MS_BOOL] = MeshgridTask<bool>;
return calls[input_type_](io_addrs_, indexing_, ndim_, bcast_);
}
uint32_t MeshgridKernel::ParseKernelParam() {
::google::protobuf::Map<::std::string, ::aicpuops::AttrValue> nodedef_map = node_def_.attrs();
indexing_ = nodedef_map["indexing"].s();
input_type_ = static_cast<aicpuops::DataType>(node_def_.inputs(0).tensor_type());
ndim_ = node_def_.inputs_size();
bcast_.resize(ndim_);
for (int n = 0; n < node_def_.inputs_size(); ++n) {
aicpuops::Tensor input_tensor = node_def_.inputs(n);
aicpuops::TensorShape input_shape = input_tensor.tensor_shape();
if (input_shape.dim().size() != 1) {
AICPU_LOGE("input tensor should be 1-D.");
}
bcast_[n] = input_shape.dim(0).size();
}
return kAicpuKernelStateSucess;
}
} // namespace aicpu
extern "C" {
__attribute__((visibility("default"))) uint32_t Meshgrid(void *param) {
aicpu::MeshgridKernel meshgridKernel;
return meshgridKernel.Compute(param);
}
}

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef _AICPU_MESHGRID_KERNELS_H_
#define _AICPU_MESHGRID_KERNELS_H_
#include <vector>
#include <string>
#include "common/kernel_base.h"
#include "Eigen/Core"
#include "unsupported/Eigen/CXX11/Tensor"
namespace aicpu {
class MeshgridKernel : public KernelBase {
public:
explicit MeshgridKernel() : KernelBase("Meshgrid") {}
~MeshgridKernel() = default;
aicpuops::DataType input_type_;
std::string indexing_;
size_t ndim_ = 0;
std::vector<int> bcast_;
protected:
uint32_t DoCompute() override;
uint32_t ParseKernelParam() override;
};
} // namespace aicpu
#endif

View File

@ -22,13 +22,13 @@
/* clang-format off */
namespace ge {
REG_OP(Meshgrid)
REG_CUST_OP(Meshgrid)
.DYNAMIC_INPUT(x, TensorType({DT_INT8, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_BOOL}))
.DYNAMIC_OUTPUT(y, TensorType({DT_INT8, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_BOOL}))
.ATTR(indexing, String, "")
.OP_END_FACTORY_REG(Meshgrid)
.CUST_OP_END_FACTORY_REG(Meshgrid)
REG_CUST_OP(SliceGrad)
.INPUT(dy, TensorType({DT_BOOL, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64,
@ -104,5 +104,11 @@ REG_CUST_OP(LogSpace)
.REQUIRED_ATTR(steps, Int)
.REQUIRED_ATTR(base, Int)
.CUST_OP_END_FACTORY_REG(LogSpace)
REG_CUST_OP(Expand)
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8, DT_BOOL}))
.INPUT(shape, TensorType({DT_INT16, DT_INT32, DT_INT64}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8, DT_BOOL}))
.CUST_OP_END_FACTORY_REG(Expand)
} // namespace ge
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_ARRAY_OPS_H_

View File

@ -87,6 +87,25 @@ REG_CUST_OP(Gcd)
.INPUT(x2, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(y, TensorType({DT_INT32, DT_INT64}))
.CUST_OP_END_FACTORY_REG(Gcd)
REG_CUST_OP(Orgqr)
.INPUT(x, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT}))
.INPUT(tau, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT}))
.OUTPUT(y, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT}))
.CUST_OP_END_FACTORY_REG(Orgqr)
REG_CUST_OP(TraceGrad)
.INPUT(y_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
DT_UINT32, DT_UINT64, DT_UINT8}))
.INPUT(x_shape, TensorType({DT_INT64}))
.OUTPUT(x_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
DT_UINT32, DT_UINT64, DT_UINT8}))
.CUST_OP_END_FACTORY_REG(TraceGrad)
REG_CUST_OP(Lgamma)
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT32}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.CUST_OP_END_FACTORY_REG(Lgamma)
} // namespace ge
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_MATH_OPS_H_

View File

@ -55,6 +55,20 @@ REG_CUST_OP(AdaptiveAvgPool2DGrad)
.OUTPUT(output_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.CUST_OP_END_FACTORY_REG(AdaptiveAvgPool2DGrad)
REG_CUST_OP(AdaptiveMaxPool3d)
.INPUT(x, TensorType::RealNumberType())
.INPUT(output_size, TensorType({DT_INT32}))
.OUTPUT(y, TensorType::RealNumberType())
.OUTPUT(argmax, TensorType({DT_INT32}))
.CUST_OP_END_FACTORY_REG(AdaptiveMaxPool3d)
REG_CUST_OP(AdaptiveMaxPool3dGrad)
.INPUT(input_grad, TensorType::RealNumberType())
.INPUT(x, TensorType::RealNumberType())
.INPUT(argmax, TensorType({DT_INT32}))
.OUTPUT(output_grad, TensorType::RealNumberType())
.CUST_OP_END_FACTORY_REG(AdaptiveMaxPool3dGrad)
REG_CUST_OP(MultiMarginLossGrad)
.INPUT(y_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))

View File

@ -19,6 +19,7 @@
#include "graph/operator_reg.h"
#include "graph/operator.h"
#include "transform/graph_ir/custom_op_proto/op_proto_macro.h"
/* clang-format off */
@ -29,5 +30,12 @@ REG_OP(KVCacheMgr)
.INPUT(index, TensorType({DT_INT32}))
.OUTPUT(past, TensorType({DT_FLOAT16}))
.OP_END_FACTORY_REG(KVCacheMgr)
REG_CUST_OP(NoRepeatNGram)
.INPUT(state_seq, TensorType({DT_INT32}))
.INPUT(log_probs, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.OUTPUT(out, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.REQUIRED_ATTR(ngram_size, Int)
.CUST_OP_END_FACTORY_REG(NoRepeatNGram)
} // namespace ge
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_OTHER_OPS_H_

View File

@ -39,5 +39,14 @@ REG_CUST_OP(Dropout2D)
.OUTPUT(mask, TensorType({DT_BOOL}))
.REQUIRED_ATTR(keep_prob, Float)
.CUST_OP_END_FACTORY_REG(Dropout2D)
REG_CUST_OP(StandardLaplace)
.INPUT(shape, TensorType({DT_INT32, DT_INT64}))
.INPUT(seed, TensorType({DT_INT64}))
.INPUT(seed2, TensorType({DT_INT64}))
.OUTPUT(output, TensorType({DT_FLOAT}))
.REQUIRED_ATTR(seed, Int)
.REQUIRED_ATTR(seed2, Int)
.CUST_OP_END_FACTORY_REG(StandardLaplace)
} // namespace ge
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_RANDOM_OPS_H_

View File

@ -378,6 +378,7 @@ constexpr const char kNameDynamicShape[] = "DynamicShape";
constexpr const char kNameGather[] = "Gather";
constexpr const char kNameUnsqueeze[] = "Unsqueeze";
constexpr const char kNamePadV3[] = "PadV3";
constexpr const char kNamePadV3Grad[] = "PadV3Grad";
constexpr const char kNamePadV2[] = "PadV2";
constexpr const char kNameGlobalAvgPool[] = "GlobalAveragePool";
constexpr const char kNameAdaptiveMaxPool2D[] = "AdaptiveMaxPool2D";

View File

@ -211,11 +211,11 @@ OUTPUT_MAP(Size) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Size, kNameSize, ADPT_DESC(Size))
// Meshgrid
INPUT_MAP(Meshgrid) = EMPTY_INPUT_MAP;
DYN_INPUT_MAP(Meshgrid) = {{1, DYN_INPUT_DESC(x)}};
ATTR_MAP(Meshgrid) = {{"indexing", ATTR_DESC(indexing, AnyTraits<std::string>())}};
DYN_OUTPUT_MAP(Meshgrid) = {{0, DYN_OUTPUT_DESC(y)}};
REG_ADPT_DESC(Meshgrid, prim::kPrimMeshgrid->name(), ADPT_DESC(Meshgrid))
CUST_INPUT_MAP(Meshgrid) = EMPTY_INPUT_MAP;
CUST_DYN_INPUT_MAP(Meshgrid) = {{1, DYN_INPUT_DESC(x)}};
CUST_ATTR_MAP(Meshgrid) = {{"indexing", ATTR_DESC(indexing, AnyTraits<std::string>())}};
CUST_DYN_OUTPUT_MAP(Meshgrid) = {{0, DYN_OUTPUT_DESC(y)}};
REG_ADPT_DESC(Meshgrid, prim::kPrimMeshgrid->name(), CUST_ADPT_DESC(Meshgrid))
// SliceGrad
CUST_INPUT_MAP(SliceGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(begin)}, {4, INPUT_DESC(size)}};
@ -306,4 +306,30 @@ CUST_ATTR_MAP(LogSpace) = {{"steps", ATTR_DESC(steps, AnyTraits<int64_t>())},
{"base", ATTR_DESC(base, AnyTraits<int64_t>())}};
CUST_OUTPUT_MAP(LogSpace) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(LogSpace, prim::kPrimLogSpace->name(), CUST_ADPT_DESC(LogSpace));
// UniqueConsecutive
INPUT_MAP(UniqueConsecutive) = {{1, INPUT_DESC(x)}};
ATTR_MAP(UniqueConsecutive) = {{"return_idx", ATTR_DESC(return_idx, AnyTraits<bool>())},
{"return_counts", ATTR_DESC(return_counts, AnyTraits<bool>())},
{"axis", ATTR_DESC(axis, AnyTraits<int64_t>())}};
OUTPUT_MAP(UniqueConsecutive) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(idx)}, {2, OUTPUT_DESC(count)}};
REG_ADPT_DESC(UniqueConsecutive, prim::kPrimUniqueConsecutive->name(), ADPT_DESC(UniqueConsecutive));
// UpperBound
INPUT_MAP(UpperBound) = {{1, INPUT_DESC(sorted_x)}, {2, INPUT_DESC(values)}};
ATTR_MAP(UpperBound) = EMPTY_ATTR_MAP;
OUTPUT_MAP(UpperBound) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(UpperBound, prim::kPrimUpperBound->name(), ADPT_DESC(UpperBound));
// UnravelIndex
INPUT_MAP(UnravelIndex) = {{1, INPUT_DESC(indices)}, {2, INPUT_DESC(dims)}};
ATTR_MAP(UnravelIndex) = EMPTY_ATTR_MAP;
OUTPUT_MAP(UnravelIndex) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(UnravelIndex, prim::kPrimUnravelIndex->name(), ADPT_DESC(UnravelIndex));
// NoRepeatNGram
CUST_INPUT_MAP(NoRepeatNGram) = {{1, INPUT_DESC(state_seq)}, {2, INPUT_DESC(log_probs)}};
CUST_ATTR_MAP(NoRepeatNGram) = {{"ngram_size", ATTR_DESC(ngram_size, AnyTraits<int64_t>())}};
CUST_OUTPUT_MAP(NoRepeatNGram) = {{0, OUTPUT_DESC(out)}};
REG_ADPT_DESC(NoRepeatNGram, prim::kPrimNoRepeatNGram->name(), CUST_ADPT_DESC(NoRepeatNGram));
} // namespace mindspore::transform

View File

@ -20,6 +20,7 @@
#include "inc/ops/array_ops.h"
#include "inc/ops/selection_ops.h"
#include "transform/graph_ir/custom_op_proto/cust_array_ops.h"
#include "transform/graph_ir/custom_op_proto/cust_other_ops.h"
#include "inc/ops/transformation_ops.h"
#include "transform/graph_ir/op_declare/op_declare_macro.h"
#include "utils/hash_map.h"
@ -107,9 +108,9 @@ DECLARE_OP_USE_OUTPUT(QueueData)
DECLARE_OP_ADAPTER(Size)
DECLARE_OP_USE_OUTPUT(Size)
DECLARE_OP_ADAPTER(Meshgrid)
DECLARE_OP_USE_DYN_INPUT(Meshgrid)
DECLARE_OP_USE_DYN_OUTPUT(Meshgrid)
DECLARE_CUST_OP_ADAPTER(Meshgrid)
DECLARE_CUST_OP_USE_DYN_INPUT(Meshgrid)
DECLARE_CUST_OP_USE_DYN_OUTPUT(Meshgrid)
DECLARE_CUST_OP_ADAPTER(SliceGrad)
DECLARE_CUST_OP_USE_OUTPUT(SliceGrad)
@ -149,4 +150,16 @@ DECLARE_CUST_OP_USE_OUTPUT(MvlgammaGrad)
DECLARE_CUST_OP_ADAPTER(LogSpace)
DECLARE_CUST_OP_USE_OUTPUT(LogSpace)
DECLARE_OP_ADAPTER(UniqueConsecutive)
DECLARE_OP_USE_OUTPUT(UniqueConsecutive)
DECLARE_OP_ADAPTER(UpperBound)
DECLARE_OP_USE_OUTPUT(UpperBound)
DECLARE_OP_ADAPTER(UnravelIndex)
DECLARE_OP_USE_OUTPUT(UnravelIndex)
DECLARE_CUST_OP_ADAPTER(NoRepeatNGram)
DECLARE_CUST_OP_USE_OUTPUT(NoRepeatNGram)
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_

View File

@ -189,4 +189,27 @@ ATTR_MAP(ExtractGlimpse) = {{"noise", ATTR_DESC(noise, AnyTraits<std::string>())
{"uniform_noise", ATTR_DESC(uniform_noise, AnyTraits<bool>())}};
OUTPUT_MAP(ExtractGlimpse) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ExtractGlimpse, prim::kPrimExtractGlimpse->name(), ADPT_DESC(ExtractGlimpse));
// ScaleAndTranslateGrad
INPUT_MAP(ScaleAndTranslateGrad) = {
{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}, {3, INPUT_DESC(scale)}, {4, INPUT_DESC(translation)}};
ATTR_MAP(ScaleAndTranslateGrad) = {{"kernel_type", ATTR_DESC(kernel_type, AnyTraits<std::string>())},
{"antialias", ATTR_DESC(antialias, AnyTraits<bool>())}};
OUTPUT_MAP(ScaleAndTranslateGrad) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ScaleAndTranslateGrad, prim::kPrimScaleAndTranslateGrad->name(), ADPT_DESC(ScaleAndTranslateGrad));
// ResizeBicubicGrad
INPUT_MAP(ResizeBicubicGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}};
ATTR_MAP(ResizeBicubicGrad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())},
{"half_pixel_centers", ATTR_DESC(half_pixel_centers, AnyTraits<bool>())}};
OUTPUT_MAP(ResizeBicubicGrad) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ResizeBicubicGrad, prim::kPrimResizeBicubicGrad->name(), ADPT_DESC(ResizeBicubicGrad));
// ScaleAndTranslate
INPUT_MAP(ScaleAndTranslate) = {
{1, INPUT_DESC(images)}, {2, INPUT_DESC(size)}, {3, INPUT_DESC(scale)}, {4, INPUT_DESC(translation)}};
ATTR_MAP(ScaleAndTranslate) = {{"kernel_type", ATTR_DESC(kernel_type, AnyTraits<std::string>())},
{"antialias", ATTR_DESC(antialias, AnyTraits<bool>())}};
OUTPUT_MAP(ScaleAndTranslate) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ScaleAndTranslate, prim::kPrimScaleAndTranslate->name(), ADPT_DESC(ScaleAndTranslate));
} // namespace mindspore::transform

View File

@ -77,4 +77,13 @@ DECLARE_OP_USE_OUTPUT(AdjustHue)
DECLARE_OP_ADAPTER(ExtractGlimpse)
DECLARE_OP_USE_OUTPUT(ExtractGlimpse)
DECLARE_OP_ADAPTER(ScaleAndTranslateGrad)
DECLARE_OP_USE_OUTPUT(ScaleAndTranslateGrad)
DECLARE_OP_ADAPTER(ResizeBicubicGrad)
DECLARE_OP_USE_OUTPUT(ResizeBicubicGrad)
DECLARE_OP_ADAPTER(ScaleAndTranslate)
DECLARE_OP_USE_OUTPUT(ScaleAndTranslate)
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_IMAGE_OPS_DECLARE_H_

View File

@ -100,4 +100,10 @@ CUST_INPUT_MAP(LuSolve) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(lu_data)}, {3, INP
CUST_ATTR_MAP(LuSolve) = EMPTY_ATTR_MAP;
CUST_OUTPUT_MAP(LuSolve) = {{0, OUTPUT_DESC(output)}};
REG_ADPT_DESC(LuSolve, prim::kPrimLuSolve->name(), CUST_ADPT_DESC(LuSolve));
// Qr
INPUT_MAP(Qr) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Qr) = {{"full_matrices", ATTR_DESC(full_matrices, AnyTraits<bool>())}};
OUTPUT_MAP(Qr) = {{0, OUTPUT_DESC(q)}, {1, OUTPUT_DESC(r)}};
REG_ADPT_DESC(Qr, prim::kPrimQr->name(), ADPT_DESC(Qr));
} // namespace mindspore::transform

View File

@ -60,4 +60,7 @@ DECLARE_CUST_OP_USE_OUTPUT(LuUnpackGrad)
DECLARE_CUST_OP_ADAPTER(LuSolve)
DECLARE_CUST_OP_USE_OUTPUT(LuSolve)
DECLARE_OP_ADAPTER(Qr)
DECLARE_OP_USE_OUTPUT(Qr)
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_LINALG_OPS_DECLARE_H_

View File

@ -309,4 +309,34 @@ CUST_INPUT_MAP(Gcd) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
CUST_ATTR_MAP(Gcd) = EMPTY_ATTR_MAP;
CUST_OUTPUT_MAP(Gcd) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Gcd, prim::kPrimGcd->name(), CUST_ADPT_DESC(Gcd));
// Orgqr
CUST_INPUT_MAP(Orgqr) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(tau)}};
CUST_ATTR_MAP(Orgqr) = EMPTY_ATTR_MAP;
CUST_OUTPUT_MAP(Orgqr) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Orgqr, prim::kPrimOrgqr->name(), CUST_ADPT_DESC(Orgqr));
// RaggedRange
INPUT_MAP(RaggedRange) = {{1, INPUT_DESC(starts)}, {2, INPUT_DESC(limits)}, {3, INPUT_DESC(deltas)}};
ATTR_MAP(RaggedRange) = {{"Tsplits", ATTR_DESC(Tsplits, AnyTraits<GEType>())}};
OUTPUT_MAP(RaggedRange) = {{0, OUTPUT_DESC(rt_nested_splits)}, {1, OUTPUT_DESC(rt_dense_values)}};
REG_ADPT_DESC(RaggedRange, prim::kPrimRaggedRange->name(), ADPT_DESC(RaggedRange));
// Imag
INPUT_MAP(Imag) = {{1, INPUT_DESC(input)}};
ATTR_MAP(Imag) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Imag) = {{0, OUTPUT_DESC(output)}};
REG_ADPT_DESC(Imag, prim::kPrimImag->name(), ADPT_DESC(Imag));
// Lgamma
CUST_INPUT_MAP(Lgamma) = {{1, INPUT_DESC(x)}};
CUST_ATTR_MAP(Lgamma) = EMPTY_ATTR_MAP;
CUST_OUTPUT_MAP(Lgamma) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Lgamma, prim::kPrimLgamma->name(), CUST_ADPT_DESC(Lgamma));
// Real
INPUT_MAP(Real) = {{1, INPUT_DESC(input)}};
ATTR_MAP(Real) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Real) = {{0, OUTPUT_DESC(output)}};
REG_ADPT_DESC(Real, prim::kPrimReal->name(), ADPT_DESC(Real));
} // namespace mindspore::transform

View File

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATH_OPS_DECLARE_H_
#include "inc/ops/math_ops.h"
#include "inc/ops/ragged_math_ops.h"
#include "inc/ops/spectral_ops.h"
#include "transform/graph_ir/custom_op_proto/cust_math_ops.h"
#include "mindspore/ccsrc/include/common/utils/utils.h"
@ -146,4 +147,19 @@ DECLARE_CUST_OP_USE_OUTPUT(Heaviside)
DECLARE_CUST_OP_ADAPTER(Gcd)
DECLARE_CUST_OP_USE_OUTPUT(Gcd)
DECLARE_CUST_OP_ADAPTER(Orgqr)
DECLARE_CUST_OP_USE_OUTPUT(Orgqr)
DECLARE_OP_ADAPTER(RaggedRange)
DECLARE_OP_USE_OUTPUT(RaggedRange)
DECLARE_OP_ADAPTER(Imag)
DECLARE_OP_USE_OUTPUT(Imag)
DECLARE_CUST_OP_ADAPTER(Lgamma)
DECLARE_CUST_OP_USE_OUTPUT(Lgamma)
DECLARE_OP_ADAPTER(Real)
DECLARE_OP_USE_OUTPUT(Real)
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATH_OPS_DECLARE_H_

View File

@ -251,4 +251,16 @@ ATTR_MAP(FillDiagonal) = {{"fill_value", ATTR_DESC(fill_value, AnyTraits<float>(
{"wrap", ATTR_DESC(wrap, AnyTraits<bool>())}};
OUTPUT_MAP(FillDiagonal) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(FillDiagonal, kNameFillDiagonal, ADPT_DESC(FillDiagonal));
// Trace
INPUT_MAP(Trace) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Trace) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Trace) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Trace, prim::kPrimTrace->name(), ADPT_DESC(Trace));
// TraceGrad
CUST_INPUT_MAP(TraceGrad) = {{1, INPUT_DESC(y_grad)}, {2, INPUT_DESC(x_shape)}};
CUST_ATTR_MAP(TraceGrad) = EMPTY_ATTR_MAP;
CUST_OUTPUT_MAP(TraceGrad) = {{0, OUTPUT_DESC(x_grad)}};
REG_ADPT_DESC(TraceGrad, prim::kPrimTraceGrad->name(), CUST_ADPT_DESC(TraceGrad));
} // namespace mindspore::transform

View File

@ -19,6 +19,7 @@
#include "mindspore/ccsrc/include/common/utils/utils.h"
#include "inc/ops/matrix_calculation_ops.h"
#include "transform/graph_ir/custom_op_proto/cust_math_ops.h"
#include "transform/graph_ir/op_declare/op_declare_macro.h"
#include "utils/hash_map.h"
@ -126,4 +127,10 @@ DECLARE_OP_USE_OUTPUT(Eye)
DECLARE_OP_ADAPTER(FillDiagonal)
DECLARE_OP_USE_OUTPUT(FillDiagonal)
DECLARE_OP_ADAPTER(Trace)
DECLARE_OP_USE_OUTPUT(Trace)
DECLARE_CUST_OP_ADAPTER(TraceGrad)
DECLARE_CUST_OP_USE_OUTPUT(TraceGrad)
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATRIX_CALCULATION_OPS_DECLARE_H_

View File

@ -376,4 +376,10 @@ ATTR_MAP(FractionalAvgPool) = {{"pooling_ratio", ATTR_DESC(pooling_ratio, AnyTra
OUTPUT_MAP(FractionalAvgPool) = {
{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(row_pooling_sequence)}, {2, OUTPUT_DESC(col_pooling_sequence)}};
REG_ADPT_DESC(FractionalAvgPool, prim::kPrimFractionalAvgPool->name(), ADPT_DESC(FractionalAvgPool));
// NthElement
INPUT_MAP(NthElement) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(n)}};
ATTR_MAP(NthElement) = {{"reverse", ATTR_DESC(reverse, AnyTraits<bool>())}};
OUTPUT_MAP(NthElement) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(NthElement, prim::kPrimNthElement->name(), ADPT_DESC(NthElement));
} // namespace mindspore::transform

View File

@ -130,4 +130,7 @@ DECLARE_OP_USE_OUTPUT(FractionalMaxPoolGrad)
DECLARE_OP_ADAPTER(FractionalAvgPool)
DECLARE_OP_USE_OUTPUT(FractionalAvgPool)
DECLARE_OP_ADAPTER(NthElement)
DECLARE_OP_USE_OUTPUT(NthElement)
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NN_POOLING_OPS_DECLARE_H_

View File

@ -64,4 +64,11 @@ INPUT_MAP(PadV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}, {3, INPUT_DES
ATTR_MAP(PadV2) = EMPTY_ATTR_MAP;
OUTPUT_MAP(PadV2) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(PadV2, kNamePadV2, ADPT_DESC(PadV2))
// PadV3Grad
INPUT_MAP(PadV3Grad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}};
ATTR_MAP(PadV3Grad) = {{"mode", ATTR_DESC(mode, AnyTraits<std::string>())},
{"paddings_contiguous", ATTR_DESC(paddings_contiguous, AnyTraits<bool>())}};
OUTPUT_MAP(PadV3Grad) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(PadV3Grad, kNamePadV3Grad, ADPT_DESC(PadV3Grad));
} // namespace mindspore::transform

View File

@ -44,4 +44,7 @@ DECLARE_OP_USE_OUTPUT(PadV3)
DECLARE_OP_ADAPTER(PadV2)
DECLARE_OP_USE_OUTPUT(PadV2)
DECLARE_OP_ADAPTER(PadV3Grad)
DECLARE_OP_USE_OUTPUT(PadV3Grad)
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_PAD_OPS_DECLARE_H_

View File

@ -117,4 +117,11 @@ CUST_INPUT_MAP(Dropout2D) = {{1, INPUT_DESC(x)}};
CUST_ATTR_MAP(Dropout2D) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())}};
CUST_OUTPUT_MAP(Dropout2D) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(mask)}};
REG_ADPT_DESC(Dropout2D, kNameDropout2D, CUST_ADPT_DESC(Dropout2D))
// StandardLaplace
CUST_INPUT_MAP(StandardLaplace) = {{1, INPUT_DESC(shape)}, {2, INPUT_DESC(seed)}, {3, INPUT_DESC(seed2)}};
CUST_ATTR_MAP(StandardLaplace) = {{"seed", ATTR_DESC(seed, AnyTraits<int64_t>())},
{"seed2", ATTR_DESC(seed2, AnyTraits<int64_t>())}};
CUST_OUTPUT_MAP(StandardLaplace) = {{0, OUTPUT_DESC(output)}};
REG_ADPT_DESC(StandardLaplace, prim::kPrimStandardLaplace->name(), CUST_ADPT_DESC(StandardLaplace));
} // namespace mindspore::transform

View File

@ -64,4 +64,7 @@ DECLARE_OP_USE_OUTPUT(NonDeterministicInts)
DECLARE_CUST_OP_ADAPTER(Dropout2D)
DECLARE_CUST_OP_USE_OUTPUT(Dropout2D)
DECLARE_CUST_OP_ADAPTER(StandardLaplace)
DECLARE_CUST_OP_USE_OUTPUT(StandardLaplace)
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_RANDOM_OPS_DECLARE_H_