Add regop and adapters for custom aicpu
This commit is contained in:
parent
43991f351f
commit
be2abc0952
|
@ -175,6 +175,7 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cast_kernels.cc" "build/include_order"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cast_kernels.h" "runtime/explicit"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cast_kernels.h" "whitespace/indent"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/meshgrid_kernels.h" "runtime/explicit"
|
||||
|
||||
# custom AICPU op_protos
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/" "whitespace/ending_newline"
|
||||
|
|
|
@ -424,3 +424,4 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/
|
|||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/combined_non_max_suppression_proto.cc:ge::IMPLEMT_INFERFUNC
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/im2col_proto.cc:ge::IMPLEMT_COMMON_INFERFUNC
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/math_ops_proto.cc:ge::IMPLEMT_INFERFUNC
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/selection_ops_proto.cc:ge::IMPLEMT_COMMON_INFERFUNC
|
||||
|
|
|
@ -52,6 +52,7 @@ if(EXISTS ${CMAKE_C_COMPILER} AND EXISTS ${CMAKE_CXX_COMPILER})
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/candidate_sampler_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cast_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/concat_offset_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/meshgrid_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/drop_out_gen_mask_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expand_dims_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/flatten_kernels.cc
|
||||
|
|
|
@ -149,6 +149,10 @@ uint32_t GatherNdCpuKernel::GatherNdComputeRealKernel(CpuKernelContext &ctx) {
|
|||
for (int64_t i = 0; i < n_slices; ++i) {
|
||||
int64_t from_pos = 0;
|
||||
for (int64_t j = 0; j < indices_nd; ++j) {
|
||||
if (indices_data[i * indices_nd + j] < 0) {
|
||||
KERNEL_LOG_ERROR("For 'GatherNd', indices can't contain negative value.");
|
||||
return KERNEL_STATUS_INNER_ERROR;
|
||||
}
|
||||
from_pos += indices_data[i * indices_nd + j] * dims_to_count[j];
|
||||
}
|
||||
auto offset = i * slice_size;
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,106 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/adaptive_avg_pool_3d_grad_op.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "external/graph/operator_reg.h"
|
||||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
// --------- AdaptiveAvgPool3dGrad ---------------
|
||||
CUST_IMPLEMT_VERIFIER(AdaptiveAvgPool3dGrad, AdaptiveAvgPool3dGradVerify) {
|
||||
auto input_grad_desc = op.GetInputDescByName("input_grad");
|
||||
auto orig_input_shape_desc = op.GetInputDescByName("orig_input_shape");
|
||||
ge::AscendString op_name;
|
||||
(void)op.GetName(op_name);
|
||||
|
||||
auto orig_input_shape_dim = orig_input_shape_desc.GetShape().GetDimNum();
|
||||
if (orig_input_shape_dim != 1) {
|
||||
OP_LOGE("AdaptiveAvgPool3dGrad", "Num Dim of orig_input_shape is invalid");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto orig_input_dim_num = orig_input_shape_desc.GetShape().GetShapeSize();
|
||||
auto input_grad_dim_num = input_grad_desc.GetShape().GetDimNum();
|
||||
|
||||
if (orig_input_dim_num != static_cast<int64_t>(input_grad_dim_num)) {
|
||||
OP_LOGE("AdaptiveAvgPool3dGrad", "Num Dim of orig_input and input_grad should be the same");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(AdaptiveAvgPool3dGradInferShape) {
|
||||
map<int, std::string> format2str = {
|
||||
{ge::FORMAT_NCHW, "NCHW"}, {ge::FORMAT_NHWC, "NHWC"}, {ge::FORMAT_HWCN, "HWCN"}, {ge::FORMAT_DHWNC, "DHWNC"},
|
||||
{ge::FORMAT_DHWCN, "DHWCN"}, {ge::FORMAT_NDHWC, "NDHWC"}, {ge::FORMAT_NCDHW, "NCDHW"}};
|
||||
|
||||
auto input_desc = op.GetInputDescByName("input_grad");
|
||||
auto orig_input_shape_desc = op.GetInputDescByName("orig_input_shape");
|
||||
TensorDesc out_desc = op.GetOutputDescByName("output_grad");
|
||||
ge::AscendString op_name;
|
||||
(void)op.GetName(op_name);
|
||||
|
||||
// update format
|
||||
Format input_format = input_desc.GetFormat();
|
||||
std::string format_str = format2str[input_format];
|
||||
if (input_format != FORMAT_NCHW) {
|
||||
OP_LOGE("AdaptiveAvgPool3dGrad",
|
||||
"Input format only support NCHW"
|
||||
", input format is [%s]",
|
||||
format_str.c_str());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out_desc.SetFormat(input_format);
|
||||
|
||||
// update data type
|
||||
DataType input_type = input_desc.GetDataType();
|
||||
out_desc.SetDataType(input_type);
|
||||
|
||||
// infer shape
|
||||
Tensor orig_input_size_tensor;
|
||||
if (op.GetInputConstData("orig_input_shape", orig_input_size_tensor) != GRAPH_SUCCESS) {
|
||||
OP_LOGE("AdaptiveAvgPool3dGrad", "failed to get tensor from output_size");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int32_t *orig_input_size_data = reinterpret_cast<int32_t *>(orig_input_size_tensor.GetData());
|
||||
if (orig_input_size_data == nullptr) {
|
||||
OP_LOGE("AdaptiveAvgPool3dGrad", "output_size data is invalid");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto input_size_dim_num = input_desc.GetShape().GetDimNum();
|
||||
std::vector<int64_t> output_shape(input_size_dim_num);
|
||||
|
||||
for (uint64_t i = 0; i < input_size_dim_num; ++i) {
|
||||
output_shape[i] = orig_input_size_data[i];
|
||||
}
|
||||
|
||||
out_desc.SetShape(Shape(output_shape));
|
||||
if (op.UpdateOutputDesc("output_grad", out_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE("AdaptiveAvgPool3dGrad", "failed to update output desc");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(AdaptiveAvgPool3dGrad, AdaptiveAvgPool3dGradInferShape);
|
||||
CUST_VERIFY_FUNC_REG(AdaptiveAvgPool3dGrad, AdaptiveAvgPool3dGradVerify);
|
||||
// --------- AdaptiveAvgPool3dGrad end---------------
|
||||
} // namespace ge
|
|
@ -1,90 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/adaptive_avg_pool_3d_op.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
// --------- AdaptiveAvgPool3d ---------------
|
||||
IMPLEMT_COMMON_INFERFUNC(AdaptiveAvgPool3dInferShape) {
|
||||
map<int, std::string> format2str = {
|
||||
{ge::FORMAT_NCHW, "NCHW"}, {ge::FORMAT_NHWC, "NHWC"}, {ge::FORMAT_HWCN, "HWCN"}, {ge::FORMAT_DHWNC, "DHWNC"},
|
||||
{ge::FORMAT_DHWCN, "DHWCN"}, {ge::FORMAT_NDHWC, "NDHWC"}, {ge::FORMAT_NCDHW, "NCDHW"}};
|
||||
|
||||
// verify the dim of output_size
|
||||
std::vector<int64_t> output_size;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("output_size", output_size)) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "GetOpAttr output_size failed!");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
ge::AscendString op_name;
|
||||
(void)op.GetName(op_name);
|
||||
auto input_desc = op.GetInputDescByName("x");
|
||||
TensorDesc out_desc = op.GetOutputDescByName("y");
|
||||
|
||||
// update data type
|
||||
DataType input_type = input_desc.GetDataType();
|
||||
out_desc.SetDataType(input_type);
|
||||
|
||||
// update format
|
||||
Format input_format = input_desc.GetFormat();
|
||||
std::string format_str = format2str[input_format];
|
||||
if (input_format != FORMAT_NCHW) {
|
||||
OP_LOGE("AdaptiveAvgPool3d",
|
||||
"Input format only support NCHW"
|
||||
", input format is [%s]",
|
||||
format_str.c_str());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out_desc.SetFormat(input_format);
|
||||
|
||||
std::vector<int64_t> input_size_shape = input_desc.GetShape().GetDims();
|
||||
auto input_size_dim_num = input_size_shape.size();
|
||||
std::vector<int64_t> output_shape(input_size_shape.begin(), input_size_shape.end());
|
||||
auto output_size_num = output_size.size();
|
||||
if (output_size_num == 1) {
|
||||
for (uint64_t i = input_size_dim_num - 3; i < input_size_dim_num; ++i) {
|
||||
if (output_size[0] < 0) {
|
||||
continue;
|
||||
}
|
||||
output_shape[i] = output_size[0];
|
||||
}
|
||||
} else if (output_size_num == 3) {
|
||||
for (uint64_t i = input_size_dim_num - 3; i < input_size_dim_num; ++i) {
|
||||
auto data = output_size[i - input_size_dim_num + 3];
|
||||
if (data < 0) {
|
||||
continue;
|
||||
}
|
||||
output_shape[i] = data;
|
||||
}
|
||||
} else {
|
||||
OP_LOGE("AdaptiveAvgPool3d", "Shape of output_size is invalid");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
out_desc.SetShape(Shape(output_shape));
|
||||
if (op.UpdateOutputDesc("y", out_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE("AdaptiveAvgPool3d", "failed to update output desc");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(AdaptiveAvgPool3d, AdaptiveAvgPool3dInferShape);
|
||||
// --------- AdaptiveAvgPool3d end---------------
|
||||
} // namespace ge
|
|
@ -1,39 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/adaptive_max_pool3_d_grad_op.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
CUST_IMPLEMT_INFERFUNC(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradInferShape) {
|
||||
TensorDesc output_grad = op.GetOutputDescByName("output_grad");
|
||||
TensorDesc input = op.GetInputDescByName("x");
|
||||
DataType input_dtype = input.GetDataType();
|
||||
Shape input_shape = input.GetShape();
|
||||
output_grad.SetShape(input_shape);
|
||||
output_grad.SetDataType(input_dtype);
|
||||
if (op.UpdateOutputDesc("output_grad", output_grad) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
CUST_INFER_FUNC_REG(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradInferShape);
|
||||
CUST_VERIFY_FUNC_REG(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradVerify);
|
||||
} // namespace ge
|
|
@ -1,58 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/adaptive_max_pool3d_op.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
CUST_IMPLEMT_INFERFUNC(AdaptiveMaxPool3d, AdaptiveMaxPool3dInferShape) {
|
||||
TensorDesc input = op.GetInputDesc(0);
|
||||
TensorDesc output_size = op.GetInputDesc(1);
|
||||
TensorDesc output = op.GetOutputDesc(0);
|
||||
TensorDesc argmax = op.GetOutputDesc(1);
|
||||
|
||||
const size_t input_num_dims = input.GetShape().GetDimNum();
|
||||
const std::vector<int64_t> output_size_shape = output_size.GetShape().GetDims();
|
||||
if ((input_num_dims == 4 || input_num_dims == 5) == false) {
|
||||
OP_LOGE(TbeGetName(op), "Input dimensions must be equal to 4 or 5.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (output_size_shape.size() != 1) {
|
||||
OP_LOGE(TbeGetName(op), "output_size dim should be equal to 1.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (output_size_shape[0] != 3) {
|
||||
OP_LOGE(TbeGetName(op), "output_size shape[0] should be equal to 3.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
DataType input_dtype = input.GetDataType();
|
||||
Shape output_shape(UNKNOWN_SHAPE);
|
||||
output.SetDataType(input_dtype);
|
||||
output.SetShape(output_shape);
|
||||
argmax.SetDataType(DT_INT32);
|
||||
argmax.SetShape(output_shape);
|
||||
(void)op.UpdateOutputDesc("y", output);
|
||||
(void)op.UpdateOutputDesc("argmax", argmax);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(AdaptiveMaxPool3d, AdaptiveMaxPool3dVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
CUST_INFER_FUNC_REG(AdaptiveMaxPool3d, AdaptiveMaxPool3dInferShape);
|
||||
CUST_VERIFY_FUNC_REG(AdaptiveMaxPool3d, AdaptiveMaxPool3dVerify);
|
||||
} // namespace ge
|
|
@ -1,39 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/adaptive_max_pool_2d_grad_op.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
CUST_IMPLEMT_INFERFUNC(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradInferShape) {
|
||||
TensorDesc input_grad = op.GetOutputDescByName("x_grad");
|
||||
TensorDesc input = op.GetInputDescByName("x");
|
||||
DataType input_dtype = input.GetDataType();
|
||||
Shape input_shape = input.GetShape();
|
||||
input_grad.SetShape(input_shape);
|
||||
input_grad.SetDataType(input_dtype);
|
||||
if (op.UpdateOutputDesc("x_grad", input_grad) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
CUST_INFER_FUNC_REG(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradInferShape);
|
||||
CUST_VERIFY_FUNC_REG(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradVerify);
|
||||
} // namespace ge
|
|
@ -160,6 +160,11 @@ IMPLEMT_INFERFUNC(Expand, ExpandInferShape) {
|
|||
OP_LOGE(op_name, "Data shape are not compatible!");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
} else if (data_type == DT_INT16) {
|
||||
if (!ExpandCalDim<int16_t>(data, vec_dim, x_dims, range_vector)) {
|
||||
OP_LOGE(op_name, "Data shape are not compatible!");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
} else {
|
||||
OP_LOGE(op_name, "Data type not supported!");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
|
@ -412,12 +417,12 @@ static bool CheckSteps(const Operator &op, const string &attr_num_steps) {
|
|||
CUST_IMPLEMT_VERIFIER(LogSpace, LogSpaceVerify) {
|
||||
AscendString opName;
|
||||
op.GetName(opName);
|
||||
if (op.GetInputDescByName("start").GetShape().GetDims().size() != 1) {
|
||||
OP_LOGE(opName.GetString(), "Input start size must be 1.");
|
||||
if (op.GetInputDescByName("start").GetShape().GetDims().size() > 1) {
|
||||
OP_LOGE(opName.GetString(), "Input start size must be <= 1.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (op.GetInputDescByName("end").GetShape().GetDims().size() != 1) {
|
||||
OP_LOGE(opName.GetString(), "Input end size must be 1.");
|
||||
if (op.GetInputDescByName("end").GetShape().GetDims().size() > 1) {
|
||||
OP_LOGE(opName.GetString(), "Input end size must be <= 1.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
DataType input_type_start = op.GetInputDescByName("start").GetDataType();
|
||||
|
@ -462,4 +467,108 @@ CUST_COMMON_INFER_FUNC_REG(LogSpace, LogSpaceInferShape);
|
|||
// Registered verify function
|
||||
CUST_VERIFY_FUNC_REG(LogSpace, LogSpaceVerify);
|
||||
// --------------------------LogSpace END---------------------
|
||||
|
||||
// ----------------UniqueConsecutive Begin-------------------
|
||||
IMPLEMT_INFERFUNC(UniqueConsecutive, UniqueConsecutiveInfer) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto x_desc_ptr = op_desc->MutableInputDesc(0);
|
||||
auto y_desc_ptr = op_desc->MutableOutputDesc(0);
|
||||
y_desc_ptr->SetDataType(x_desc_ptr->GetDataType());
|
||||
|
||||
auto idx_desc_ptr = op_desc->MutableOutputDesc(1);
|
||||
auto count_desc_ptr = op_desc->MutableOutputDesc(2);
|
||||
|
||||
auto &y_shape = y_desc_ptr->MutableShape();
|
||||
auto &idx_shape = idx_desc_ptr->MutableShape();
|
||||
auto &count_shape = count_desc_ptr->MutableShape();
|
||||
|
||||
bool return_idx = false;
|
||||
bool return_counts = false;
|
||||
int64_t axis = 1000;
|
||||
|
||||
op.GetAttr("axis", axis);
|
||||
op.GetAttr("return_idx", return_idx);
|
||||
op.GetAttr("return_counts", return_counts);
|
||||
count_shape.SetIsUnknownDimNum();
|
||||
count_desc_ptr->SetDataType(DT_INT64);
|
||||
idx_shape.SetIsUnknownDimNum();
|
||||
idx_desc_ptr->SetDataType(DT_INT64);
|
||||
y_shape.SetIsUnknownDimNum();
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(UniqueConsecutive, UniqueConsecutiveInfer);
|
||||
// ----------------UniqueConsecutive End-----------------------
|
||||
|
||||
// ----------------UpperBound-----------------------
|
||||
IMPLEMT_INFERFUNC(UpperBound, UpperBoundInfer) {
|
||||
Shape unused_shape;
|
||||
if (WithRank(op.GetInputDesc(0), 2, unused_shape, op) != GRAPH_SUCCESS) {
|
||||
string err_msg = ConcatString("failed to call WithRank function, input[sorted_x] rank must be 2D, got rank[",
|
||||
op.GetInputDesc(0).GetShape().GetDimNum(), "]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(op.GetInputDesc(1), 2, unused_shape, op) != GRAPH_SUCCESS) {
|
||||
string err_msg = ConcatString("failed to call WithRank function, input[values] rank must be 2D, got rank[",
|
||||
op.GetInputDesc(1).GetShape().GetDimNum(), "]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
DataType type;
|
||||
if (op.GetAttr("out_type", type) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("get attr[out_type] failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
TensorDesc out_desc = op.GetOutputDescByName("y");
|
||||
out_desc.SetShape(op.GetInputDesc(1).GetShape());
|
||||
out_desc.SetDataType(type);
|
||||
return op.UpdateOutputDesc("y", out_desc);
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(UpperBound, UpperBoundInfer);
|
||||
// ----------------UpperBound END-----------------------
|
||||
|
||||
// ----------------UnravelIndex-----------------------
|
||||
IMPLEMT_INFERFUNC(UnravelIndex, UnravelIndexInfer) {
|
||||
auto indices_desc = op.GetInputDesc(0);
|
||||
auto dims_desc = op.GetInputDesc(1);
|
||||
|
||||
Shape dims_shape;
|
||||
if (WithRank(dims_desc, 1, dims_shape, op) != GRAPH_SUCCESS) {
|
||||
string err_msg = ConcatString("input[dims] must be 1D, real rank is ", dims_shape.GetDimNum());
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
Shape indices_shape;
|
||||
if (WithRankAtMost(indices_desc, 1, indices_shape, op) != GRAPH_SUCCESS) {
|
||||
string err_msg = ConcatString("input[indices] must be less than 1D, real rank is ", dims_shape.GetDimNum());
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::vector<int64_t> out_dims({-1, -1});
|
||||
std::vector<int64_t> dims_shape_vec = dims_shape.GetDims();
|
||||
std::vector<int64_t> indices_shape_vec = indices_shape.GetDims();
|
||||
if (indices_shape.GetDimNum() == 0) {
|
||||
out_dims.pop_back();
|
||||
} else {
|
||||
if (indices_shape_vec != ge::UNKNOWN_RANK && indices_shape_vec != ge::UNKNOWN_SHAPE) {
|
||||
out_dims[1] = indices_shape_vec[0];
|
||||
}
|
||||
}
|
||||
if (dims_shape_vec != ge::UNKNOWN_RANK && dims_shape_vec != ge::UNKNOWN_SHAPE) {
|
||||
out_dims[0] = dims_shape_vec[0];
|
||||
}
|
||||
|
||||
TensorDesc out_desc = op.GetOutputDescByName("y");
|
||||
out_desc.SetShape(Shape(out_dims));
|
||||
out_desc.SetDataType(indices_desc.GetDataType());
|
||||
return op.UpdateOutputDesc("y", out_desc);
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(UnravelIndex, UnravelIndexInfer);
|
||||
// ----------------UnravelIndex END-----------------------
|
||||
} // namespace ge
|
|
@ -1,45 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/linalg_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
#include "utils/linalg_ops_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
IMPLEMT_INFERFUNC(CholeskyGrad, CholeskyGradInfer) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto x_desc = op_desc->MutableInputDesc(0);
|
||||
|
||||
GeShape y_shape;
|
||||
if (MakeBatchSquareMatrix(x_desc, y_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(),
|
||||
"Op CholeskyGrad first input x tensor make batch square matrix "
|
||||
"failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
DataType type = x_desc->GetDataType();
|
||||
auto y_desc = op_desc->MutableOutputDesc(0);
|
||||
y_desc->SetShape(y_shape);
|
||||
y_desc->SetDataType(type);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(CholeskyGrad, CholeskyGradInfer);
|
||||
} // namespace ge
|
|
@ -59,6 +59,10 @@ COMMON_INFER_FUNC_REG(Cos, OneInOneOutCommonInferShape);
|
|||
COMMON_INFER_FUNC_REG(Expm1, OneInOneOutCommonInferShape);
|
||||
COMMON_INFER_FUNC_REG(Log1p, OneInOneOutCommonInferShape);
|
||||
COMMON_INFER_FUNC_REG(Log, OneInOneOutCommonInferShape);
|
||||
COMMON_INFER_FUNC_REG(Tanh, OneInOneOutCommonInferShape);
|
||||
COMMON_INFER_FUNC_REG(Sin, OneInOneOutCommonInferShape);
|
||||
COMMON_INFER_FUNC_REG(Reciprocal, OneInOneOutCommonInferShape);
|
||||
COMMON_INFER_FUNC_REG(Sign, OneInOneOutCommonInferShape);
|
||||
// ----------------------------------OneInOneOutCommonInfer END-----------------------------
|
||||
|
||||
// ----------------------------------TowInOneOutCommonInfer-----------------------------
|
||||
|
@ -68,6 +72,9 @@ CUST_COMMON_INFER_FUNC_REG(Gcd, TwoInOneOutCommonInferShape);
|
|||
CUST_COMMON_INFER_FUNC_REG(Heaviside, TwoInOneOutCommonInferShape);
|
||||
CUST_COMMON_INFER_FUNC_REG(Hypot, TwoInOneOutCommonInferShape);
|
||||
CUST_COMMON_INFER_FUNC_REG(Lcm, TwoInOneOutCommonInferShape);
|
||||
CUST_COMMON_INFER_FUNC_REG(Pow, TwoInOneOutCommonInferShape);
|
||||
CUST_COMMON_INFER_FUNC_REG(Xlogy, TwoInOneOutCommonInferShape);
|
||||
CUST_COMMON_INFER_FUNC_REG(Xdivy, TwoInOneOutCommonInferShape);
|
||||
// ----------------------------------TowInOneOutCommonInfer END-----------------------------
|
||||
|
||||
// --------------AcosGrad----------------
|
||||
|
@ -399,4 +406,33 @@ IMPLEMT_VERIFIER(FloorDiv, FloorDivVerify) {
|
|||
COMMON_INFER_FUNC_REG(FloorDiv, TwoInOneOutCommonInferShape);
|
||||
VERIFY_FUNC_REG(FloorDiv, FloorDivVerify);
|
||||
// ----------------FloorDiv END------------------------
|
||||
|
||||
// ----------------SqrtGrad Op Begin-----------------
|
||||
IMPLEMT_VERIFIER(SqrtGrad, SqrtGradVerify) {
|
||||
if (!CheckTwoInputDtypeSame(op, "y", "dy")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(SqrtGradInferShape) {
|
||||
Shape shape_x = op.GetInputDescByName("y").GetShape();
|
||||
DataType input_dtype = op.GetInputDescByName("y").GetDataType();
|
||||
TensorDesc tensordesc_output = op.GetOutputDescByName("z");
|
||||
std::vector<std::pair<int64_t, int64_t>> shape_range_x;
|
||||
op.GetInputDescByName("y").GetShapeRange(shape_range_x);
|
||||
tensordesc_output.SetShape(shape_x);
|
||||
tensordesc_output.SetDataType(input_dtype);
|
||||
tensordesc_output.SetShapeRange(shape_range_x);
|
||||
if (op.UpdateOutputDesc("z", tensordesc_output) != GRAPH_SUCCESS) {
|
||||
std::string err_msg = UpdateParamErrMsg("z");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(SqrtGrad, SqrtGradInferShape);
|
||||
VERIFY_FUNC_REG(SqrtGrad, SqrtGradVerify);
|
||||
// ----------------SqrtGrad Op End-------------------
|
||||
} // namespace ge
|
||||
|
|
|
@ -14,9 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <graph/utils/type_utils.h>
|
||||
#include "inc/ops/image_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/op_const.h"
|
||||
#include "utils/image_ops_shape_fns.h"
|
||||
namespace ge {
|
||||
// ----------------AdjustHue Start-------------------
|
||||
|
@ -138,5 +140,425 @@ IMPLEMT_INFERFUNC(ExtractGlimpse, ExtractGlimpseInfer) {
|
|||
}
|
||||
|
||||
INFER_FUNC_REG(ExtractGlimpse, ExtractGlimpseInfer);
|
||||
// ----------------ExtractGlimpse-------------------
|
||||
// ----------------ExtractGlimpse END-------------------
|
||||
|
||||
// ----------------ResizeArea-------------------
|
||||
IMPLEMT_INFERFUNC(ResizeArea, ResizeAreaInfer) {
|
||||
TensorDesc desc = op.GetOutputDescByName("y");
|
||||
desc.SetDataType(DT_FLOAT);
|
||||
if (op.UpdateOutputDesc("y", desc) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return ResizeShapeFn(op, "images", "size", "y");
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(ResizeArea, ResizeAreaInfer);
|
||||
// ----------------ResizeArea END-------------------
|
||||
|
||||
// ----------------ResizeBicubic-------------------
|
||||
IMPLEMT_INFERFUNC(ResizeBicubic, ResizeBicubicInfer) {
|
||||
TensorDesc desc = op.GetOutputDescByName("y");
|
||||
|
||||
DataType data_type = DT_FLOAT;
|
||||
// Update DataType when Attr "dtype" is set
|
||||
if (op.GetAttr("dtype", data_type) == GRAPH_SUCCESS) {
|
||||
CHECK(((data_type != DT_FLOAT) && (data_type != DT_UINT8)),
|
||||
OP_LOGE(TbeGetName(op), "Attr dtype should only be DT_FLOAT or DT_UNIT8"), return GRAPH_FAILED);
|
||||
OP_LOGI(TbeGetName(op), "Update Bicubic DataType from attr, which is %d", data_type);
|
||||
}
|
||||
|
||||
desc.SetDataType(data_type);
|
||||
if (op.UpdateOutputDesc("y", desc) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return ResizeShapeFn(op, "images", "size", "y");
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(ResizeBicubic, ResizeBicubicInfer);
|
||||
// ----------------ResizeBicubic END-------------------
|
||||
|
||||
bool ResizeConstInferShape(const Operator &op, const std::pair<uint32_t, std::string> image_info,
|
||||
const std::pair<uint32_t, std::string> size_info,
|
||||
const std::pair<uint32_t, std::string> output_info) {
|
||||
static const size_t output_len = 4;
|
||||
static const size_t size_len = 2;
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
CHECK(op_desc == nullptr, OP_LOGE(TbeGetName(op), "op desc is null."), return false);
|
||||
|
||||
auto input_desc_x = op_desc->MutableInputDesc(image_info.first);
|
||||
CHECK(input_desc_x == nullptr,
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), OtherErrMsg("input x is null.")), return false);
|
||||
auto output_desc_y = op_desc->MutableOutputDesc(output_info.first);
|
||||
CHECK(output_desc_y == nullptr,
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), OtherErrMsg("output y is null.")), return false);
|
||||
|
||||
// infer dtype start
|
||||
output_desc_y->SetDataType(input_desc_x->GetDataType());
|
||||
// infer dtype end
|
||||
|
||||
// infer shape start
|
||||
const GeShape &x_shape = input_desc_x->MutableShape();
|
||||
auto input_format = input_desc_x->GetFormat();
|
||||
OP_LOGD(TbeGetName(op), "get the format is %s", ge::TypeUtils::FormatToSerialString(input_format).c_str());
|
||||
CHECK(input_format != FORMAT_NHWC && input_format != FORMAT_NCHW,
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), OtherErrMsg("The input format is valid")),
|
||||
return false);
|
||||
const int64_t image_n_idx = 0;
|
||||
// format is NHWC, c_idx = 3, format is NCHW, c_idx = 1,
|
||||
const int64_t image_c_idx = input_format == FORMAT_NHWC ? 3 : 1;
|
||||
const int64_t image_h_idx = input_format == FORMAT_NHWC ? 1 : 2;
|
||||
const int64_t image_w_idx = input_format == FORMAT_NHWC ? 2 : 3;
|
||||
// get const value
|
||||
bool is_size_const = true;
|
||||
vector<int64_t> size_out;
|
||||
if (!ops::GetConstIntData(op, size_info.first, size_out)) {
|
||||
OP_LOGW(TbeGetName(op).c_str(), "get const value of input size failed, set out hw = -1, -1");
|
||||
size_out = {-1, -1};
|
||||
is_size_const = false;
|
||||
}
|
||||
|
||||
// the size num must be 2, mean output h, output w
|
||||
OP_LOGD(TbeGetName(op), "the size num must be 2. get the num is %zu", size_out.size());
|
||||
CHECK(size_out.size() != size_len,
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), OtherErrMsg("the input size num must be 2.")),
|
||||
return false);
|
||||
|
||||
// get y shape
|
||||
GeShape &y_shape = output_desc_y->MutableShape();
|
||||
y_shape.SetDimNum(output_len);
|
||||
if (!x_shape.IsUnknownDimNum()) {
|
||||
OP_LOGD(TbeGetName(op), "the input shape size must be 4. get shape size is %zu", x_shape.GetDimNum());
|
||||
CHECK(x_shape.GetDimNum() != output_len,
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), OtherErrMsg("The dim of input x is not 4")),
|
||||
return false);
|
||||
y_shape.SetDim(image_n_idx, x_shape.GetDim(image_n_idx));
|
||||
y_shape.SetDim(image_c_idx, x_shape.GetDim(image_c_idx));
|
||||
} else {
|
||||
OP_LOGW(TbeGetName(op).c_str(), "the input is unknown rank, will set the out nc = -1, -1");
|
||||
y_shape.SetDim(image_n_idx, -1);
|
||||
y_shape.SetDim(image_c_idx, -1);
|
||||
}
|
||||
y_shape.SetDim(image_h_idx, size_out[0]);
|
||||
y_shape.SetDim(image_w_idx, size_out[1]);
|
||||
// infer shape end
|
||||
|
||||
// charge whether is dynamic, when output is static shape, return true
|
||||
CHECK(!y_shape.IsUnknownShape(), OP_LOGD(TbeGetName(op), "the output is static shape. infer succ"), return true);
|
||||
|
||||
OP_LOGD(TbeGetName(op), "the output is dynamic shape. will infer range");
|
||||
// infer shape_range start
|
||||
std::vector<std::pair<int64_t, int64_t>> x_range;
|
||||
vector<int64_t> image_shape{-1, -1, -1, -1};
|
||||
// check whether is -2 case
|
||||
if (!x_shape.IsUnknownDimNum()) {
|
||||
image_shape = x_shape.GetDims();
|
||||
(void)input_desc_x->GetShapeRange(x_range);
|
||||
}
|
||||
MakeUpShapeRange(image_shape, x_range);
|
||||
OP_LOGD(TbeGetName(op), "the input range size must be 4. get size is %zu", x_range.size());
|
||||
CHECK(x_range.size() != output_len,
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), OtherErrMsg("the x range size is not equal 4")),
|
||||
return false);
|
||||
if (!is_size_const) {
|
||||
std::vector<std::pair<int64_t, int64_t>> size_value_range;
|
||||
auto input_size_x = op_desc->MutableInputDesc(size_info.first);
|
||||
CHECK(input_size_x == nullptr,
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), OtherErrMsg("input size is null.")),
|
||||
return false);
|
||||
// means no const value, will get the value range
|
||||
(void)input_size_x->GetValueRange(size_value_range);
|
||||
// the size num must be 2, so the value range num must be 2
|
||||
if (size_value_range.size() != size_len) {
|
||||
x_range[image_h_idx] = std::pair<int64_t, int64_t>(0, -1);
|
||||
x_range[image_w_idx] = std::pair<int64_t, int64_t>(0, -1);
|
||||
} else {
|
||||
x_range[image_h_idx] = size_value_range[0];
|
||||
x_range[image_w_idx] = size_value_range[1];
|
||||
}
|
||||
} else {
|
||||
x_range[image_h_idx] = std::pair<int64_t, int64_t>(size_out[0], size_out[0]);
|
||||
x_range[image_w_idx] = std::pair<int64_t, int64_t>(size_out[1], size_out[1]);
|
||||
}
|
||||
|
||||
output_desc_y->SetShapeRange(x_range);
|
||||
// infer shape_range end
|
||||
return true;
|
||||
}
|
||||
|
||||
// ---------------ResizeNearestNeighborV2 Op Start-------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(ResizeNearestNeighborV2InferShape) {
|
||||
static const std::pair<uint32_t, std::string> input_x{0, "x"};
|
||||
static const std::pair<uint32_t, std::string> input_size{1, "size"};
|
||||
static const std::pair<uint32_t, std::string> output_y{0, "y"};
|
||||
const vector<string> depends{input_size.second};
|
||||
PREPARE_DYNAMIC_SHAPE(depends);
|
||||
if (!ResizeConstInferShape(op, input_x, input_size, output_y)) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
COMMON_INFER_FUNC_REG(ResizeNearestNeighborV2, ResizeNearestNeighborV2InferShape);
|
||||
INFER_VALUE_RANGE_DEFAULT_REG(ResizeNearestNeighborV2);
|
||||
// ---------------ResizeNearestNeighborV2 Op End-------------------
|
||||
|
||||
// ----------------ResizeNearestNeighborV2Grad-------------------
|
||||
IMPLEMT_INFERFUNC(ResizeNearestNeighborV2Grad, ResizeNearestNeighborV2GradInfer) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto y_desc = op_desc->MutableOutputDesc(0);
|
||||
auto size_desc = op_desc->MutableInputDesc(1);
|
||||
auto grads_desc = op_desc->MutableInputDesc(0);
|
||||
if (op.GetInputDesc(0).GetShape().GetDims() == UNKNOWN_RANK ||
|
||||
op.GetInputDesc(1).GetShape().GetDims() == UNKNOWN_RANK) {
|
||||
y_desc->SetShape(GeShape(UNKNOWN_RANK));
|
||||
y_desc->SetDataType(grads_desc->GetDataType());
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
// unknown shape support
|
||||
std::vector<std::string> input_infer_depends = {"size"};
|
||||
op_desc->SetOpInferDepends(input_infer_depends);
|
||||
|
||||
GeShape grads_shape;
|
||||
if (WithRank(grads_desc, 4, grads_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(op_desc->GetName().c_str(), "Input grads must be 4-D, real rank is [%lu]",
|
||||
grads_desc->GetShape().GetDimNum());
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
GeShape size_shape;
|
||||
if (WithRank(size_desc, 1, size_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(op_desc->GetName().c_str(), "Input size must be 1-D, real rank is [%lu]",
|
||||
size_desc->GetShape().GetDimNum());
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto size_dims = size_shape.GetDims();
|
||||
if (size_dims[0] != 2 && size_dims[0] != UNKNOWN_DIM) {
|
||||
OP_LOGE(op_desc->GetName().c_str(), "Input size must be 1-D of 2 elements, real dim size is [%ld]", size_dims[0]);
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
int64_t size_height = UNKNOWN_DIM;
|
||||
int64_t size_width = UNKNOWN_DIM;
|
||||
Tensor size_tensor;
|
||||
if (op.GetInputConstData("size", size_tensor) == GRAPH_SUCCESS) {
|
||||
auto size_data = reinterpret_cast<const int32_t *>(size_tensor.GetData());
|
||||
if (size_data == nullptr) {
|
||||
OP_LOGE(op_desc->GetName().c_str(), "Get size data failed");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
size_height = static_cast<int64_t>(size_data[0]);
|
||||
size_width = static_cast<int64_t>(size_data[1]);
|
||||
}
|
||||
|
||||
std::vector<int64_t> output_dims;
|
||||
auto grads_dims = grads_shape.GetDims();
|
||||
auto input_format = static_cast<ge::Format>(ge::GetPrimaryFormat(grads_desc->GetFormat()));
|
||||
if (input_format == FORMAT_NCHW) {
|
||||
output_dims.push_back(grads_dims[0]);
|
||||
output_dims.push_back(grads_dims[1]);
|
||||
output_dims.push_back(size_height);
|
||||
output_dims.push_back(size_width);
|
||||
} else if (input_format == FORMAT_NHWC) {
|
||||
output_dims.push_back(grads_dims[0]);
|
||||
output_dims.push_back(size_height);
|
||||
output_dims.push_back(size_width);
|
||||
output_dims.push_back(grads_dims[3]);
|
||||
} else {
|
||||
OP_LOGE(op_desc->GetName().c_str(), "Not supported this format: [%d]", input_format);
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
GeShape output_shape(output_dims);
|
||||
if (ShapeFullyDefined(output_shape) == false) {
|
||||
std::vector<std::pair<int64_t, int64_t>> output_range;
|
||||
for (const int64_t &output_dim : output_dims) {
|
||||
output_range.push_back(output_dim == UNKNOWN_DIM ? std::pair<int64_t, int64_t>{1, -1}
|
||||
: std::pair<int64_t, int64_t>{output_dim, output_dim});
|
||||
}
|
||||
y_desc->SetShapeRange(output_range);
|
||||
}
|
||||
y_desc->SetDataType(grads_desc->GetDataType());
|
||||
y_desc->SetShape(output_shape);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
INFER_FUNC_REG(ResizeNearestNeighborV2Grad, ResizeNearestNeighborV2GradInfer);
|
||||
// ----------------ResizeNearestNeighborV2Grad END-------------------
|
||||
|
||||
// ----------------RGBToHSV-------------------
|
||||
IMPLEMT_INFERFUNC(RGBToHSV, RGBToHSVInfer) {
|
||||
TensorDesc desc = op.GetOutputDescByName("y");
|
||||
desc.SetDataType(op.GetInputDesc(0).GetDataType());
|
||||
if (op.UpdateOutputDesc("y", desc) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("update output[y] desc failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return ColorspaceShapeFn(op, "y");
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(RGBToHSV, RGBToHSVInfer);
|
||||
// ----------------RGBToHSV END-------------------
|
||||
|
||||
// ----------------NonMaxSuppressionWithOverlaps-------------------
|
||||
IMPLEMT_INFERFUNC(NonMaxSuppressionWithOverlaps, NonMaxSuppressionWithOverlapsInfer) {
|
||||
Shape overlaps_shape = op.GetInputDescByName("overlaps").GetShape();
|
||||
Shape scores_shape = op.GetInputDescByName("scores").GetShape();
|
||||
Shape max_output_size_shape = op.GetInputDescByName("max_output_size").GetShape();
|
||||
Shape overlap_threshold_shape = op.GetInputDescByName("overlap_threshold").GetShape();
|
||||
Shape score_threshold_shape = op.GetInputDescByName("score_threshold").GetShape();
|
||||
if (WithRank(op.GetInputDescByName("overlaps"), 2, overlaps_shape, op) != GRAPH_SUCCESS) {
|
||||
std::string err_msg =
|
||||
ConcatString("failed to call WithRank function, ", "input[overlaps] rank must be 2, but got rank[",
|
||||
op.GetInputDescByName("overlaps").GetShape().GetDimNum(), "]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(op.GetInputDescByName("scores"), 1, scores_shape, op) != GRAPH_SUCCESS) {
|
||||
std::string err_msg =
|
||||
ConcatString("failed to call WithRank function, ", "input[scores] rank must be 1, but got rank[",
|
||||
op.GetInputDescByName("scores").GetShape().GetDimNum(), "]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(op.GetInputDescByName("max_output_size"), 0, max_output_size_shape, op) != GRAPH_SUCCESS) {
|
||||
std::string err_msg =
|
||||
ConcatString("failed to call WithRank function, ", "input[max_output_size] rank must be 0, but got rank[",
|
||||
op.GetInputDescByName("max_output_size").GetShape().GetDimNum(), "]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(op.GetInputDescByName("overlap_threshold"), 0, overlap_threshold_shape, op) != GRAPH_SUCCESS) {
|
||||
std::string err_msg =
|
||||
ConcatString("failed to call WithRank function, ", "input[overlap_threshold] rank must be 0, but got rank[",
|
||||
op.GetInputDescByName("overlap_threshold").GetShape().GetDimNum(), "]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(op.GetInputDescByName("score_threshold"), 0, score_threshold_shape, op) != GRAPH_SUCCESS) {
|
||||
std::string err_msg =
|
||||
ConcatString("failed to call WithRank function, ", "input[score_threshold] rank must be 0, but got rank[",
|
||||
op.GetInputDescByName("score_threshold").GetShape().GetDimNum(), "]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
int64_t unused_dim = 0;
|
||||
if (Merge(overlaps_shape.GetDim(0), scores_shape.GetDim(0), unused_dim) != GRAPH_SUCCESS) {
|
||||
std::string err_msg =
|
||||
ConcatString("failed to call Merge function to merge the input[overlaps] 0th dim", "[", overlaps_shape.GetDim(0),
|
||||
"] and the input[scores]'s 0th dim [", scores_shape.GetDim(0), "]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (Merge(overlaps_shape.GetDim(0), overlaps_shape.GetDim(1), unused_dim) != GRAPH_SUCCESS) {
|
||||
std::string err_msg =
|
||||
ConcatString("failed to call Merge function to merge the input[overlaps] 0th dim", "[", overlaps_shape.GetDim(0),
|
||||
"] and the input[overlaps]'s 1th dim [", overlaps_shape.GetDim(1), "]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
TensorDesc selected_indices_desc = op.GetOutputDescByName("selected_indices");
|
||||
Shape selecte_indices_shape;
|
||||
Vector(ge::UNKNOWN_DIM, selecte_indices_shape);
|
||||
selected_indices_desc.SetDataType(DT_INT32);
|
||||
selected_indices_desc.SetShape(selecte_indices_shape);
|
||||
if (op.UpdateOutputDesc("selected_indices", selected_indices_desc) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("update output[selected_indices] desc failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(NonMaxSuppressionWithOverlaps, NonMaxSuppressionWithOverlapsInfer);
|
||||
// ----------------NonMaxSuppressionWithOverlaps END-------------------
|
||||
|
||||
// ----------------ScaleAndTranslate-------------------
|
||||
IMPLEMT_INFERFUNC(ScaleAndTranslate, ScaleAndTranslateInfer) {
|
||||
TensorDesc desc = op.GetOutputDescByName("y");
|
||||
desc.SetDataType(DT_FLOAT);
|
||||
if (op.UpdateOutputDesc("y", desc) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("update description for output[y] failed."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return ResizeShapeFn(op, "images", "size", "y");
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(ScaleAndTranslate, ScaleAndTranslateInfer);
|
||||
// ----------------ScaleAndTranslate END-------------------
|
||||
|
||||
// ----------------ScaleAndTranslateGrad-------------------
|
||||
IMPLEMT_INFERFUNC(ScaleAndTranslateGrad, ScaleAndTranslateGradInfer) {
|
||||
TensorDesc desc = op.GetOutputDescByName("y");
|
||||
Format input_format = static_cast<ge::Format>(ge::GetPrimaryFormat(op.GetInputDesc(0).GetFormat()));
|
||||
vector<int64_t> grads_shape = op.GetInputDesc(0).GetShape().GetDims();
|
||||
vector<int64_t> org_images_shape = op.GetInputDesc(1).GetShape().GetDims();
|
||||
vector<int64_t> y_shape;
|
||||
if (input_format == FORMAT_NHWC && grads_shape.size() > 3 && org_images_shape.size() > 2) {
|
||||
y_shape.push_back(grads_shape[0]);
|
||||
y_shape.push_back(org_images_shape[1]);
|
||||
y_shape.push_back(org_images_shape[2]);
|
||||
y_shape.push_back(grads_shape[3]);
|
||||
} else if (input_format == FORMAT_NCHW && grads_shape.size() > 1 && org_images_shape.size() > 3) {
|
||||
y_shape.push_back(grads_shape[0]);
|
||||
y_shape.push_back(grads_shape[1]);
|
||||
y_shape.push_back(org_images_shape[2]);
|
||||
y_shape.push_back(org_images_shape[3]);
|
||||
} else {
|
||||
if (grads_shape.size() < 4) {
|
||||
std::string err_msg =
|
||||
ConcatString("the 0th input[grads]'s rank should not be less than 4, ", "current rank is ", grads_shape.size());
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (org_images_shape.size() < 2) {
|
||||
std::string err_msg = ConcatString("the 1th input[original_images]'s rank should not be less than 2, ",
|
||||
"current rank is ", org_images_shape.size());
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
y_shape.push_back(grads_shape[0]);
|
||||
y_shape.push_back(org_images_shape[1]);
|
||||
y_shape.push_back(org_images_shape[2]);
|
||||
y_shape.push_back(grads_shape[3]);
|
||||
OP_LOGI(TbeGetName(op).c_str(), "Real format is %d", input_format);
|
||||
}
|
||||
|
||||
desc.SetShape(ge::Shape(y_shape));
|
||||
desc.SetDataType(DT_FLOAT);
|
||||
return op.UpdateOutputDesc("y", desc);
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(ScaleAndTranslateGrad, ScaleAndTranslateGradInfer);
|
||||
// ----------------ScaleAndTranslateGrad END-------------------
|
||||
|
||||
// ----------------ResizeBicubicGrad-------------------
|
||||
IMPLEMT_INFERFUNC(ResizeBicubicGrad, ResizeBicubicGradInfer) {
|
||||
TensorDesc desc = op.GetOutputDescByName("y");
|
||||
Format input_format = static_cast<ge::Format>(ge::GetPrimaryFormat(op.GetInputDesc(0).GetFormat()));
|
||||
vector<int64_t> grads_shape = op.GetInputDesc(0).GetShape().GetDims();
|
||||
vector<int64_t> org_images_shape = op.GetInputDesc(1).GetShape().GetDims();
|
||||
vector<int64_t> y_shape;
|
||||
if (input_format == FORMAT_NHWC && grads_shape.size() > 3 && org_images_shape.size() > 2) {
|
||||
y_shape.push_back(grads_shape[0]);
|
||||
y_shape.push_back(org_images_shape[1]);
|
||||
y_shape.push_back(org_images_shape[2]);
|
||||
y_shape.push_back(grads_shape[3]);
|
||||
} else if (input_format == FORMAT_NCHW && grads_shape.size() > 1 && org_images_shape.size() > 3) {
|
||||
y_shape.push_back(grads_shape[0]);
|
||||
y_shape.push_back(grads_shape[1]);
|
||||
y_shape.push_back(org_images_shape[2]);
|
||||
y_shape.push_back(org_images_shape[3]);
|
||||
} else {
|
||||
std::string str_input_format = ge::TypeUtils::FormatToSerialString(input_format);
|
||||
std::string err_msg = ConcatString("only supporting NCHW and NHWC, current format is [", str_input_format, "]");
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
}
|
||||
desc.SetShape(ge::Shape(y_shape));
|
||||
auto type = op.GetInputDesc(1).GetDataType();
|
||||
desc.SetDataType(type);
|
||||
return op.UpdateOutputDesc("y", desc);
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(ResizeBicubicGrad, ResizeBicubicGradInfer);
|
||||
// ----------------ResizeBicubicGrad END-------------------
|
||||
} // namespace ge
|
|
@ -1,44 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef CUSTOMIZE_OP_PROTO_INC_TRACE_GRAD_OP_H
|
||||
#define CUSTOMIZE_OP_PROTO_INC_TRACE_GRAD_OP_H
|
||||
|
||||
#include "op_proto_macro.h"
|
||||
|
||||
namespace ge {
|
||||
/**
|
||||
* @brief Computes the grad of x1 in trace. \n
|
||||
|
||||
* @par Inputs:
|
||||
* Four inputs, including:
|
||||
* @li y_grad: A tensor. \n
|
||||
* @li x_shape: A tensor. Must be one of the following types:
|
||||
* int32, int64. \n
|
||||
|
||||
* @par Outputs:
|
||||
* x_grad: A Tensor with the same type and shape of y_grad's. \n
|
||||
|
||||
* @par Third-party framework compatibility
|
||||
* Compatible with the Pytorch operator Trace Backward. \n
|
||||
*/
|
||||
REG_CUST_OP(TraceGrad)
|
||||
.INPUT(y_grad, TensorType::BasicType())
|
||||
.INPUT(x_shape, TensorType({DT_INT32, DT_INT64}))
|
||||
.OUTPUT(x_grad, TensorType::BasicType())
|
||||
.CUST_OP_END_FACTORY_REG(TraceGrad)
|
||||
} // namespace ge
|
||||
#endif
|
|
@ -404,4 +404,92 @@ CUST_IMPLEMT_VERIFIER(LuSolve, LuSolveVerify) {
|
|||
CUST_COMMON_INFER_FUNC_REG(LuSolve, LuSolveInferShape);
|
||||
CUST_VERIFY_FUNC_REG(LuSolve, LuSolveVerify);
|
||||
// -----------------------LuSolve END---------------------------------
|
||||
|
||||
// -----------------------Qr---------------------------------
|
||||
IMPLEMT_INFERFUNC(Qr, QrInfer) {
|
||||
auto tensor = op.get_input_desc_x();
|
||||
Shape input;
|
||||
if (WithRankAtLeast(tensor, 2, input, op) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Shape batch_shape;
|
||||
if (SubShape(input, 0, -2, 1, batch_shape, op) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int dim_num = input.GetDimNum();
|
||||
int m = input.GetDim(dim_num - 2);
|
||||
int n = input.GetDim(dim_num - 1);
|
||||
Shape q_shape;
|
||||
Shape r_shape;
|
||||
auto full_matrices = op.get_attr_full_matrices();
|
||||
|
||||
if (full_matrices) {
|
||||
// [...,M,M]; [...,M,N], if full_matrices is true
|
||||
Shape m_m_shape;
|
||||
Shape m_n_shape;
|
||||
Matrix(m, m, m_m_shape);
|
||||
Matrix(m, n, m_n_shape);
|
||||
|
||||
Concatenate(batch_shape, m_m_shape, q_shape);
|
||||
Concatenate(batch_shape, m_n_shape, r_shape);
|
||||
} else {
|
||||
// [...,M,P]; [...,P,N], if full_matrices is false
|
||||
int p = m > n ? n : m;
|
||||
Shape m_p_shape;
|
||||
Shape p_n_shape;
|
||||
Matrix(m, p, m_p_shape);
|
||||
Matrix(p, n, p_n_shape);
|
||||
|
||||
Concatenate(batch_shape, m_p_shape, q_shape);
|
||||
Concatenate(batch_shape, p_n_shape, r_shape);
|
||||
}
|
||||
|
||||
DataType type = op.GetInputDescByName("x").GetDataType();
|
||||
TensorDesc q_desc = op.GetOutputDescByName("q");
|
||||
q_desc.SetShape(Shape(q_shape));
|
||||
q_desc.SetDataType(type);
|
||||
if (op.UpdateOutputDesc("q", q_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Update q desc failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
TensorDesc r_desc = op.GetOutputDescByName("r");
|
||||
r_desc.SetShape(Shape(r_shape));
|
||||
r_desc.SetDataType(type);
|
||||
if (op.UpdateOutputDesc("r", r_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Update r desc failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(Qr, QrInfer);
|
||||
// -----------------------Qr END---------------------------------
|
||||
|
||||
// -----------------------CholeskyGrad---------------------------------
|
||||
IMPLEMT_INFERFUNC(CholeskyGrad, CholeskyGradInfer) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto x_desc = op_desc->MutableInputDesc(0);
|
||||
|
||||
GeShape y_shape;
|
||||
if (MakeBatchSquareMatrix(x_desc, y_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(),
|
||||
"Op CholeskyGrad first input x tensor make batch square matrix "
|
||||
"failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
DataType type = x_desc->GetDataType();
|
||||
auto y_desc = op_desc->MutableOutputDesc(0);
|
||||
y_desc->SetShape(y_shape);
|
||||
y_desc->SetDataType(type);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(CholeskyGrad, CholeskyGradInfer);
|
||||
// -----------------------CholeskyGrad END---------------------------------
|
||||
} // namespace ge
|
|
@ -5,9 +5,11 @@
|
|||
*/
|
||||
|
||||
#include "inc/ops/math_ops.h"
|
||||
#include "inc/ops/ragged_math_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
#include "utils/reduce_infer_util.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------ComplexAbs-------------------
|
||||
|
@ -212,4 +214,113 @@ IMPLEMT_INFERFUNC(IsInf, IsInfInfer) {
|
|||
|
||||
INFER_FUNC_REG(IsInf, IsInfInfer);
|
||||
// ----------------IsInf END------------------------
|
||||
|
||||
// ----------------ReduceOp-------------------
|
||||
static bool InferReduceShapeProcess(const Operator &op, const int64_t input_x_idx, const int64_t output_y_idx,
|
||||
const int64_t input_axes_idx) {
|
||||
bool keep_dims = false;
|
||||
op.GetAttr("keep_dims", keep_dims);
|
||||
reduce_ops::CommonReduceInferWithInputAxes(op, input_x_idx, output_y_idx, input_axes_idx, keep_dims);
|
||||
return true;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(TypicalReduceInferShape) {
|
||||
OP_LOGD(TbeGetName(op), "Enter %s InferShape", TbeGetOpType(op).c_str());
|
||||
const int64_t input_x_idx = 0;
|
||||
const int64_t output_y_idx = 0;
|
||||
const int64_t input_axes_idx = 1;
|
||||
if (InferReduceShapeProcess(op, input_x_idx, output_y_idx, input_axes_idx)) {
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(ReduceSum, TypicalReduceInferShape);
|
||||
// ----------------ReduceOp END-------------------
|
||||
|
||||
// ----------------RaggedRange-------------------
|
||||
IMPLEMT_INFERFUNC(RaggedRange, RaggedRangeInfer) {
|
||||
Shape starts;
|
||||
Shape limits;
|
||||
Shape deltas;
|
||||
if (WithRankAtMost(op.GetInputDesc(0), 1, starts, op) != GRAPH_SUCCESS) {
|
||||
std::string err_msg =
|
||||
ConcatString("failed to call WithRankAtMost function, ", "input[starts] rank must be at most 1D, got rank[",
|
||||
op.GetInputDesc(0).GetShape().GetDimNum(), "]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRankAtMost(op.GetInputDesc(1), 1, limits, op) != GRAPH_SUCCESS) {
|
||||
std::string err_msg =
|
||||
ConcatString("failed to call WithRankAtMost function, ", "input[limits] rank must be at most 1D, got rank[",
|
||||
op.GetInputDesc(1).GetShape().GetDimNum(), "]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRankAtMost(op.GetInputDesc(2), 1, deltas, op) != GRAPH_SUCCESS) {
|
||||
std::string err_msg =
|
||||
ConcatString("failed to call WithRankAtMost function, input[deltas] ", "rank must be at most 1D, got rank[",
|
||||
op.GetInputDesc(2).GetShape().GetDimNum(), "]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int64_t dim = ge::UNKNOWN_DIM;
|
||||
int64_t starts_dim = starts.GetDim(0);
|
||||
int64_t limits_dim = limits.GetDim(0);
|
||||
int64_t deltas_dim = deltas.GetDim(0);
|
||||
if (op.GetInputDesc(0).GetShape().GetDimNum() == 1) {
|
||||
if (Merge(starts_dim, dim, dim) != GRAPH_SUCCESS) {
|
||||
std::string err_msg = ConcatString("failed to call Merge function, the 0th dim[", starts_dim,
|
||||
"] of input[starts] not equal UNKNOWN_DIM");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
if (op.GetInputDesc(1).GetShape().GetDimNum() == 1) {
|
||||
if (Merge(limits_dim, dim, dim) != GRAPH_SUCCESS) {
|
||||
std::string err_msg = ConcatString("failed to call Merge function, the 0th dim[", limits_dim,
|
||||
"] of input[limits] not equal UNKNOWN_DIM");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
if (op.GetInputDesc(2).GetShape().GetDimNum() == 1) {
|
||||
if (Merge(deltas_dim, dim, dim) != GRAPH_SUCCESS) {
|
||||
std::string err_msg = ConcatString("failed to call Merge function, the 0th dim[", deltas_dim,
|
||||
"] of input[deltas] not equal UNKNOWN_DIM");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t rt_nested_splits_dim = ge::UNKNOWN_DIM;
|
||||
if (dim != ge::UNKNOWN_DIM) {
|
||||
rt_nested_splits_dim = dim + 1;
|
||||
} else if (op.GetInputDesc(0).GetShape().GetDimNum() == 0 && op.GetInputDesc(1).GetShape().GetDimNum() == 0 &&
|
||||
op.GetInputDesc(2).GetShape().GetDimNum() == 0) {
|
||||
rt_nested_splits_dim = 2;
|
||||
}
|
||||
|
||||
DataType Tsplits_type;
|
||||
if (op.GetAttr("Tsplits", Tsplits_type) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("get attr[Tsplits] failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc rt_nested_desc = op.GetOutputDescByName("rt_nested_splits");
|
||||
rt_nested_desc.SetShape(Shape({rt_nested_splits_dim}));
|
||||
rt_nested_desc.SetDataType(Tsplits_type);
|
||||
(void)op.UpdateOutputDesc("rt_nested_splits", rt_nested_desc);
|
||||
|
||||
DataType T_type = op.GetInputDescByName("starts").GetDataType();
|
||||
std::vector<int64_t> unknow_dim_vec(1, UNKNOWN_DIM);
|
||||
TensorDesc dense_desc = op.GetOutputDescByName("rt_dense_values");
|
||||
dense_desc.SetShape(Shape(unknow_dim_vec));
|
||||
dense_desc.SetDataType(T_type);
|
||||
(void)op.UpdateOutputDesc("rt_dense_values", dense_desc);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(RaggedRange, RaggedRangeInfer);
|
||||
// ----------------RaggedRange END-------------------
|
||||
} // namespace ge
|
||||
|
|
|
@ -159,4 +159,162 @@ CUST_COMMON_INFER_FUNC_REG(MatrixLogarithm, MatrixLogarithmInferShaper);
|
|||
// ----------------MatrixExp-------------------
|
||||
CUST_COMMON_INFER_FUNC_REG(MatirxExp, OneInOneOutCommonInferShape);
|
||||
// ----------------MatrixExp END-------------------
|
||||
|
||||
// ----------------TraceGrad Begin------------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(TraceGradInferShape) {
|
||||
Shape shape = op.GetInputDescByName("y_grad").GetShape();
|
||||
DataType input_dtype = op.GetInputDescByName("y_grad").GetDataType();
|
||||
std::vector<std::string> input_infer_depends = {"x_shape"};
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
op_desc->SetOpInferDepends(input_infer_depends);
|
||||
Tensor tensor_input;
|
||||
Shape output_shape;
|
||||
if (op.GetInputConstData("x_shape", tensor_input) == GRAPH_SUCCESS) {
|
||||
MakeShapeFromShapeTensor(tensor_input, output_shape, op);
|
||||
} else {
|
||||
output_shape = Shape({UNKNOWN_RANK});
|
||||
}
|
||||
TensorDesc td = op.GetOutputDescByName("x_grad");
|
||||
td.SetShape(output_shape);
|
||||
td.SetDataType(input_dtype);
|
||||
td.SetFormat(FORMAT_ND);
|
||||
(void)op.UpdateOutputDesc("x_grad", td);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
CUST_IMPLEMT_VERIFIER(TraceGrad, TraceGradVerify) {
|
||||
DataType x_shape_dtype = op.GetInputDescByName("x_shape").GetDataType();
|
||||
if ((x_shape_dtype != DT_INT32) && (x_shape_dtype != DT_INT64)) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(TraceGrad, TraceGradInferShape);
|
||||
CUST_VERIFY_FUNC_REG(TraceGrad, TraceGradVerify);
|
||||
// ---------------TraceGrad END-------------------------------
|
||||
|
||||
// ---------------Tril--------------
|
||||
IMPLEMT_COMMON_INFERFUNC(TrilInferShape) {
|
||||
Shape input_shape = op.GetInputDesc(0).GetShape();
|
||||
DataType input_dtype = op.GetInputDesc(0).GetDataType();
|
||||
TensorDesc td = op.GetOutputDesc(0);
|
||||
td.SetShape(ge::Shape(input_shape));
|
||||
td.SetDataType(input_dtype);
|
||||
(void)op.UpdateOutputDesc("y", td);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_VERIFIER(Tril, TrilVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
INFER_FUNC_REG(Tril, TrilInferShape);
|
||||
VERIFY_FUNC_REG(Tril, TrilVerify);
|
||||
// ----------------Tril END----------------
|
||||
|
||||
// -----------------ScatterNdUpdate-----------------
|
||||
IMPLEMT_VERIFIER(ScatterNdUpdate, ScatterNdUpdateVerify) {
|
||||
if (!CheckTwoInputDtypeSame(op, "var", "updates")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(ScatterNdUpdateInferShape) {
|
||||
// main part of shape infer
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
ge::GeShape var_shape = op_desc->MutableInputDesc("var")->GetShape();
|
||||
std::vector<std::pair<int64_t, int64_t>> var_shape_range;
|
||||
op_desc->MutableInputDesc("var")->GetShapeRange(var_shape_range);
|
||||
DataType input_dtype = op_desc->MutableInputDesc("var")->GetDataType();
|
||||
GeTensorDescPtr td = op_desc->MutableOutputDesc("var");
|
||||
td->SetShape(var_shape);
|
||||
td->SetDataType(input_dtype);
|
||||
td->SetShapeRange(var_shape_range);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
COMMON_INFER_FUNC_REG(ScatterNdUpdate, ScatterNdUpdateInferShape);
|
||||
VERIFY_FUNC_REG(ScatterNdUpdate, ScatterNdUpdateVerify);
|
||||
// -------------------ScatterNdUpdate END----------------
|
||||
|
||||
// -----------------TensorScatterUpdate-----------------
|
||||
IMPLEMT_VERIFIER(TensorScatterUpdate, TensorScatterUpdateVerify) {
|
||||
if (!CheckTwoInputDtypeSame(op, "x", "updates")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
DataType indices_dtype = op.GetInputDescByName("indices").GetDataType();
|
||||
if ((indices_dtype != DT_INT32) && (indices_dtype != DT_INT64)) {
|
||||
OP_LOGE("tensor_scatter_update", "The indices type is not int32 or int64, please check!");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_INFERFUNC(TensorScatterUpdate, TensorScatterUpdateInferShape) {
|
||||
Shape var_shape = op.GetInputDescByName("x").GetShape();
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
TensorDesc td = op.GetOutputDescByName("y");
|
||||
td.SetShape(ge::Shape(var_shape));
|
||||
td.SetDataType(input_dtype);
|
||||
(void)op.UpdateOutputDesc("y", td);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(TensorScatterUpdate, TensorScatterUpdateInferShape);
|
||||
VERIFY_FUNC_REG(TensorScatterUpdate, TensorScatterUpdateVerify);
|
||||
// -------------------TensorScatterUpdate END----------------
|
||||
|
||||
// -------------------Orgqr----------------
|
||||
CUST_COMMON_INFER_FUNC_REG(Orgqr, OneInOneOutCommonInferShape);
|
||||
// -------------------Orgqr END----------------
|
||||
|
||||
// -----------------------Trace-----------------------
|
||||
static bool InferShapeAndTypeTrace(Operator &op, const std::string &inputName, const std::string outputName) {
|
||||
TensorDesc vOutputDesc = op.GetOutputDescByName(outputName.c_str());
|
||||
DataType inputDtype = op.GetInputDescByName(inputName.c_str()).GetDataType();
|
||||
ge::Shape inputShape = op.GetInputDescByName(inputName.c_str()).GetShape();
|
||||
|
||||
// set output tensor dim
|
||||
std::vector<int64_t> dimVec;
|
||||
ge::Shape outputShape = ge::Shape(dimVec);
|
||||
vOutputDesc.SetShape(outputShape);
|
||||
const std::vector<DataType> unchange_dtype = {DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT,
|
||||
DT_FLOAT16, DT_INT64, DT_UINT64};
|
||||
if (CheckInputDataType(op, "x", unchange_dtype) == true) {
|
||||
vOutputDesc.SetDataType(inputDtype);
|
||||
} else {
|
||||
vOutputDesc.SetDataType(ge::DT_INT64);
|
||||
}
|
||||
op.UpdateOutputDesc(outputName.c_str(), vOutputDesc);
|
||||
return true;
|
||||
}
|
||||
|
||||
IMPLEMT_VERIFIER(Trace, TraceVerify) {
|
||||
AscendString op_name;
|
||||
CHECK(op.GetName(op_name) != GRAPH_SUCCESS, OP_LOGE("", "GetName failed."), return GRAPH_FAILED);
|
||||
ge::Shape shapeX = op.GetInputDescByName("x").GetShape();
|
||||
DataType dtypeX = op.GetInputDescByName("x").GetDataType();
|
||||
constexpr int64_t shapeDimsLimit = 2;
|
||||
if (shapeX.GetDims() != UNKNOWN_RANK && shapeX.GetDimNum() != shapeDimsLimit) {
|
||||
OP_LOGE(op_name.GetString(), "the input shape must be 2-D matrix.\n");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
const std::vector<DataType> support_list = {DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16,
|
||||
DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32,
|
||||
DT_UINT32, DT_INT64, DT_UINT64};
|
||||
if (CheckInputDataType(op, "x", support_list) == false) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "dataType [%s] is not supported in Trace.", DTypeStr(dtypeX).c_str());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(TraceInferShape) {
|
||||
if (InferShapeAndTypeTrace(op, "x", "y")) {
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(Trace, TraceInferShape);
|
||||
VERIFY_FUNC_REG(Trace, TraceVerify);
|
||||
// ---------------------Trace END----------------------
|
||||
} // namespace ge
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright (c) 2023 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "custom_op_proto/cust_array_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
// ---------------Meshgrid-------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(MeshgridInfer) {
|
||||
auto input_size = op.GetInputsSize();
|
||||
if (input_size == 0) {
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
std::vector<int64_t> out_dims = std::vector<int64_t>();
|
||||
auto out_dtype = op.GetDynamicInputDesc("x", 0).GetDataType();
|
||||
for (size_t i = 0; i < op.GetInputsSize(); ++i) {
|
||||
auto in_shape = op.GetInputDesc(i).GetShape();
|
||||
if (in_shape.GetDims() == UNKNOWN_RANK) {
|
||||
out_dims = ge::UNKNOWN_RANK;
|
||||
break;
|
||||
}
|
||||
out_dims.push_back(in_shape.GetDim(0));
|
||||
}
|
||||
std::string indexing;
|
||||
if (op.GetAttr("indexing", indexing) == GRAPH_SUCCESS) {
|
||||
if (indexing == "xy") {
|
||||
std::swap(out_dims[0], out_dims[1]);
|
||||
}
|
||||
} else {
|
||||
OP_LOGW(TbeGetName(op).c_str(), "get attr indexing failed.");
|
||||
}
|
||||
Shape out_shape = Shape(out_dims);
|
||||
for (size_t i = 0; i < op.GetInputsSize(); ++i) {
|
||||
TensorDesc output_desc = op.GetDynamicOutputDesc("y", i);
|
||||
output_desc.SetShape(out_shape);
|
||||
output_desc.SetDataType(out_dtype);
|
||||
op.UpdateDynamicOutputDesc("y", i, output_desc);
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
CUST_INFER_FUNC_REG(Meshgrid, MeshgridInfer);
|
||||
// ---------------Meshgrid End---------------
|
||||
} // namespace ge
|
|
@ -11,6 +11,309 @@
|
|||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
// ---------------AdaptiveAvgPool2D-------------------
|
||||
CUST_IMPLEMT_INFERFUNC(AdaptiveAvgPool2D, AdaptiveAvgPool2dInferShape) {
|
||||
OP_LOGI(TbeGetName(op).c_str(), " AdaptiveAvgPool2d inferShape begin!");
|
||||
const size_t DIM_SIZE2 = 2;
|
||||
auto input_tensor_desc = op.GetInputDescByName("x");
|
||||
auto shape = input_tensor_desc.GetShape();
|
||||
// get output_size
|
||||
std::vector<int64_t> ouput_size_list;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("output_size", ouput_size_list)) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "GetOpAttr ouput_size_list failed!");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
// check output size
|
||||
if (ouput_size_list.size() != DIM_SIZE2) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "length of output_size must be 2");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::vector<int64_t> dims_input = shape.GetDims();
|
||||
// set output shape
|
||||
std::vector<int64_t> dim_vector;
|
||||
for (size_t i = 0; i < dims_input.size(); i++) {
|
||||
int64_t dims = dims_input[i];
|
||||
dim_vector.push_back(dims);
|
||||
}
|
||||
size_t index0 = dims_input.size() - 2;
|
||||
size_t index1 = dims_input.size() - 1;
|
||||
if (ouput_size_list[0] > 0) {
|
||||
dim_vector[index0] = ouput_size_list[0];
|
||||
}
|
||||
if (ouput_size_list[1] > 0) {
|
||||
dim_vector[index1] = ouput_size_list[1];
|
||||
}
|
||||
TensorDesc td = op.GetOutputDescByName("y");
|
||||
DataType input_dtype = input_tensor_desc.GetDataType();
|
||||
Shape output_shape(dim_vector);
|
||||
td.SetShape(output_shape);
|
||||
td.SetDataType(input_dtype);
|
||||
(void)op.UpdateOutputDesc("y", td);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(AdaptiveAvgPool2D, AdaptiveAvgPool2dVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
CUST_INFER_FUNC_REG(AdaptiveAvgPool2D, AdaptiveAvgPool2dInferShape);
|
||||
CUST_VERIFY_FUNC_REG(AdaptiveAvgPool2D, AdaptiveAvgPool2dVerify);
|
||||
// ---------------AdaptiveAvgPool2D End---------------
|
||||
|
||||
// ---------------AdaptiveAvgPool2DGrad-------------------
|
||||
CUST_IMPLEMT_INFERFUNC(AdaptiveAvgPool2DGrad, AdaptiveAvgPool2dGradInferShape) {
|
||||
std::vector<std::string> input_infer_depends = {"orig_input_shape"};
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
op_desc->SetOpInferDepends(input_infer_depends);
|
||||
DataType input_dtype = op.GetInputDescByName("input_grad").GetDataType();
|
||||
Shape output_shape;
|
||||
Tensor orig_input_shape_tensor;
|
||||
if (op.GetInputConstData("orig_input_shape", orig_input_shape_tensor) != GRAPH_SUCCESS) {
|
||||
auto output_desc = op.GetOutputDescByName("output_grad");
|
||||
output_desc.SetDataType(input_dtype);
|
||||
output_desc.SetShape(Shape(ge::UNKNOWN_RANK));
|
||||
return op.UpdateOutputDesc("output_grad", output_desc);
|
||||
}
|
||||
MakeShapeFromShapeTensor(orig_input_shape_tensor, output_shape, op);
|
||||
TensorDesc output_grad = op.GetOutputDescByName("output_grad");
|
||||
output_grad.SetShape(output_shape);
|
||||
output_grad.SetDataType(input_dtype);
|
||||
return op.UpdateOutputDesc("output_grad", output_grad);
|
||||
}
|
||||
|
||||
CUST_INFER_FUNC_REG(AdaptiveAvgPool2DGrad, AdaptiveAvgPool2dGradInferShape);
|
||||
// ---------------AdaptiveAvgPool2DGrad END-------------------
|
||||
|
||||
// --------- AdaptiveAvgPool3d ---------------
|
||||
IMPLEMT_COMMON_INFERFUNC(AdaptiveAvgPool3dInferShape) {
|
||||
map<int, std::string> format2str = {
|
||||
{ge::FORMAT_NCHW, "NCHW"}, {ge::FORMAT_NHWC, "NHWC"}, {ge::FORMAT_HWCN, "HWCN"}, {ge::FORMAT_DHWNC, "DHWNC"},
|
||||
{ge::FORMAT_DHWCN, "DHWCN"}, {ge::FORMAT_NDHWC, "NDHWC"}, {ge::FORMAT_NCDHW, "NCDHW"}};
|
||||
|
||||
// verify the dim of output_size
|
||||
std::vector<int64_t> output_size;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("output_size", output_size)) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "GetOpAttr output_size failed!");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
ge::AscendString op_name;
|
||||
(void)op.GetName(op_name);
|
||||
auto input_desc = op.GetInputDescByName("x");
|
||||
TensorDesc out_desc = op.GetOutputDescByName("y");
|
||||
|
||||
// update data type
|
||||
DataType input_type = input_desc.GetDataType();
|
||||
out_desc.SetDataType(input_type);
|
||||
|
||||
// update format
|
||||
Format input_format = input_desc.GetFormat();
|
||||
std::string format_str = format2str[input_format];
|
||||
if (input_format != FORMAT_NCHW) {
|
||||
OP_LOGE("AdaptiveAvgPool3d",
|
||||
"Input format only support NCHW"
|
||||
", input format is [%s]",
|
||||
format_str.c_str());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out_desc.SetFormat(input_format);
|
||||
|
||||
std::vector<int64_t> input_size_shape = input_desc.GetShape().GetDims();
|
||||
auto input_size_dim_num = input_size_shape.size();
|
||||
std::vector<int64_t> output_shape(input_size_shape.begin(), input_size_shape.end());
|
||||
auto output_size_num = output_size.size();
|
||||
if (output_size_num == 1) {
|
||||
for (uint64_t i = input_size_dim_num - 3; i < input_size_dim_num; ++i) {
|
||||
if (output_size[0] < 0) {
|
||||
continue;
|
||||
}
|
||||
output_shape[i] = output_size[0];
|
||||
}
|
||||
} else if (output_size_num == 3) {
|
||||
for (uint64_t i = input_size_dim_num - 3; i < input_size_dim_num; ++i) {
|
||||
auto data = output_size[i - input_size_dim_num + 3];
|
||||
if (data < 0) {
|
||||
continue;
|
||||
}
|
||||
output_shape[i] = data;
|
||||
}
|
||||
} else {
|
||||
OP_LOGE("AdaptiveAvgPool3d", "Shape of output_size is invalid");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
out_desc.SetShape(Shape(output_shape));
|
||||
if (op.UpdateOutputDesc("y", out_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE("AdaptiveAvgPool3d", "failed to update output desc");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(AdaptiveAvgPool3d, AdaptiveAvgPool3dInferShape);
|
||||
// --------- AdaptiveAvgPool3d end---------------
|
||||
|
||||
// --------- AdaptiveAvgPool3dGrad ---------------
|
||||
CUST_IMPLEMT_VERIFIER(AdaptiveAvgPool3dGrad, AdaptiveAvgPool3dGradVerify) {
|
||||
auto input_grad_desc = op.GetInputDescByName("input_grad");
|
||||
auto orig_input_shape_desc = op.GetInputDescByName("orig_input_shape");
|
||||
ge::AscendString op_name;
|
||||
(void)op.GetName(op_name);
|
||||
|
||||
auto orig_input_shape_dim = orig_input_shape_desc.GetShape().GetDimNum();
|
||||
if (orig_input_shape_dim != 1) {
|
||||
OP_LOGE("AdaptiveAvgPool3dGrad", "Num Dim of orig_input_shape is invalid");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto orig_input_dim_num = orig_input_shape_desc.GetShape().GetShapeSize();
|
||||
auto input_grad_dim_num = input_grad_desc.GetShape().GetDimNum();
|
||||
|
||||
if (orig_input_dim_num != static_cast<int64_t>(input_grad_dim_num)) {
|
||||
OP_LOGE("AdaptiveAvgPool3dGrad", "Num Dim of orig_input and input_grad should be the same");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(AdaptiveAvgPool3dGradInferShape) {
|
||||
map<int, std::string> format2str = {
|
||||
{ge::FORMAT_NCHW, "NCHW"}, {ge::FORMAT_NHWC, "NHWC"}, {ge::FORMAT_HWCN, "HWCN"}, {ge::FORMAT_DHWNC, "DHWNC"},
|
||||
{ge::FORMAT_DHWCN, "DHWCN"}, {ge::FORMAT_NDHWC, "NDHWC"}, {ge::FORMAT_NCDHW, "NCDHW"}};
|
||||
|
||||
auto input_desc = op.GetInputDescByName("input_grad");
|
||||
auto orig_input_shape_desc = op.GetInputDescByName("orig_input_shape");
|
||||
TensorDesc out_desc = op.GetOutputDescByName("output_grad");
|
||||
ge::AscendString op_name;
|
||||
(void)op.GetName(op_name);
|
||||
|
||||
// update format
|
||||
Format input_format = input_desc.GetFormat();
|
||||
std::string format_str = format2str[input_format];
|
||||
if (input_format != FORMAT_NCHW) {
|
||||
OP_LOGE("AdaptiveAvgPool3dGrad",
|
||||
"Input format only support NCHW"
|
||||
", input format is [%s]",
|
||||
format_str.c_str());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out_desc.SetFormat(input_format);
|
||||
|
||||
// update data type
|
||||
DataType input_type = input_desc.GetDataType();
|
||||
out_desc.SetDataType(input_type);
|
||||
|
||||
// infer shape
|
||||
Tensor orig_input_size_tensor;
|
||||
if (op.GetInputConstData("orig_input_shape", orig_input_size_tensor) != GRAPH_SUCCESS) {
|
||||
OP_LOGE("AdaptiveAvgPool3dGrad", "failed to get tensor from output_size");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int32_t *orig_input_size_data = reinterpret_cast<int32_t *>(orig_input_size_tensor.GetData());
|
||||
if (orig_input_size_data == nullptr) {
|
||||
OP_LOGE("AdaptiveAvgPool3dGrad", "output_size data is invalid");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto input_size_dim_num = input_desc.GetShape().GetDimNum();
|
||||
std::vector<int64_t> output_shape(input_size_dim_num);
|
||||
|
||||
for (uint64_t i = 0; i < input_size_dim_num; ++i) {
|
||||
output_shape[i] = orig_input_size_data[i];
|
||||
}
|
||||
|
||||
out_desc.SetShape(Shape(output_shape));
|
||||
if (op.UpdateOutputDesc("output_grad", out_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE("AdaptiveAvgPool3dGrad", "failed to update output desc");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(AdaptiveAvgPool3dGrad, AdaptiveAvgPool3dGradInferShape);
|
||||
CUST_VERIFY_FUNC_REG(AdaptiveAvgPool3dGrad, AdaptiveAvgPool3dGradVerify);
|
||||
// --------- AdaptiveAvgPool3dGrad end---------------
|
||||
|
||||
// --------- AdaptiveMaxPool3d---------------
|
||||
CUST_IMPLEMT_INFERFUNC(AdaptiveMaxPool3d, AdaptiveMaxPool3dInferShape) {
|
||||
TensorDesc input = op.GetInputDesc(0);
|
||||
TensorDesc output_size = op.GetInputDesc(1);
|
||||
TensorDesc output = op.GetOutputDesc(0);
|
||||
TensorDesc argmax = op.GetOutputDesc(1);
|
||||
|
||||
const size_t input_num_dims = input.GetShape().GetDimNum();
|
||||
const std::vector<int64_t> output_size_shape = output_size.GetShape().GetDims();
|
||||
if ((input_num_dims == 4 || input_num_dims == 5) == false) {
|
||||
OP_LOGE(TbeGetName(op), "Input dimensions must be equal to 4 or 5.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (output_size_shape.size() != 1) {
|
||||
OP_LOGE(TbeGetName(op), "output_size dim should be equal to 1.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (output_size_shape[0] != 3) {
|
||||
OP_LOGE(TbeGetName(op), "output_size shape[0] should be equal to 3.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
DataType input_dtype = input.GetDataType();
|
||||
Shape output_shape(UNKNOWN_SHAPE);
|
||||
output.SetDataType(input_dtype);
|
||||
output.SetShape(output_shape);
|
||||
argmax.SetDataType(DT_INT32);
|
||||
argmax.SetShape(output_shape);
|
||||
(void)op.UpdateOutputDesc("y", output);
|
||||
(void)op.UpdateOutputDesc("argmax", argmax);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(AdaptiveMaxPool3d, AdaptiveMaxPool3dVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
CUST_INFER_FUNC_REG(AdaptiveMaxPool3d, AdaptiveMaxPool3dInferShape);
|
||||
CUST_VERIFY_FUNC_REG(AdaptiveMaxPool3d, AdaptiveMaxPool3dVerify);
|
||||
// --------- AdaptiveMaxPool3d END---------------
|
||||
|
||||
// --------- AdaptiveMaxPool2dGrad---------------
|
||||
CUST_IMPLEMT_INFERFUNC(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradInferShape) {
|
||||
TensorDesc input_grad = op.GetOutputDescByName("x_grad");
|
||||
TensorDesc input = op.GetInputDescByName("x");
|
||||
DataType input_dtype = input.GetDataType();
|
||||
Shape input_shape = input.GetShape();
|
||||
input_grad.SetShape(input_shape);
|
||||
input_grad.SetDataType(input_dtype);
|
||||
if (op.UpdateOutputDesc("x_grad", input_grad) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
CUST_INFER_FUNC_REG(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradInferShape);
|
||||
CUST_VERIFY_FUNC_REG(AdaptiveMaxPool2dGrad, AdaptiveMaxPool2dGradVerify);
|
||||
// --------- AdaptiveMaxPool2dGrad END---------------
|
||||
|
||||
// --------- AdaptiveMaxPool3dGrad---------------
|
||||
CUST_IMPLEMT_INFERFUNC(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradInferShape) {
|
||||
TensorDesc output_grad = op.GetOutputDescByName("output_grad");
|
||||
TensorDesc input = op.GetInputDescByName("x");
|
||||
DataType input_dtype = input.GetDataType();
|
||||
Shape input_shape = input.GetShape();
|
||||
output_grad.SetShape(input_shape);
|
||||
output_grad.SetDataType(input_dtype);
|
||||
if (op.UpdateOutputDesc("output_grad", output_grad) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
CUST_INFER_FUNC_REG(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradInferShape);
|
||||
CUST_VERIFY_FUNC_REG(AdaptiveMaxPool3dGrad, AdaptiveMaxPool3dGradVerify);
|
||||
// --------- AdaptiveMaxPool3dGrad END---------------
|
||||
|
||||
// -------------------DataFormatVecPermute---------------------
|
||||
IMPLEMT_INFERFUNC(DataFormatVecPermute, DataFormatVecPermuteInfer) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
|
@ -284,4 +587,56 @@ CUST_INFER_FUNC_REG(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxInferShape);
|
|||
CUST_VERIFY_FUNC_REG(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxVerify);
|
||||
//-------------------MaxPool3DGradWithArgMax---------------------
|
||||
|
||||
//-------------------NthElement---------------------
|
||||
IMPLEMT_INFERFUNC(NthElement, NthElementInfer) {
|
||||
std::vector<std::string> input_infer_depends = {"n"};
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
op_desc->SetOpInferDepends(input_infer_depends);
|
||||
Shape x_shape;
|
||||
auto x_tensor = op.get_input_desc_x();
|
||||
if (WithRankAtLeast(x_tensor, 1, x_shape, op) != GRAPH_SUCCESS) {
|
||||
std::string err_msg =
|
||||
ConcatString("failed to call WithRankAtLeast function, ", "input[x] rank must be at least 1D, but got rank[",
|
||||
op.get_input_desc_x().GetShape().GetDimNum(), "]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Tensor n_tensor;
|
||||
int64_t n_dim = 0;
|
||||
if (op.GetInputConstData("n", n_tensor) != GRAPH_SUCCESS) {
|
||||
n_dim = ge::UNKNOWN_DIM;
|
||||
} else {
|
||||
if (MakeDimForScalarInput(n_tensor, n_dim, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), std::string("failed to call MakeDimForScalarInput function, "
|
||||
"get input n shape failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t existing = x_shape.GetDimNum();
|
||||
int64_t last_input_dim = x_shape.GetDim(existing - 1);
|
||||
if ((last_input_dim != ge::UNKNOWN_DIM) && (n_dim != ge::UNKNOWN_DIM) && (last_input_dim <= n_dim)) {
|
||||
std::string err_msg =
|
||||
ConcatString("input[x] last dim value[", last_input_dim, "] must be greater than [", n_dim, "]");
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Shape output_shape;
|
||||
if (SubShape(x_shape, 0, -1, 1, output_shape, op) != GRAPH_SUCCESS) {
|
||||
std::string err_msg = ConcatString("failed to call SubShape function, input[x] shape[",
|
||||
DebugString(x_shape.GetDims()), "], start[0], end[-1], stride[1]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
TensorDesc y_tensor = op.GetOutputDescByName("y");
|
||||
y_tensor.SetDataType(x_tensor.GetDataType());
|
||||
y_tensor.SetShape(output_shape);
|
||||
return op.UpdateOutputDesc("y", y_tensor);
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(NthElement, NthElementInfer);
|
||||
//-------------------NthElement END---------------------
|
||||
} // namespace ge
|
|
@ -0,0 +1,316 @@
|
|||
/**
|
||||
* Copyright (c) 2022-202 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
#include "inc/ops/pad_ops.h"
|
||||
#include "graph/utils/node_utils.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
#include "utils/error_util.h"
|
||||
#include "utils/op_log.h"
|
||||
#include "utils/op_const.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------Pad Op Begin-------------------
|
||||
static graphStatus PadInferShapeAndType(ge::Operator &op, std::vector<int64_t> &paddings) {
|
||||
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
|
||||
static const int64_t input_x_idx = 0;
|
||||
auto input_desc = op_info->MutableInputDesc(input_x_idx);
|
||||
const ge::GeShape &input_shape = input_desc->MutableShape();
|
||||
auto input_dtype = input_desc->GetDataType();
|
||||
static const int64_t output_y_idx = 0;
|
||||
auto output_desc = op_info->MutableOutputDesc(output_y_idx);
|
||||
ge::GeShape &output_shape = output_desc->MutableShape();
|
||||
output_desc->SetDataType(input_dtype);
|
||||
|
||||
// input shape is -2, output is -2
|
||||
if (input_shape.IsUnknownDimNum()) {
|
||||
output_desc->SetShape(input_shape);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
size_t dim_num = input_shape.GetDimNum();
|
||||
if (!input_shape.IsUnknownShape()) {
|
||||
// not dynamic shape, will output shape and dtype
|
||||
if (dim_num == 0) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("input shape cannot empty"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (dim_num * 2 != paddings.size()) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op),
|
||||
OtherErrMsg("the num of paddings must be double the input dim size"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
// calce the output shape
|
||||
output_shape.SetDimNum(dim_num);
|
||||
for (size_t dim = 0; dim < dim_num; dim++) {
|
||||
output_shape.SetDim(dim, input_shape.GetDim(dim) + paddings[dim * 2] + paddings[dim * 2 + 1]);
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// input shape is -1, will get the shape and range
|
||||
// calcu the output shape
|
||||
output_shape.SetDimNum(dim_num);
|
||||
for (size_t dim = 0; dim < dim_num; dim++) {
|
||||
if (input_shape.GetDim(dim) == -1) {
|
||||
output_shape.SetDim(dim, input_shape.GetDim(dim));
|
||||
} else {
|
||||
output_shape.SetDim(dim, input_shape.GetDim(dim) + paddings[dim * 2] + paddings[dim * 2 + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
// calcu the output range
|
||||
std::vector<std::pair<int64_t, int64_t>> input_range;
|
||||
input_desc->GetShapeRange(input_range);
|
||||
MakeUpShapeRange(input_shape, input_range);
|
||||
std::vector<std::pair<int64_t, int64_t>> output_range;
|
||||
for (size_t dim = 0; dim < dim_num; dim++) {
|
||||
auto range_min = input_range[dim].first + paddings[dim * 2] + paddings[dim * 2 + 1];
|
||||
auto range_max =
|
||||
input_range[dim].second == -1 ? -1 : input_range[dim].second + paddings[dim * 2] + paddings[dim * 2 + 1];
|
||||
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
|
||||
}
|
||||
output_desc->SetShapeRange(output_range);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(PadInferShape) {
|
||||
OP_LOGD(TbeGetName(op), "InferShape Begin.");
|
||||
const vector<string> depend_names = {"paddings"};
|
||||
PREPARE_DYNAMIC_SHAPE(depend_names);
|
||||
|
||||
// first get the padding const
|
||||
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
|
||||
// get const paddings data
|
||||
std::vector<int64_t> paddings;
|
||||
static const int64_t input_paddings_idx = 1;
|
||||
if (!ops::GetConstIntData(op, input_paddings_idx, paddings)) {
|
||||
OP_LOGW(TbeGetName(op), "the node paddings is not const node, will set the output dynamic");
|
||||
auto input_desc = op_info->MutableInputDesc("x");
|
||||
const ge::GeShape &input_shape = input_desc->MutableShape();
|
||||
DataType input_dtype = input_desc->GetDataType();
|
||||
auto output_desc = op_info->MutableOutputDesc("y");
|
||||
|
||||
// shape_x is UNKNOWN_RANK
|
||||
if (IsUnknownRankShape(input_shape)) {
|
||||
OP_LOGW(TbeGetName(op), "shape_x is UNKNOWN_RANK. Set output UNKNOWN_RANK");
|
||||
output_desc->SetShape(input_shape);
|
||||
output_desc->SetDataType(input_dtype);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
size_t dim_num = input_shape.GetDimNum();
|
||||
// shape_x is UNKNOWN_DIM
|
||||
if (dim_num == 0) {
|
||||
dim_num = 1;
|
||||
}
|
||||
vector<int64_t> out_shape(dim_num, -1);
|
||||
std::vector<std::pair<int64_t, int64_t>> output_range;
|
||||
input_desc->GetShapeRange(output_range);
|
||||
MakeUpShapeRange(out_shape, output_range);
|
||||
for (size_t i = 0; i < dim_num; i++) {
|
||||
output_range[i].second = -1;
|
||||
}
|
||||
output_desc->SetShape(GeShape(out_shape));
|
||||
output_desc->SetDataType(input_dtype);
|
||||
output_desc->SetShapeRange(output_range);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
return PadInferShapeAndType(op, paddings);
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(Pad, PadInferShape);
|
||||
// ----------------Pad Op End-------------------
|
||||
// ----------------PadV3 Op Begin-------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(PadV3InferShape) {
|
||||
const vector<string> depend_names = {"paddings"};
|
||||
PREPARE_DYNAMIC_SHAPE(depend_names);
|
||||
Tensor paddings_tensor;
|
||||
if (ge::GRAPH_SUCCESS != op.GetInputConstData("paddings", paddings_tensor)) {
|
||||
OP_LOGW(TbeGetName(op).c_str(), "Get Const Value [paddings] failed, Setting shape to UNKNOWN_DIM");
|
||||
Shape shape_x = op.GetInputDescByName("x").GetShape();
|
||||
vector<int64_t> shape;
|
||||
for (size_t dim = 0; dim < shape_x.GetDimNum(); dim++) {
|
||||
shape.push_back(UNKNOWN_DIM);
|
||||
}
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
TensorDesc tensordesc_output = op.GetOutputDescByName("y");
|
||||
Shape out_shape(shape);
|
||||
tensordesc_output.SetShape(out_shape);
|
||||
tensordesc_output.SetDataType(input_dtype);
|
||||
(void)op.UpdateOutputDesc("y", tensordesc_output);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
DataType dtype = op.GetInputDescByName("paddings").GetDataType();
|
||||
|
||||
std::vector<int64_t> paddings;
|
||||
if (!GetConstValue(op, paddings_tensor, dtype, paddings)) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Get Const Value [paddings] failed ");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
bool paddings_contiguous = true;
|
||||
if (op.GetAttr("paddings_contiguous", paddings_contiguous) == GRAPH_FAILED) {
|
||||
OP_LOGI(TbeGetName(op).c_str(), "Get attr [paddings_contiguous] failed");
|
||||
}
|
||||
|
||||
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto input_desc = op_info->MutableInputDesc("x");
|
||||
auto input_shape = input_desc->MutableShape().GetDims();
|
||||
auto input_shape_max = input_shape.size();
|
||||
auto paddings_shape = paddings.size();
|
||||
|
||||
// expand paddings by 0
|
||||
auto expand_num = input_shape_max * 2 - paddings_shape;
|
||||
for (size_t dim = 0; dim < expand_num; dim++) {
|
||||
paddings.push_back(0);
|
||||
}
|
||||
|
||||
if (expand_num > 0) {
|
||||
std::vector<int64_t> pad_vec;
|
||||
for (int i = input_shape_max; i > 0; i--) {
|
||||
pad_vec.push_back(paddings[i * 2 - 2]);
|
||||
pad_vec.push_back(paddings[i * 2 - 1]);
|
||||
}
|
||||
paddings = pad_vec;
|
||||
}
|
||||
|
||||
if (!paddings_contiguous) {
|
||||
std::vector<int64_t> pads;
|
||||
int64_t rank = paddings.size() / 2;
|
||||
for (int i = 0; i < rank; i++) {
|
||||
pads.push_back(paddings[i]);
|
||||
pads.push_back(paddings[i + rank]);
|
||||
}
|
||||
paddings = pads;
|
||||
OP_LOGI(TbeGetName(op).c_str(), "Get attr paddings_contiguous = false");
|
||||
} else {
|
||||
OP_LOGI(TbeGetName(op).c_str(), "Get attr paddings_contiguous = true[default]");
|
||||
}
|
||||
|
||||
return PadInferShapeAndType(op, paddings);
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(PadV3, PadV3InferShape);
|
||||
// ----------------PadV3 Op End-------------------
|
||||
|
||||
// ----------------PadV3Grad Op Begin-------------------
|
||||
static graphStatus PadV3GradInferShapeAndType(ge::Operator &op, std::vector<int64_t> &paddings) {
|
||||
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto input_desc = op_info->MutableInputDesc("x");
|
||||
auto input_shape = input_desc->MutableShape().GetDims();
|
||||
auto input_dtype = input_desc->GetDataType();
|
||||
auto output_desc = op_info->MutableOutputDesc("y");
|
||||
output_desc->SetDataType(input_dtype);
|
||||
|
||||
if (!IsUnknown(input_shape)) {
|
||||
// calce the output shape
|
||||
vector<int64_t> output_shape;
|
||||
for (size_t dim = 0; dim < input_shape.size(); dim++) {
|
||||
output_shape.push_back(input_shape[dim] - paddings[dim * 2] - paddings[dim * 2 + 1]);
|
||||
}
|
||||
output_desc->SetShape(GeShape(output_shape));
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// input shape is -2, output is -2
|
||||
if (IsUnknownRankShape(input_shape)) {
|
||||
output_desc->SetShape(GeShape(input_shape));
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// input shape is -1, will get the shape and range
|
||||
// calcu the output shape
|
||||
vector<int64_t> output_shape;
|
||||
for (size_t dim = 0; dim < input_shape.size(); dim++) {
|
||||
if (input_shape[dim] == -1) {
|
||||
output_shape.push_back(input_shape[dim]);
|
||||
} else {
|
||||
output_shape.push_back(input_shape[dim] - paddings[dim * 2] - paddings[dim * 2 + 1]);
|
||||
}
|
||||
}
|
||||
output_desc->SetShape(GeShape(output_shape));
|
||||
|
||||
// calcu the output range
|
||||
std::vector<std::pair<int64_t, int64_t>> input_range;
|
||||
input_desc->GetShapeRange(input_range);
|
||||
MakeUpShapeRange(input_shape, input_range);
|
||||
std::vector<std::pair<int64_t, int64_t>> output_range;
|
||||
for (size_t dim = 0; dim < input_shape.size(); dim++) {
|
||||
auto range_min = input_range[dim].first - paddings[dim * 2] - paddings[dim * 2 + 1];
|
||||
if (range_min < 1) {
|
||||
range_min = 1;
|
||||
}
|
||||
auto range_max =
|
||||
input_range[dim].second == -1 ? -1 : input_range[dim].second - paddings[dim * 2] - paddings[dim * 2 + 1];
|
||||
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
|
||||
}
|
||||
output_desc->SetShapeRange(output_range);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(PadV3GradInferShape) {
|
||||
const vector<string> depend_names = {"paddings"};
|
||||
PREPARE_DYNAMIC_SHAPE(depend_names);
|
||||
Tensor paddings_tensor;
|
||||
if (ge::GRAPH_SUCCESS != op.GetInputConstData("paddings", paddings_tensor)) {
|
||||
OP_LOGW(TbeGetName(op).c_str(), "Get Const Value [paddings] failed, Setting shape to UNKNOWN_DIM");
|
||||
Shape shape_x = op.GetInputDescByName("x").GetShape();
|
||||
vector<int64_t> shape;
|
||||
for (size_t dim = 0; dim < shape_x.GetDimNum(); dim++) {
|
||||
shape.push_back(UNKNOWN_DIM);
|
||||
}
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
TensorDesc tensordesc_output = op.GetOutputDescByName("y");
|
||||
Shape out_shape(shape);
|
||||
tensordesc_output.SetShape(out_shape);
|
||||
tensordesc_output.SetDataType(input_dtype);
|
||||
(void)op.UpdateOutputDesc("y", tensordesc_output);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
DataType dtype = op.GetInputDescByName("paddings").GetDataType();
|
||||
|
||||
std::vector<int64_t> paddings;
|
||||
if (!GetConstValue(op, paddings_tensor, dtype, paddings)) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Get Const Value [paddings] failed ");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
bool paddings_contiguous = true;
|
||||
if (op.GetAttr("paddings_contiguous", paddings_contiguous) == GRAPH_FAILED) {
|
||||
OP_LOGI(TbeGetName(op).c_str(), "Get attr [paddings_contiguous] failed");
|
||||
}
|
||||
return PadV3GradInferShapeAndType(op, paddings);
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(PadV3Grad, PadV3GradInferShape);
|
||||
// ----------------PadV3Grad Op End-------------------
|
||||
} // namespace ge
|
|
@ -8,7 +8,9 @@
|
|||
#include "custom_op_proto/cust_array_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/op_const.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
#include "utils/vector_proto_profiling.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------CumulativeLogsumexp-------------------
|
||||
|
@ -162,4 +164,519 @@ IMPLEMT_COMMON_INFERFUNC(IndexFillInferShape) {
|
|||
// Registered inferfunction
|
||||
CUST_COMMON_INFER_FUNC_REG(IndexFill, IndexFillInferShape);
|
||||
// ----------------IndexFill END-------------------
|
||||
|
||||
// ----------------SegmentSum-------------------
|
||||
static bool SegmentSumShapeVerify(const Operator &op, const std::string &input_name,
|
||||
const std::string &segment_ids_name) {
|
||||
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto input_shape_dims = op_info->MutableInputDesc("x")->MutableShape().GetDims();
|
||||
auto segment_ids_shape_dims = op_info->MutableInputDesc("segment_ids")->MutableShape().GetDims();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
IMPLEMT_VERIFIER(SegmentSum, SegmentSumInferShapeVerifier) {
|
||||
if (!SegmentSumShapeVerify(op, "x", "segment_ids")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(SegmentSumInferShape) {
|
||||
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto input_x_desc = op_info->MutableInputDesc("x");
|
||||
auto output_desc = op_info->MutableOutputDesc("y");
|
||||
auto shape_x = input_x_desc->MutableShape().GetDims();
|
||||
auto output_shape_dims = input_x_desc->MutableShape().GetDims();
|
||||
if (output_shape_dims.empty()) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("the input[x]'s shape should not be empty."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
const vector<string> depend_name = {"segment_ids"};
|
||||
PREPARE_DYNAMIC_SHAPE(depend_name);
|
||||
const std::string segment_ids_name = "segment_ids";
|
||||
Tensor segment_ids;
|
||||
int64_t first_axis_dims;
|
||||
int64_t out_range_first_dims;
|
||||
if (GRAPH_SUCCESS != op.GetInputConstData(segment_ids_name.c_str(), segment_ids)) {
|
||||
OP_LOGI("segment_max", "GetInputConstData %s failed.", segment_ids_name.c_str());
|
||||
first_axis_dims = -1;
|
||||
out_range_first_dims = 0;
|
||||
} else {
|
||||
auto data_type = op.GetInputDescByName(segment_ids_name.c_str()).GetDataType();
|
||||
std::vector<int64_t> const_data;
|
||||
if (!GetConstIntData(segment_ids, data_type, const_data)) {
|
||||
std::string err_msg =
|
||||
ConcatString("failed to call GetConstIntData function ",
|
||||
"due to invalid data type of input[segment_ids]. data_type is ", DTypeStr(data_type));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
first_axis_dims = (*std::max_element(const_data.begin(), const_data.end())) + 1;
|
||||
out_range_first_dims = first_axis_dims;
|
||||
}
|
||||
|
||||
if (IsUnknownRankShape(shape_x)) {
|
||||
output_desc->SetShape(GeShape(shape_x));
|
||||
} else {
|
||||
output_shape_dims[0] = first_axis_dims;
|
||||
GeShape output_shape(output_shape_dims);
|
||||
output_desc->SetShape(GeShape(output_shape_dims));
|
||||
if (output_shape.IsUnknownShape()) {
|
||||
std::vector<std::pair<int64_t, int64_t>> shape_range_x;
|
||||
std::vector<std::pair<int64_t, int64_t>> output_shape_range;
|
||||
output_shape_range.push_back(std::pair<int64_t, int64_t>(out_range_first_dims, first_axis_dims));
|
||||
input_x_desc->GetShapeRange(shape_range_x);
|
||||
MakeUpShapeRange(output_shape_dims, shape_range_x);
|
||||
for (size_t i = 1; i < output_shape_dims.size(); i++) {
|
||||
output_shape_range.push_back(shape_range_x[i]);
|
||||
}
|
||||
output_desc->SetShapeRange(output_shape_range);
|
||||
}
|
||||
}
|
||||
DataType input_dtype = input_x_desc->GetDataType();
|
||||
output_desc->SetDataType(input_dtype);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
COMMON_INFER_FUNC_REG(SegmentSum, SegmentSumInferShape);
|
||||
VERIFY_FUNC_REG(SegmentSum, SegmentSumInferShapeVerifier);
|
||||
// ----------------SegmentSum END-------------------
|
||||
|
||||
// ----------------Select----------------------
|
||||
IMPLEMT_VERIFIER(Select, SelectVerify) {
|
||||
if (!CheckTwoInputDtypeSame(op, "x1", "x2")) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
|
||||
TbeGetName(op),
|
||||
string("call function CheckTwoInputDtypeSame failed, data type of input[x1] is not same as input[x2]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(SelectInferShape) {
|
||||
if (!TwoInOneOutDynamicInferNoBroadcast(op, "x1", "x2", {"y"})) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
|
||||
TbeGetName(op), string("call function TwoInOneOutDynamicInferNoBroadcast failed, update output[y] desc failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
COMMON_INFER_FUNC_REG(Select, SelectInferShape);
|
||||
VERIFY_FUNC_REG(Select, SelectVerify);
|
||||
// ---------------Select END-----------------------
|
||||
|
||||
// ----------------ReverseV2 Op Begin-----------------
|
||||
IMPLEMT_COMMON_INFERFUNC(ReverseV2InferShape) {
|
||||
const vector<string> depend_names = {"axis"};
|
||||
PREPARE_DYNAMIC_SHAPE(depend_names);
|
||||
if (OneInOneOutDynamicInfer(op, "x", {"y"})) {
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
COMMON_INFER_FUNC_REG(ReverseV2, ReverseV2InferShape);
|
||||
// ----------------ReverseV2 Op End-------------------
|
||||
|
||||
// ----------------ScatterNd-------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(ScatterNdInferShape) {
|
||||
vector<string> input_infer_depends = {"shape"};
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
op_desc->SetOpInferDepends(input_infer_depends);
|
||||
|
||||
auto output_desc = op_desc->MutableOutputDesc("y");
|
||||
auto shape_desc = op_desc->MutableInputDesc("shape");
|
||||
std::vector<int64_t> shape_shape = shape_desc->MutableShape().GetDims();
|
||||
std::vector<std::pair<int64_t, int64_t>> out_range;
|
||||
Tensor shape;
|
||||
std::vector<int64_t> const_data;
|
||||
if (GRAPH_SUCCESS != op.GetInputConstData("shape", shape)) {
|
||||
const_data = {-2};
|
||||
} else {
|
||||
auto data_type = shape_desc->GetDataType();
|
||||
if (!GetConstIntData(shape, data_type, const_data)) {
|
||||
USER_GE_LOGE("Invalid data type of shape, data_type is %d.", (int)data_type);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
vector<int64_t> shape_dims;
|
||||
if (shape_shape.size() == 1 && shape_shape[0] > 0 && IsUnknownRankShape(const_data)) {
|
||||
for (int64_t i = 0; i < shape_shape[0]; i++) {
|
||||
shape_dims.push_back(-1);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < (uint32_t)const_data.size(); ++i) {
|
||||
shape_dims.push_back(const_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (IsUnknownRankShape(shape_dims)) {
|
||||
out_range.push_back(std::pair<int64_t, int64_t>(1, -1));
|
||||
} else if (IsUnknownVec(shape_dims)) {
|
||||
for (size_t i = 0; i < shape_dims.size(); i++) {
|
||||
if (shape_dims[i] == -1) {
|
||||
out_range.push_back(std::pair<int64_t, int64_t>(1, -1));
|
||||
} else {
|
||||
out_range.push_back(std::pair<int64_t, int64_t>(shape_dims[i], shape_dims[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GeShape output_shape(shape_dims);
|
||||
output_desc->SetShape(output_shape);
|
||||
output_desc->SetShapeRange(out_range);
|
||||
output_desc->SetDataType(op_desc->MutableInputDesc("x")->GetDataType());
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(ScatterNd, ScatterNdInferShape);
|
||||
// ----------------ScatterNd End-------------------
|
||||
|
||||
// ----------------OneHot---------------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(OneHotInferShape) {
|
||||
const vector<string> depend_names = {"depth"};
|
||||
PREPARE_DYNAMIC_SHAPE(depend_names);
|
||||
// get attr axis
|
||||
int32_t axis = -1;
|
||||
if (ge::GRAPH_SUCCESS != op.GetAttr("axis", axis)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("Get const axis failed from op of 'OneHot'!\n");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (axis < -1) {
|
||||
string correct_size = ConcatString("attr axis(", axis, ") must be >= -1");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), correct_size);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
// get all Desc info
|
||||
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
|
||||
static const int64_t input_x_idx = 0;
|
||||
auto input_desc = op_info->MutableInputDesc(input_x_idx);
|
||||
const ge::GeShape &input_shape = input_desc->MutableShape();
|
||||
|
||||
static const int64_t input_on_value_idx = 2;
|
||||
auto value_desc = op_info->MutableInputDesc(input_on_value_idx);
|
||||
DataType value_dtype = value_desc->GetDataType();
|
||||
|
||||
// output desc and set dtype
|
||||
static const int64_t output_y_idx = 0;
|
||||
auto output_desc = op_info->MutableOutputDesc(output_y_idx);
|
||||
output_desc->SetDataType(value_dtype);
|
||||
|
||||
if (input_shape.IsUnknownDimNum()) {
|
||||
// input is UnknownRank, set output UnknownRank
|
||||
OP_LOGW("OneHot", "input shape is UnknownRank, set output UnknownRank");
|
||||
output_desc->SetShape(input_shape);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
// update axis to positive number
|
||||
int32_t dimnum = input_shape.GetDimNum();
|
||||
if (axis > dimnum) {
|
||||
string correct_size = ConcatString("attr axis(", axis, ") must be < ", input_shape.GetDimNum());
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), correct_size);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
// get depth const value, depth index is 1
|
||||
int64_t depth_value = -1;
|
||||
static const int64_t input_depth_idx = 1;
|
||||
if (!ops::GetConstInt(op, input_depth_idx, depth_value)) {
|
||||
OP_LOGW("OneHot", "Get depth const tensor failed, set depth -1");
|
||||
}
|
||||
|
||||
// update output shape
|
||||
ge::GeShape &output_shape = output_desc->MutableShape();
|
||||
output_shape.SetDimNum(dimnum + 1);
|
||||
if (-1 == axis) {
|
||||
for (int32_t i = 0; i < dimnum; i++) {
|
||||
output_shape.SetDim(i, input_shape.GetDim(i));
|
||||
}
|
||||
output_shape.SetDim(dimnum, depth_value);
|
||||
} else {
|
||||
while (dimnum > axis) {
|
||||
output_shape.SetDim(dimnum, input_shape.GetDim(dimnum - 1));
|
||||
dimnum--;
|
||||
}
|
||||
output_shape.SetDim(axis, depth_value);
|
||||
for (int32_t i = 0; i < axis; i++) {
|
||||
output_shape.SetDim(i, input_shape.GetDim(i));
|
||||
}
|
||||
}
|
||||
|
||||
// if output shape is dynamic update output range
|
||||
if (output_shape.IsUnknownShape()) {
|
||||
output_desc->SetOriginShape(output_shape);
|
||||
std::vector<std::pair<int64_t, int64_t>> input_range;
|
||||
input_desc->GetShapeRange(input_range);
|
||||
MakeUpShapeRange(input_shape, input_range);
|
||||
std::pair<int64_t, int64_t> depth_range =
|
||||
depth_value == -1 ? std::pair<int64_t, int64_t>(1, -1) : std::pair<int64_t, int64_t>(depth_value, depth_value);
|
||||
if (-1 == axis) {
|
||||
input_range.insert(input_range.end(), depth_range);
|
||||
} else {
|
||||
input_range.insert(input_range.begin() + axis, depth_range);
|
||||
}
|
||||
output_desc->SetShapeRange(input_range);
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(OneHot, OneHotInferShape);
|
||||
// ----------------OneHot END----------------------
|
||||
|
||||
// ----------------UnsortedSegmentSum-------------------
|
||||
static void GetUnsortedSegmentSumConstValue(const Tensor &const_tensor, const DataType &dtype, int64_t &const_data) {
|
||||
if (dtype == ge::DT_INT32) {
|
||||
int32_t *const_data_ptr = (int32_t *)const_tensor.GetData();
|
||||
const_data = *const_data_ptr;
|
||||
} else {
|
||||
int64_t *const_data_ptr = (int64_t *)const_tensor.GetData();
|
||||
const_data = *const_data_ptr;
|
||||
}
|
||||
}
|
||||
|
||||
static void GetRealRange(ge::GeShape shape, std::vector<std::pair<int64_t, int64_t>> &range) {
|
||||
if (shape.IsUnknownDimNum()) {
|
||||
return;
|
||||
}
|
||||
if (range.empty()) {
|
||||
for (size_t i = 0; i < shape.GetDimNum(); i++) {
|
||||
int64_t dim = shape.GetDim(i);
|
||||
if (dim == -1) {
|
||||
range.push_back(std::pair<int64_t, int64_t>(0, -1));
|
||||
} else {
|
||||
range.push_back(std::pair<int64_t, int64_t>(dim, dim));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(UnsortedSegmentSumInferShape) {
|
||||
PROFILING_PROTO_INIT(TbeGetName(op).c_str());
|
||||
vector<string> input_infer_depends = {"num_segments"};
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
op_desc->SetOpInferDepends(input_infer_depends);
|
||||
|
||||
Tensor input_num_segments_tensor;
|
||||
int64_t input_num_segments;
|
||||
DataType input_num_segments_dtype = op_desc->GetInputDescPtr(2)->GetDataType();
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> shape_range_x;
|
||||
op_desc->GetInputDescPtr(0)->GetShapeRange(shape_range_x);
|
||||
std::vector<std::pair<int64_t, int64_t>> shape_range_seg_id;
|
||||
op_desc->GetInputDescPtr(1)->GetShapeRange(shape_range_seg_id);
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> out_range;
|
||||
|
||||
if (GRAPH_SUCCESS != op.GetInputConstData("num_segments", input_num_segments_tensor)) {
|
||||
input_num_segments = -1;
|
||||
out_range.push_back(std::pair<int64_t, int64_t>(0, -1));
|
||||
} else {
|
||||
GetUnsortedSegmentSumConstValue(input_num_segments_tensor, input_num_segments_dtype, input_num_segments);
|
||||
out_range.push_back(std::pair<int64_t, int64_t>(input_num_segments, input_num_segments));
|
||||
}
|
||||
|
||||
ge::GeShape shape = op_desc->GetInputDescPtr(0)->GetShape();
|
||||
ge::GeShape shape_id = op_desc->GetInputDescPtr(1)->GetShape();
|
||||
|
||||
auto output_desc = op_desc->MutableOutputDesc(0);
|
||||
ge::GeShape output_shape = output_desc->MutableShape();
|
||||
GetRealRange(shape, shape_range_x);
|
||||
GetRealRange(shape_id, shape_range_seg_id);
|
||||
|
||||
int64_t dim_idsize_input = shape_id.GetDimNum();
|
||||
int64_t dim_size_input = shape.GetDimNum();
|
||||
DataType input_dtype = op_desc->GetInputDescPtr(0)->GetDataType();
|
||||
PROFILING_PROTO_AFTER_GET_SHAPE_REG();
|
||||
if (shape.IsUnknownDimNum() || shape_id.IsUnknownDimNum()) {
|
||||
if (shape.IsUnknownDimNum()) {
|
||||
output_desc->SetShape(shape);
|
||||
output_desc->SetDataType(input_dtype);
|
||||
} else {
|
||||
output_desc->SetShape(shape_id);
|
||||
output_desc->SetDataType(input_dtype);
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
} else if (dim_idsize_input > 1) {
|
||||
size_t rank = dim_size_input - dim_idsize_input + 1;
|
||||
size_t idx = 1;
|
||||
output_shape.SetDimNum(rank);
|
||||
output_shape.SetDim(0, input_num_segments);
|
||||
|
||||
for (int64_t i = dim_idsize_input; i < dim_size_input; i++) {
|
||||
int64_t x_dim = shape.GetDim(i);
|
||||
output_shape.SetDim(idx, x_dim);
|
||||
if ((size_t)i < shape_range_x.size()) {
|
||||
out_range.push_back(shape_range_x[i]);
|
||||
}
|
||||
idx++;
|
||||
}
|
||||
} else {
|
||||
size_t rank = shape.GetDimNum();
|
||||
output_shape.SetDimNum(rank);
|
||||
output_shape.SetDim(0, input_num_segments);
|
||||
|
||||
for (size_t i = 1; i < rank; i++) {
|
||||
int64_t x_dim = shape.GetDim(i);
|
||||
output_shape.SetDim(i, x_dim);
|
||||
if ((size_t)i < shape_range_x.size()) {
|
||||
out_range.push_back(shape_range_x[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PROFILING_PROTO_AFTER_INFER_SHAPE_REG();
|
||||
output_desc->SetShape(output_shape);
|
||||
output_desc->SetDataType(input_dtype);
|
||||
output_desc->SetShapeRange(out_range);
|
||||
PROFILING_PROTO_END();
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
COMMON_INFER_FUNC_REG(UnsortedSegmentSum, UnsortedSegmentSumInferShape);
|
||||
// ----------------UnsortedSegmentSum END----------------
|
||||
|
||||
// ----------------Slice Op Begin ----------------------
|
||||
static void GetSliceConstValue(const Tensor &const_tensor, const DataType &dtype, std::vector<int64_t> &const_data) {
|
||||
size_t size = 0;
|
||||
if (dtype == ge::DT_INT32) {
|
||||
int32_t *const_data_ptr = (int32_t *)const_tensor.GetData();
|
||||
size = const_tensor.GetSize() / sizeof(int32_t);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
const_data.push_back((int32_t)((*(const_data_ptr + i))));
|
||||
}
|
||||
} else {
|
||||
int64_t *const_data_ptr = (int64_t *)const_tensor.GetData();
|
||||
size = const_tensor.GetSize() / sizeof(int64_t);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
const_data.push_back(((int64_t)(*(const_data_ptr + i))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(SliceInferShape) {
|
||||
const vector<string> depend_names = {"offsets", "size"};
|
||||
PREPARE_DYNAMIC_SHAPE(depend_names);
|
||||
|
||||
Tensor input_begin_tensor;
|
||||
Tensor input_size_tensor;
|
||||
auto input_desc = op.GetInputDescByName("x");
|
||||
const Shape shape = input_desc.GetShape();
|
||||
DataType input_dtype = input_desc.GetDataType();
|
||||
std::vector<int64_t> input_begin;
|
||||
std::vector<int64_t> input_size;
|
||||
|
||||
bool has_offsets = true;
|
||||
if (op.GetInputConstData("offsets", input_begin_tensor) != GRAPH_SUCCESS) {
|
||||
OP_LOGI(TbeGetName(op).c_str(), "Get offsets failed.");
|
||||
has_offsets = false;
|
||||
} else {
|
||||
DataType input_begin_dtype = op.GetInputDescByName("offsets").GetDataType();
|
||||
GetSliceConstValue(input_begin_tensor, input_begin_dtype, input_begin);
|
||||
}
|
||||
|
||||
bool has_size = true;
|
||||
if (op.GetInputConstData("size", input_size_tensor) != GRAPH_SUCCESS) {
|
||||
OP_LOGI(TbeGetName(op).c_str(), "Get size failed.");
|
||||
has_size = false;
|
||||
} else {
|
||||
DataType input_size_dtype = op.GetInputDescByName("size").GetDataType();
|
||||
GetSliceConstValue(input_size_tensor, input_size_dtype, input_size);
|
||||
}
|
||||
|
||||
bool is_unknown_rank = !has_size && !has_offsets && shape.GetDims() == UNKNOWN_RANK;
|
||||
if (is_unknown_rank) {
|
||||
TensorDesc output_desc = op.GetOutputDescByName("y");
|
||||
output_desc.SetDataType(input_dtype);
|
||||
Shape outputShape(UNKNOWN_RANK);
|
||||
output_desc.SetShape(outputShape);
|
||||
OP_LOGD(TbeGetName(op).c_str(), "output_shape:%s", to_string(output_desc.GetShape()).c_str());
|
||||
(void)op.UpdateOutputDesc("y", output_desc);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
auto shape_dims = shape.GetDims();
|
||||
if (shape.GetDims() == UNKNOWN_RANK) {
|
||||
shape_dims.assign(std::max(input_begin.size(), input_size.size()), -1);
|
||||
}
|
||||
|
||||
size_t dimNum = shape_dims.size();
|
||||
std::vector<int64_t> outputList;
|
||||
|
||||
vector<pair<int64_t, int64_t>> ranges;
|
||||
input_desc.GetShapeRange(ranges);
|
||||
if (ranges.empty()) {
|
||||
MakeUpShapeRange(shape_dims, ranges);
|
||||
}
|
||||
|
||||
if (ranges.size() < dimNum) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "ranges.size is:%ld, smaller than dimNum, dimNum is %ld.", ranges.size(), dimNum);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (!has_size && !has_offsets) {
|
||||
for (size_t i = 0; i < dimNum; ++i) {
|
||||
outputList.push_back(-1);
|
||||
ranges[i].first = 0;
|
||||
}
|
||||
} else if (!has_offsets && has_size) {
|
||||
for (size_t i = 0; i < dimNum; ++i) {
|
||||
if (input_size[i] == -1) {
|
||||
outputList.push_back(-1);
|
||||
ranges[i].first = 0;
|
||||
} else {
|
||||
outputList.push_back(input_size[i]);
|
||||
ranges[i].first = input_size[i];
|
||||
ranges[i].second = input_size[i];
|
||||
}
|
||||
}
|
||||
} else if (has_offsets && !has_size) {
|
||||
for (size_t i = 0; i < dimNum; ++i) {
|
||||
outputList.push_back(-1);
|
||||
ranges[i].first = 0;
|
||||
if (ranges[i].second != -1) {
|
||||
if (shape_dims[i] != -1) {
|
||||
ranges[i].second = std::min(ranges[i].second, shape_dims[i]);
|
||||
}
|
||||
ranges[i].second -= input_begin[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < dimNum; ++i) {
|
||||
if (input_size[i] == -1) {
|
||||
if (shape_dims[i] == -1) {
|
||||
outputList.push_back(-1);
|
||||
} else {
|
||||
outputList.push_back(shape_dims[i] - input_begin[i]);
|
||||
}
|
||||
|
||||
ranges[i].first = 0;
|
||||
} else {
|
||||
outputList.push_back(input_size[i]);
|
||||
ranges[i].first = input_size[i];
|
||||
ranges[i].second = input_size[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TensorDesc tensordesc_output = op.GetOutputDescByName("y");
|
||||
tensordesc_output.SetDataType(input_dtype);
|
||||
if (IsUnKnownShape(outputList)) {
|
||||
tensordesc_output.SetShapeRange(ranges);
|
||||
OP_LOGD(TbeGetName(op).c_str(), "output_ranges:%s", to_string(ranges).c_str());
|
||||
}
|
||||
|
||||
Shape outputShape(outputList);
|
||||
tensordesc_output.SetShape(outputShape);
|
||||
OP_LOGD(TbeGetName(op).c_str(), "output_ranges:%s", to_string(ranges).c_str());
|
||||
OP_LOGD(TbeGetName(op).c_str(), "offset:%s", to_string(input_begin).c_str());
|
||||
OP_LOGD(TbeGetName(op).c_str(), "size:%s", to_string(input_size).c_str());
|
||||
OP_LOGD(TbeGetName(op).c_str(), "output_shape:%s", to_string(tensordesc_output.GetShape()).c_str());
|
||||
(void)op.UpdateOutputDesc("y", tensordesc_output);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(Slice, SliceInferShape);
|
||||
// ----------------Slice Op END ----------------------
|
||||
} // namespace ge
|
||||
|
|
|
@ -1,53 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/trace_grad_op.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------TraceGrad Begin------------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(TraceGradInferShape) {
|
||||
Shape shape = op.GetInputDescByName("y_grad").GetShape();
|
||||
DataType input_dtype = op.GetInputDescByName("y_grad").GetDataType();
|
||||
Tensor tensor_input;
|
||||
std::vector<int64_t> dim_vector;
|
||||
if (op.GetInputConstData("x_shape", tensor_input) == GRAPH_SUCCESS) {
|
||||
uint8_t *input_shape = tensor_input.GetData();
|
||||
for (int64_t i = 0; i < 2; i++) {
|
||||
dim_vector.push_back(shape.GetDim(*(input_shape + i)));
|
||||
}
|
||||
}
|
||||
Shape output_shape(dim_vector);
|
||||
TensorDesc td = op.GetOutputDescByName("x_grad");
|
||||
td.SetShape(output_shape);
|
||||
td.SetDataType(input_dtype);
|
||||
td.SetFormat(FORMAT_ND);
|
||||
(void)op.UpdateOutputDesc("x_grad", td);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
CUST_IMPLEMT_VERIFIER(TraceGrad, TraceGradVerify) {
|
||||
DataType x_shape_dtype = op.GetInputDescByName("x_shape").GetDataType();
|
||||
if ((x_shape_dtype != DT_INT32) || (x_shape_dtype != DT_INT64)) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(TraceGrad, TraceGradInferShape);
|
||||
CUST_VERIFY_FUNC_REG(TraceGrad, TraceGradVerify);
|
||||
// ---------------TraceGrad END-------------------------------
|
||||
} // namespace ge
|
|
@ -7,7 +7,9 @@
|
|||
#include "inc/ops/transformation_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/op_const.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
#include "utils/vector_proto_profiling.h"
|
||||
|
||||
namespace ge {
|
||||
// ------------------DepthToSpace------------------
|
||||
|
@ -185,4 +187,149 @@ IMPLEMT_COMMON_INFERFUNC(DepthToSpaceInfer) {
|
|||
COMMON_INFER_FUNC_REG(DepthToSpace, DepthToSpaceInfer);
|
||||
VERIFY_FUNC_REG(DepthToSpace, DepthToSpaceVerify);
|
||||
// -------------------DepthToSpace END-----------------
|
||||
|
||||
// -------------------Transpose-----------------
|
||||
static graphStatus TransposeCommonInferShape(const std::vector<int64_t> &perm_list, Operator &op) {
|
||||
PROFILING_PROTO_INIT(TbeGetName(op).c_str());
|
||||
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
|
||||
const int64_t input_x_idx = 0;
|
||||
auto input_desc = op_info->MutableInputDesc(input_x_idx);
|
||||
const int64_t output_y_idx = 0;
|
||||
auto output_desc = op_info->MutableOutputDesc(output_y_idx);
|
||||
|
||||
auto input_dtype = input_desc->GetDataType();
|
||||
const GeShape &input_ge_shape = input_desc->MutableShape();
|
||||
|
||||
int64_t input_shape_len = input_ge_shape.GetDimNum();
|
||||
|
||||
PROFILING_PROTO_AFTER_GET_SHAPE_REG();
|
||||
|
||||
if (IsUnknownRankShape(input_ge_shape)) {
|
||||
// UnknownRankShape, set shape is -1, -1, -1....
|
||||
std::vector<int64_t> out_vec(perm_list.size(), -1);
|
||||
output_desc->SetShape(GeShape(out_vec));
|
||||
output_desc->SetDataType(input_dtype);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// infer the shape
|
||||
GeShape &output_ge_shape = output_desc->MutableShape();
|
||||
output_ge_shape.SetDimNum(input_shape_len);
|
||||
for (size_t i = 0; i < perm_list.size(); ++i) {
|
||||
// verify perm_list begin
|
||||
int64_t perm_value = perm_list[i] < 0 ? perm_list[i] + input_shape_len : perm_list[i];
|
||||
if (perm_value >= input_shape_len) {
|
||||
std::string err_msg = GetAttrValueErrMsg("perm", ConcatString(perm_value),
|
||||
ConcatString("less than input shape size[", input_shape_len, "]"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
// verify perm_list end
|
||||
|
||||
// set the output shape
|
||||
output_ge_shape.SetDim(i, input_ge_shape.GetDim(perm_value));
|
||||
}
|
||||
PROFILING_PROTO_AFTER_INFER_SHAPE_REG();
|
||||
// set output dtype as the same with input x
|
||||
output_desc->SetDataType(input_dtype);
|
||||
|
||||
// infer the range, when need
|
||||
if (output_ge_shape.IsUnknownShape()) {
|
||||
output_desc->SetOriginShape(output_ge_shape);
|
||||
std::vector<std::pair<int64_t, int64_t>> input_range;
|
||||
std::vector<std::pair<int64_t, int64_t>> output_range;
|
||||
input_desc->GetShapeRange(input_range);
|
||||
MakeUpShapeRange(input_ge_shape, input_range);
|
||||
for (size_t i = 0; i < perm_list.size(); ++i) {
|
||||
output_range.push_back(input_range[perm_list[i]]);
|
||||
}
|
||||
output_desc->SetShapeRange(output_range);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
PROFILING_PROTO_END();
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
IMPLEMT_COMMON_INFERFUNC(TransposeInferShape) {
|
||||
const vector<string> depend_names = {"perm"};
|
||||
PREPARE_DYNAMIC_SHAPE(depend_names);
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
|
||||
bool perm_done = true;
|
||||
std::vector<int64_t> perm_list;
|
||||
static const int64_t perm_input_index = 1;
|
||||
if (!(ops::GetConstIntData(op, perm_input_index, perm_list))) {
|
||||
perm_done = false;
|
||||
OP_LOGW(TbeGetName(op), "Get Const perm value failed ");
|
||||
}
|
||||
|
||||
// perm is const node , will do infer use function TransposeCommonInferShape
|
||||
if (perm_done) {
|
||||
if (GRAPH_SUCCESS != TransposeCommonInferShape(perm_list, op)) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// perm is not const node, infer for aicpu
|
||||
static const int64_t x_input_index = 0;
|
||||
static const int64_t y_output_index = 0;
|
||||
auto input_desc = op_desc->MutableInputDesc(x_input_index);
|
||||
auto input_shape = input_desc->MutableShape().GetDims();
|
||||
auto input_dtype = input_desc->GetDataType();
|
||||
auto output_desc = op_desc->MutableOutputDesc(y_output_index);
|
||||
|
||||
// set output dtype as the same with input x
|
||||
output_desc->SetDataType(input_dtype);
|
||||
|
||||
if (IsUnknownRankShape(input_shape)) {
|
||||
auto perm_desc = op_desc->MutableInputDesc("perm");
|
||||
auto perm_shape = perm_desc->MutableShape().GetDims();
|
||||
if (IsUnknown(perm_shape)) {
|
||||
// set output is -2 UnknownRank
|
||||
OP_LOGW(TbeGetName(op), "the output will be set to -2");
|
||||
output_desc->SetShape(GeShape(input_shape));
|
||||
output_desc->SetOriginShape(GeShape(input_shape));
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// pert is not dynamic shape, will update the input shape
|
||||
if (perm_shape.empty()) {
|
||||
perm_shape.push_back(1);
|
||||
}
|
||||
input_shape.clear();
|
||||
for (auto i = 0; i < perm_shape[0]; ++i) {
|
||||
input_shape.push_back(-1);
|
||||
}
|
||||
}
|
||||
|
||||
// begin to infer shape and range
|
||||
std::vector<std::pair<int64_t, int64_t>> input_range;
|
||||
std::vector<std::pair<int64_t, int64_t>> output_range;
|
||||
vector<int64_t> out_vec;
|
||||
input_desc->GetShapeRange(input_range);
|
||||
MakeUpShapeRange(input_shape, input_range);
|
||||
|
||||
int64_t range_first = input_range[0].first;
|
||||
int64_t range_second = input_range[0].second;
|
||||
|
||||
for (size_t i = 0; i < input_range.size(); ++i) {
|
||||
// all range is the same and get the shape range
|
||||
range_first = std::min(range_first, input_range[i].first);
|
||||
range_second =
|
||||
(range_second == -1 || input_range[i].second == -1) ? -1 : std::max(range_second, input_range[i].second);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_range.size(); ++i) {
|
||||
out_vec.push_back(-1);
|
||||
output_range.push_back(std::pair<int64_t, int64_t>(range_first, range_second));
|
||||
}
|
||||
output_desc->SetShape(GeShape(out_vec));
|
||||
output_desc->SetOriginShape(GeShape(out_vec));
|
||||
output_desc->SetShapeRange(output_range);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(Transpose, TransposeInferShape);
|
||||
// -------------------Transpose END-----------------
|
||||
} // namespace ge
|
|
@ -77,9 +77,11 @@
|
|||
namespace ge {
|
||||
// enum type and string type mapping
|
||||
const std::map<ge::DataType, std::string> DTYPE_STR_MAP{
|
||||
{ge::DT_FLOAT16, "float16"}, {ge::DT_FLOAT, "float32"}, {ge::DT_INT8, "int8"}, {ge::DT_INT16, "int16"},
|
||||
{ge::DT_INT32, "int32"}, {ge::DT_INT64, "int64"}, {ge::DT_UINT8, "uint8"}, {ge::DT_UINT16, "uint16"},
|
||||
{ge::DT_UINT32, "uint32"}, {ge::DT_UINT64, "uint64"}, {ge::DT_BOOL, "bool"}, {ge::DT_INT4, "int4"},
|
||||
{ge::DT_DOUBLE, "double"}, {ge::DT_COMPLEX64, "complex64"}, {ge::DT_COMPLEX128, "complex128"},
|
||||
{ge::DT_FLOAT16, "float16"}, {ge::DT_FLOAT, "float32"}, {ge::DT_INT8, "int8"},
|
||||
{ge::DT_INT16, "int16"}, {ge::DT_INT32, "int32"}, {ge::DT_INT64, "int64"},
|
||||
{ge::DT_UINT8, "uint8"}, {ge::DT_UINT16, "uint16"}, {ge::DT_UINT32, "uint32"},
|
||||
{ge::DT_UINT64, "uint64"}, {ge::DT_BOOL, "bool"}, {ge::DT_INT4, "int4"},
|
||||
{ge::DT_BF16, "bfloat16"}};
|
||||
|
||||
// define the input num of shape
|
||||
|
|
|
@ -95,6 +95,7 @@ cust_op_lists = [
|
|||
"leftshift",
|
||||
"lessequal",
|
||||
"listdiff",
|
||||
"lgamma",
|
||||
"log",
|
||||
"log1p",
|
||||
"lognormalreverse",
|
||||
|
@ -102,7 +103,6 @@ cust_op_lists = [
|
|||
"lowerbound",
|
||||
"lusolve",
|
||||
"luunpackgrad",
|
||||
"maskedselect",
|
||||
"maskedselectgrad",
|
||||
"matrixdeterminant",
|
||||
"matrixexp",
|
||||
|
@ -111,6 +111,7 @@ cust_op_lists = [
|
|||
"matrixtriangularsolve",
|
||||
"maxpool3dgradwithargmax",
|
||||
"maxpool3dwithargmax",
|
||||
"meshgrid",
|
||||
"mul",
|
||||
"mulnonan",
|
||||
"multimarginloss",
|
||||
|
@ -123,7 +124,73 @@ cust_op_lists = [
|
|||
"gatherdgradv2",
|
||||
"isnan",
|
||||
"maskedselectgrad",
|
||||
"slicegrad"
|
||||
"slicegrad",
|
||||
"orgqr",
|
||||
"tracegrad",
|
||||
"nonmaxsuppressionwithoverlaps",
|
||||
"nthelement",
|
||||
"onehot",
|
||||
"orgqr",
|
||||
"padv3",
|
||||
"padv3grad",
|
||||
"parameterizedtruncatednormal",
|
||||
"pow",
|
||||
"qr",
|
||||
"raggedrange",
|
||||
"randompoisson",
|
||||
"reciprocal",
|
||||
"reciprocalgrad",
|
||||
"reducemean",
|
||||
"reduceprod",
|
||||
"reducesum",
|
||||
"resizearea",
|
||||
"resizebicubic",
|
||||
"resizebicubicgrad",
|
||||
"resizenearestneighborv2",
|
||||
"resizenearestneighborv2grad",
|
||||
"reversev2",
|
||||
"rgbtohsv",
|
||||
"rightshift",
|
||||
"rsqrtgrad",
|
||||
"sampledistortedboundingboxv2",
|
||||
"scaleandtranslate",
|
||||
"scaleandtranslategrad",
|
||||
"scatternd",
|
||||
"scatterndupdate",
|
||||
"segmentsum",
|
||||
"select",
|
||||
"sign",
|
||||
"sin",
|
||||
"sinh",
|
||||
"slice",
|
||||
"smoothl1loss",
|
||||
"smoothl1lossgrad",
|
||||
"split",
|
||||
"sqrt",
|
||||
"sqrtgrad",
|
||||
"stack",
|
||||
"tanh",
|
||||
"tensorscatterupdate",
|
||||
"tile",
|
||||
"trace",
|
||||
"tracegrad",
|
||||
"transpose",
|
||||
"tril",
|
||||
"truncatednormal",
|
||||
"uniqueconsecutive",
|
||||
"unravelindex",
|
||||
"unsortedsegmentsum",
|
||||
"unstack",
|
||||
"upperbound",
|
||||
"xdivy",
|
||||
"xlogy",
|
||||
"zeroslike",
|
||||
"flatten",
|
||||
"maxpoolv1",
|
||||
"norepeatngram",
|
||||
"randint",
|
||||
"reversesequence",
|
||||
"standardlaplace"
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/kernel/aicpu/aicpu_ops/meshgrid_kernels.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "Eigen/Core"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
#include "common/kernel_log.h"
|
||||
#include "common/kernel_errcode.h"
|
||||
#include "common/tensor.h"
|
||||
#include "aicpu_sharder/aicpu_sharder.h"
|
||||
#include "proto/aicpu_tensor.pb.h"
|
||||
|
||||
namespace aicpu {
|
||||
|
||||
template <typename T>
|
||||
uint32_t MeshgridTask(const std::vector<uintptr_t> &io_addrs_, const std::string &indexing, size_t ndim,
|
||||
const std::vector<int> &bcast) {
|
||||
auto shards = [&](const int64_t begin, const int64_t end) {
|
||||
for (int i = begin; i < end; ++i) { // 0~ndim
|
||||
auto new_i = i;
|
||||
auto s = bcast;
|
||||
if (indexing == "xy" && i < 2) {
|
||||
new_i = 1 - i;
|
||||
auto tmp = s[0];
|
||||
s[0] = s[1];
|
||||
s[1] = tmp;
|
||||
}
|
||||
size_t row_ = 1;
|
||||
size_t col_ = 1;
|
||||
for (int j = 0; j <= new_i; j++) {
|
||||
row_ *= s[j];
|
||||
}
|
||||
for (int j = new_i + 1; j < static_cast<int>(s.size()); j++) {
|
||||
col_ *= s[j];
|
||||
}
|
||||
|
||||
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>> input_map(reinterpret_cast<T *>(io_addrs_[i]), bcast[i],
|
||||
1);
|
||||
const auto &input = Eigen::Tensor<T, 2, Eigen::RowMajor>(input_map);
|
||||
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>> output(reinterpret_cast<T *>(io_addrs_[ndim + i]), row_,
|
||||
col_);
|
||||
|
||||
Eigen::Tensor<T, 2, Eigen::RowMajor> origin(bcast[i], row_ * col_ / bcast[i]);
|
||||
for (int c = 0; c < bcast[i]; ++c) {
|
||||
for (int r = 0; r < static_cast<int>(row_ * col_ / bcast[i]); ++r) {
|
||||
origin(c, r) = input(c, 0);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < row_ * col_ / bcast[i] / col_; ++j) {
|
||||
Eigen::array<int64_t, 2> offsets_in = {0, static_cast<int64_t>(col_ * j)};
|
||||
Eigen::array<int64_t, 2> offsets_out = {static_cast<int64_t>(bcast[i] * j), 0};
|
||||
Eigen::array<int64_t, 2> extents = {static_cast<int64_t>(bcast[i]), static_cast<int64_t>(col_)};
|
||||
output.slice(offsets_out, extents) = origin.slice(offsets_in, extents);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const int64_t perUnitSize = 1; // shard unit size
|
||||
ParallelFor(ndim, perUnitSize, shards);
|
||||
|
||||
return kAicpuKernelStateSucess;
|
||||
}
|
||||
|
||||
uint32_t MeshgridKernel::DoCompute() {
|
||||
std::map<int, std::function<uint32_t(std::vector<uintptr_t> &, std::string &, size_t &, std::vector<int> &)>> calls;
|
||||
calls[aicpuops::DataType::MS_INT8] = MeshgridTask<int8_t>;
|
||||
calls[aicpuops::DataType::MS_INT16] = MeshgridTask<int16_t>;
|
||||
calls[aicpuops::DataType::MS_INT32] = MeshgridTask<int32_t>;
|
||||
calls[aicpuops::DataType::MS_INT64] = MeshgridTask<int64_t>;
|
||||
calls[aicpuops::DataType::MS_FLOAT16] = MeshgridTask<Eigen::half>;
|
||||
calls[aicpuops::DataType::MS_FLOAT32] = MeshgridTask<float>;
|
||||
calls[aicpuops::DataType::MS_FLOAT64] = MeshgridTask<double>;
|
||||
calls[aicpuops::DataType::MS_UINT8] = MeshgridTask<uint8_t>;
|
||||
calls[aicpuops::DataType::MS_UINT16] = MeshgridTask<uint16_t>;
|
||||
calls[aicpuops::DataType::MS_UINT32] = MeshgridTask<uint32_t>;
|
||||
calls[aicpuops::DataType::MS_UINT64] = MeshgridTask<uint64_t>;
|
||||
calls[aicpuops::DataType::MS_BOOL] = MeshgridTask<bool>;
|
||||
return calls[input_type_](io_addrs_, indexing_, ndim_, bcast_);
|
||||
}
|
||||
|
||||
uint32_t MeshgridKernel::ParseKernelParam() {
|
||||
::google::protobuf::Map<::std::string, ::aicpuops::AttrValue> nodedef_map = node_def_.attrs();
|
||||
indexing_ = nodedef_map["indexing"].s();
|
||||
|
||||
input_type_ = static_cast<aicpuops::DataType>(node_def_.inputs(0).tensor_type());
|
||||
|
||||
ndim_ = node_def_.inputs_size();
|
||||
bcast_.resize(ndim_);
|
||||
for (int n = 0; n < node_def_.inputs_size(); ++n) {
|
||||
aicpuops::Tensor input_tensor = node_def_.inputs(n);
|
||||
aicpuops::TensorShape input_shape = input_tensor.tensor_shape();
|
||||
if (input_shape.dim().size() != 1) {
|
||||
AICPU_LOGE("input tensor should be 1-D.");
|
||||
}
|
||||
bcast_[n] = input_shape.dim(0).size();
|
||||
}
|
||||
|
||||
return kAicpuKernelStateSucess;
|
||||
}
|
||||
} // namespace aicpu
|
||||
|
||||
extern "C" {
|
||||
__attribute__((visibility("default"))) uint32_t Meshgrid(void *param) {
|
||||
aicpu::MeshgridKernel meshgridKernel;
|
||||
return meshgridKernel.Compute(param);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef _AICPU_MESHGRID_KERNELS_H_
|
||||
#define _AICPU_MESHGRID_KERNELS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "common/kernel_base.h"
|
||||
#include "Eigen/Core"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace aicpu {
|
||||
class MeshgridKernel : public KernelBase {
|
||||
public:
|
||||
explicit MeshgridKernel() : KernelBase("Meshgrid") {}
|
||||
~MeshgridKernel() = default;
|
||||
|
||||
aicpuops::DataType input_type_;
|
||||
std::string indexing_;
|
||||
size_t ndim_ = 0;
|
||||
std::vector<int> bcast_;
|
||||
|
||||
protected:
|
||||
uint32_t DoCompute() override;
|
||||
uint32_t ParseKernelParam() override;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
|
@ -22,13 +22,13 @@
|
|||
/* clang-format off */
|
||||
|
||||
namespace ge {
|
||||
REG_OP(Meshgrid)
|
||||
REG_CUST_OP(Meshgrid)
|
||||
.DYNAMIC_INPUT(x, TensorType({DT_INT8, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
|
||||
DT_UINT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_BOOL}))
|
||||
.DYNAMIC_OUTPUT(y, TensorType({DT_INT8, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
|
||||
DT_UINT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_BOOL}))
|
||||
.ATTR(indexing, String, "")
|
||||
.OP_END_FACTORY_REG(Meshgrid)
|
||||
.CUST_OP_END_FACTORY_REG(Meshgrid)
|
||||
|
||||
REG_CUST_OP(SliceGrad)
|
||||
.INPUT(dy, TensorType({DT_BOOL, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64,
|
||||
|
@ -104,5 +104,11 @@ REG_CUST_OP(LogSpace)
|
|||
.REQUIRED_ATTR(steps, Int)
|
||||
.REQUIRED_ATTR(base, Int)
|
||||
.CUST_OP_END_FACTORY_REG(LogSpace)
|
||||
|
||||
REG_CUST_OP(Expand)
|
||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8, DT_BOOL}))
|
||||
.INPUT(shape, TensorType({DT_INT16, DT_INT32, DT_INT64}))
|
||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8, DT_BOOL}))
|
||||
.CUST_OP_END_FACTORY_REG(Expand)
|
||||
} // namespace ge
|
||||
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_ARRAY_OPS_H_
|
||||
|
|
|
@ -87,6 +87,25 @@ REG_CUST_OP(Gcd)
|
|||
.INPUT(x2, TensorType({DT_INT32, DT_INT64}))
|
||||
.OUTPUT(y, TensorType({DT_INT32, DT_INT64}))
|
||||
.CUST_OP_END_FACTORY_REG(Gcd)
|
||||
|
||||
REG_CUST_OP(Orgqr)
|
||||
.INPUT(x, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT}))
|
||||
.INPUT(tau, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT}))
|
||||
.OUTPUT(y, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT}))
|
||||
.CUST_OP_END_FACTORY_REG(Orgqr)
|
||||
|
||||
REG_CUST_OP(TraceGrad)
|
||||
.INPUT(y_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
|
||||
DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.INPUT(x_shape, TensorType({DT_INT64}))
|
||||
.OUTPUT(x_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
|
||||
DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.CUST_OP_END_FACTORY_REG(TraceGrad)
|
||||
|
||||
REG_CUST_OP(Lgamma)
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT32}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.CUST_OP_END_FACTORY_REG(Lgamma)
|
||||
} // namespace ge
|
||||
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_MATH_OPS_H_
|
||||
|
||||
|
|
|
@ -55,6 +55,20 @@ REG_CUST_OP(AdaptiveAvgPool2DGrad)
|
|||
.OUTPUT(output_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.CUST_OP_END_FACTORY_REG(AdaptiveAvgPool2DGrad)
|
||||
|
||||
REG_CUST_OP(AdaptiveMaxPool3d)
|
||||
.INPUT(x, TensorType::RealNumberType())
|
||||
.INPUT(output_size, TensorType({DT_INT32}))
|
||||
.OUTPUT(y, TensorType::RealNumberType())
|
||||
.OUTPUT(argmax, TensorType({DT_INT32}))
|
||||
.CUST_OP_END_FACTORY_REG(AdaptiveMaxPool3d)
|
||||
|
||||
REG_CUST_OP(AdaptiveMaxPool3dGrad)
|
||||
.INPUT(input_grad, TensorType::RealNumberType())
|
||||
.INPUT(x, TensorType::RealNumberType())
|
||||
.INPUT(argmax, TensorType({DT_INT32}))
|
||||
.OUTPUT(output_grad, TensorType::RealNumberType())
|
||||
.CUST_OP_END_FACTORY_REG(AdaptiveMaxPool3dGrad)
|
||||
|
||||
REG_CUST_OP(MultiMarginLossGrad)
|
||||
.INPUT(y_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include "graph/operator_reg.h"
|
||||
#include "graph/operator.h"
|
||||
#include "transform/graph_ir/custom_op_proto/op_proto_macro.h"
|
||||
|
||||
/* clang-format off */
|
||||
|
||||
|
@ -29,5 +30,12 @@ REG_OP(KVCacheMgr)
|
|||
.INPUT(index, TensorType({DT_INT32}))
|
||||
.OUTPUT(past, TensorType({DT_FLOAT16}))
|
||||
.OP_END_FACTORY_REG(KVCacheMgr)
|
||||
|
||||
REG_CUST_OP(NoRepeatNGram)
|
||||
.INPUT(state_seq, TensorType({DT_INT32}))
|
||||
.INPUT(log_probs, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.OUTPUT(out, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.REQUIRED_ATTR(ngram_size, Int)
|
||||
.CUST_OP_END_FACTORY_REG(NoRepeatNGram)
|
||||
} // namespace ge
|
||||
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_OTHER_OPS_H_
|
||||
|
|
|
@ -39,5 +39,14 @@ REG_CUST_OP(Dropout2D)
|
|||
.OUTPUT(mask, TensorType({DT_BOOL}))
|
||||
.REQUIRED_ATTR(keep_prob, Float)
|
||||
.CUST_OP_END_FACTORY_REG(Dropout2D)
|
||||
|
||||
REG_CUST_OP(StandardLaplace)
|
||||
.INPUT(shape, TensorType({DT_INT32, DT_INT64}))
|
||||
.INPUT(seed, TensorType({DT_INT64}))
|
||||
.INPUT(seed2, TensorType({DT_INT64}))
|
||||
.OUTPUT(output, TensorType({DT_FLOAT}))
|
||||
.REQUIRED_ATTR(seed, Int)
|
||||
.REQUIRED_ATTR(seed2, Int)
|
||||
.CUST_OP_END_FACTORY_REG(StandardLaplace)
|
||||
} // namespace ge
|
||||
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_RANDOM_OPS_H_
|
||||
|
|
|
@ -378,6 +378,7 @@ constexpr const char kNameDynamicShape[] = "DynamicShape";
|
|||
constexpr const char kNameGather[] = "Gather";
|
||||
constexpr const char kNameUnsqueeze[] = "Unsqueeze";
|
||||
constexpr const char kNamePadV3[] = "PadV3";
|
||||
constexpr const char kNamePadV3Grad[] = "PadV3Grad";
|
||||
constexpr const char kNamePadV2[] = "PadV2";
|
||||
constexpr const char kNameGlobalAvgPool[] = "GlobalAveragePool";
|
||||
constexpr const char kNameAdaptiveMaxPool2D[] = "AdaptiveMaxPool2D";
|
||||
|
|
|
@ -211,11 +211,11 @@ OUTPUT_MAP(Size) = {{0, OUTPUT_DESC(y)}};
|
|||
REG_ADPT_DESC(Size, kNameSize, ADPT_DESC(Size))
|
||||
|
||||
// Meshgrid
|
||||
INPUT_MAP(Meshgrid) = EMPTY_INPUT_MAP;
|
||||
DYN_INPUT_MAP(Meshgrid) = {{1, DYN_INPUT_DESC(x)}};
|
||||
ATTR_MAP(Meshgrid) = {{"indexing", ATTR_DESC(indexing, AnyTraits<std::string>())}};
|
||||
DYN_OUTPUT_MAP(Meshgrid) = {{0, DYN_OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Meshgrid, prim::kPrimMeshgrid->name(), ADPT_DESC(Meshgrid))
|
||||
CUST_INPUT_MAP(Meshgrid) = EMPTY_INPUT_MAP;
|
||||
CUST_DYN_INPUT_MAP(Meshgrid) = {{1, DYN_INPUT_DESC(x)}};
|
||||
CUST_ATTR_MAP(Meshgrid) = {{"indexing", ATTR_DESC(indexing, AnyTraits<std::string>())}};
|
||||
CUST_DYN_OUTPUT_MAP(Meshgrid) = {{0, DYN_OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Meshgrid, prim::kPrimMeshgrid->name(), CUST_ADPT_DESC(Meshgrid))
|
||||
|
||||
// SliceGrad
|
||||
CUST_INPUT_MAP(SliceGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(begin)}, {4, INPUT_DESC(size)}};
|
||||
|
@ -306,4 +306,30 @@ CUST_ATTR_MAP(LogSpace) = {{"steps", ATTR_DESC(steps, AnyTraits<int64_t>())},
|
|||
{"base", ATTR_DESC(base, AnyTraits<int64_t>())}};
|
||||
CUST_OUTPUT_MAP(LogSpace) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(LogSpace, prim::kPrimLogSpace->name(), CUST_ADPT_DESC(LogSpace));
|
||||
|
||||
// UniqueConsecutive
|
||||
INPUT_MAP(UniqueConsecutive) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(UniqueConsecutive) = {{"return_idx", ATTR_DESC(return_idx, AnyTraits<bool>())},
|
||||
{"return_counts", ATTR_DESC(return_counts, AnyTraits<bool>())},
|
||||
{"axis", ATTR_DESC(axis, AnyTraits<int64_t>())}};
|
||||
OUTPUT_MAP(UniqueConsecutive) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(idx)}, {2, OUTPUT_DESC(count)}};
|
||||
REG_ADPT_DESC(UniqueConsecutive, prim::kPrimUniqueConsecutive->name(), ADPT_DESC(UniqueConsecutive));
|
||||
|
||||
// UpperBound
|
||||
INPUT_MAP(UpperBound) = {{1, INPUT_DESC(sorted_x)}, {2, INPUT_DESC(values)}};
|
||||
ATTR_MAP(UpperBound) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(UpperBound) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(UpperBound, prim::kPrimUpperBound->name(), ADPT_DESC(UpperBound));
|
||||
|
||||
// UnravelIndex
|
||||
INPUT_MAP(UnravelIndex) = {{1, INPUT_DESC(indices)}, {2, INPUT_DESC(dims)}};
|
||||
ATTR_MAP(UnravelIndex) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(UnravelIndex) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(UnravelIndex, prim::kPrimUnravelIndex->name(), ADPT_DESC(UnravelIndex));
|
||||
|
||||
// NoRepeatNGram
|
||||
CUST_INPUT_MAP(NoRepeatNGram) = {{1, INPUT_DESC(state_seq)}, {2, INPUT_DESC(log_probs)}};
|
||||
CUST_ATTR_MAP(NoRepeatNGram) = {{"ngram_size", ATTR_DESC(ngram_size, AnyTraits<int64_t>())}};
|
||||
CUST_OUTPUT_MAP(NoRepeatNGram) = {{0, OUTPUT_DESC(out)}};
|
||||
REG_ADPT_DESC(NoRepeatNGram, prim::kPrimNoRepeatNGram->name(), CUST_ADPT_DESC(NoRepeatNGram));
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "inc/ops/array_ops.h"
|
||||
#include "inc/ops/selection_ops.h"
|
||||
#include "transform/graph_ir/custom_op_proto/cust_array_ops.h"
|
||||
#include "transform/graph_ir/custom_op_proto/cust_other_ops.h"
|
||||
#include "inc/ops/transformation_ops.h"
|
||||
#include "transform/graph_ir/op_declare/op_declare_macro.h"
|
||||
#include "utils/hash_map.h"
|
||||
|
@ -107,9 +108,9 @@ DECLARE_OP_USE_OUTPUT(QueueData)
|
|||
DECLARE_OP_ADAPTER(Size)
|
||||
DECLARE_OP_USE_OUTPUT(Size)
|
||||
|
||||
DECLARE_OP_ADAPTER(Meshgrid)
|
||||
DECLARE_OP_USE_DYN_INPUT(Meshgrid)
|
||||
DECLARE_OP_USE_DYN_OUTPUT(Meshgrid)
|
||||
DECLARE_CUST_OP_ADAPTER(Meshgrid)
|
||||
DECLARE_CUST_OP_USE_DYN_INPUT(Meshgrid)
|
||||
DECLARE_CUST_OP_USE_DYN_OUTPUT(Meshgrid)
|
||||
|
||||
DECLARE_CUST_OP_ADAPTER(SliceGrad)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(SliceGrad)
|
||||
|
@ -149,4 +150,16 @@ DECLARE_CUST_OP_USE_OUTPUT(MvlgammaGrad)
|
|||
|
||||
DECLARE_CUST_OP_ADAPTER(LogSpace)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(LogSpace)
|
||||
|
||||
DECLARE_OP_ADAPTER(UniqueConsecutive)
|
||||
DECLARE_OP_USE_OUTPUT(UniqueConsecutive)
|
||||
|
||||
DECLARE_OP_ADAPTER(UpperBound)
|
||||
DECLARE_OP_USE_OUTPUT(UpperBound)
|
||||
|
||||
DECLARE_OP_ADAPTER(UnravelIndex)
|
||||
DECLARE_OP_USE_OUTPUT(UnravelIndex)
|
||||
|
||||
DECLARE_CUST_OP_ADAPTER(NoRepeatNGram)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(NoRepeatNGram)
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_
|
||||
|
|
|
@ -189,4 +189,27 @@ ATTR_MAP(ExtractGlimpse) = {{"noise", ATTR_DESC(noise, AnyTraits<std::string>())
|
|||
{"uniform_noise", ATTR_DESC(uniform_noise, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ExtractGlimpse) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ExtractGlimpse, prim::kPrimExtractGlimpse->name(), ADPT_DESC(ExtractGlimpse));
|
||||
|
||||
// ScaleAndTranslateGrad
|
||||
INPUT_MAP(ScaleAndTranslateGrad) = {
|
||||
{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}, {3, INPUT_DESC(scale)}, {4, INPUT_DESC(translation)}};
|
||||
ATTR_MAP(ScaleAndTranslateGrad) = {{"kernel_type", ATTR_DESC(kernel_type, AnyTraits<std::string>())},
|
||||
{"antialias", ATTR_DESC(antialias, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ScaleAndTranslateGrad) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ScaleAndTranslateGrad, prim::kPrimScaleAndTranslateGrad->name(), ADPT_DESC(ScaleAndTranslateGrad));
|
||||
|
||||
// ResizeBicubicGrad
|
||||
INPUT_MAP(ResizeBicubicGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}};
|
||||
ATTR_MAP(ResizeBicubicGrad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())},
|
||||
{"half_pixel_centers", ATTR_DESC(half_pixel_centers, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ResizeBicubicGrad) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ResizeBicubicGrad, prim::kPrimResizeBicubicGrad->name(), ADPT_DESC(ResizeBicubicGrad));
|
||||
|
||||
// ScaleAndTranslate
|
||||
INPUT_MAP(ScaleAndTranslate) = {
|
||||
{1, INPUT_DESC(images)}, {2, INPUT_DESC(size)}, {3, INPUT_DESC(scale)}, {4, INPUT_DESC(translation)}};
|
||||
ATTR_MAP(ScaleAndTranslate) = {{"kernel_type", ATTR_DESC(kernel_type, AnyTraits<std::string>())},
|
||||
{"antialias", ATTR_DESC(antialias, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ScaleAndTranslate) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ScaleAndTranslate, prim::kPrimScaleAndTranslate->name(), ADPT_DESC(ScaleAndTranslate));
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -77,4 +77,13 @@ DECLARE_OP_USE_OUTPUT(AdjustHue)
|
|||
|
||||
DECLARE_OP_ADAPTER(ExtractGlimpse)
|
||||
DECLARE_OP_USE_OUTPUT(ExtractGlimpse)
|
||||
|
||||
DECLARE_OP_ADAPTER(ScaleAndTranslateGrad)
|
||||
DECLARE_OP_USE_OUTPUT(ScaleAndTranslateGrad)
|
||||
|
||||
DECLARE_OP_ADAPTER(ResizeBicubicGrad)
|
||||
DECLARE_OP_USE_OUTPUT(ResizeBicubicGrad)
|
||||
|
||||
DECLARE_OP_ADAPTER(ScaleAndTranslate)
|
||||
DECLARE_OP_USE_OUTPUT(ScaleAndTranslate)
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_IMAGE_OPS_DECLARE_H_
|
||||
|
|
|
@ -100,4 +100,10 @@ CUST_INPUT_MAP(LuSolve) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(lu_data)}, {3, INP
|
|||
CUST_ATTR_MAP(LuSolve) = EMPTY_ATTR_MAP;
|
||||
CUST_OUTPUT_MAP(LuSolve) = {{0, OUTPUT_DESC(output)}};
|
||||
REG_ADPT_DESC(LuSolve, prim::kPrimLuSolve->name(), CUST_ADPT_DESC(LuSolve));
|
||||
|
||||
// Qr
|
||||
INPUT_MAP(Qr) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(Qr) = {{"full_matrices", ATTR_DESC(full_matrices, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(Qr) = {{0, OUTPUT_DESC(q)}, {1, OUTPUT_DESC(r)}};
|
||||
REG_ADPT_DESC(Qr, prim::kPrimQr->name(), ADPT_DESC(Qr));
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -60,4 +60,7 @@ DECLARE_CUST_OP_USE_OUTPUT(LuUnpackGrad)
|
|||
|
||||
DECLARE_CUST_OP_ADAPTER(LuSolve)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(LuSolve)
|
||||
|
||||
DECLARE_OP_ADAPTER(Qr)
|
||||
DECLARE_OP_USE_OUTPUT(Qr)
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_LINALG_OPS_DECLARE_H_
|
||||
|
|
|
@ -309,4 +309,34 @@ CUST_INPUT_MAP(Gcd) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
|
|||
CUST_ATTR_MAP(Gcd) = EMPTY_ATTR_MAP;
|
||||
CUST_OUTPUT_MAP(Gcd) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Gcd, prim::kPrimGcd->name(), CUST_ADPT_DESC(Gcd));
|
||||
|
||||
// Orgqr
|
||||
CUST_INPUT_MAP(Orgqr) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(tau)}};
|
||||
CUST_ATTR_MAP(Orgqr) = EMPTY_ATTR_MAP;
|
||||
CUST_OUTPUT_MAP(Orgqr) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Orgqr, prim::kPrimOrgqr->name(), CUST_ADPT_DESC(Orgqr));
|
||||
|
||||
// RaggedRange
|
||||
INPUT_MAP(RaggedRange) = {{1, INPUT_DESC(starts)}, {2, INPUT_DESC(limits)}, {3, INPUT_DESC(deltas)}};
|
||||
ATTR_MAP(RaggedRange) = {{"Tsplits", ATTR_DESC(Tsplits, AnyTraits<GEType>())}};
|
||||
OUTPUT_MAP(RaggedRange) = {{0, OUTPUT_DESC(rt_nested_splits)}, {1, OUTPUT_DESC(rt_dense_values)}};
|
||||
REG_ADPT_DESC(RaggedRange, prim::kPrimRaggedRange->name(), ADPT_DESC(RaggedRange));
|
||||
|
||||
// Imag
|
||||
INPUT_MAP(Imag) = {{1, INPUT_DESC(input)}};
|
||||
ATTR_MAP(Imag) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(Imag) = {{0, OUTPUT_DESC(output)}};
|
||||
REG_ADPT_DESC(Imag, prim::kPrimImag->name(), ADPT_DESC(Imag));
|
||||
|
||||
// Lgamma
|
||||
CUST_INPUT_MAP(Lgamma) = {{1, INPUT_DESC(x)}};
|
||||
CUST_ATTR_MAP(Lgamma) = EMPTY_ATTR_MAP;
|
||||
CUST_OUTPUT_MAP(Lgamma) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Lgamma, prim::kPrimLgamma->name(), CUST_ADPT_DESC(Lgamma));
|
||||
|
||||
// Real
|
||||
INPUT_MAP(Real) = {{1, INPUT_DESC(input)}};
|
||||
ATTR_MAP(Real) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(Real) = {{0, OUTPUT_DESC(output)}};
|
||||
REG_ADPT_DESC(Real, prim::kPrimReal->name(), ADPT_DESC(Real));
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATH_OPS_DECLARE_H_
|
||||
|
||||
#include "inc/ops/math_ops.h"
|
||||
#include "inc/ops/ragged_math_ops.h"
|
||||
#include "inc/ops/spectral_ops.h"
|
||||
#include "transform/graph_ir/custom_op_proto/cust_math_ops.h"
|
||||
#include "mindspore/ccsrc/include/common/utils/utils.h"
|
||||
|
@ -146,4 +147,19 @@ DECLARE_CUST_OP_USE_OUTPUT(Heaviside)
|
|||
|
||||
DECLARE_CUST_OP_ADAPTER(Gcd)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(Gcd)
|
||||
|
||||
DECLARE_CUST_OP_ADAPTER(Orgqr)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(Orgqr)
|
||||
|
||||
DECLARE_OP_ADAPTER(RaggedRange)
|
||||
DECLARE_OP_USE_OUTPUT(RaggedRange)
|
||||
|
||||
DECLARE_OP_ADAPTER(Imag)
|
||||
DECLARE_OP_USE_OUTPUT(Imag)
|
||||
|
||||
DECLARE_CUST_OP_ADAPTER(Lgamma)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(Lgamma)
|
||||
|
||||
DECLARE_OP_ADAPTER(Real)
|
||||
DECLARE_OP_USE_OUTPUT(Real)
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATH_OPS_DECLARE_H_
|
||||
|
|
|
@ -251,4 +251,16 @@ ATTR_MAP(FillDiagonal) = {{"fill_value", ATTR_DESC(fill_value, AnyTraits<float>(
|
|||
{"wrap", ATTR_DESC(wrap, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(FillDiagonal) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(FillDiagonal, kNameFillDiagonal, ADPT_DESC(FillDiagonal));
|
||||
|
||||
// Trace
|
||||
INPUT_MAP(Trace) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(Trace) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(Trace) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Trace, prim::kPrimTrace->name(), ADPT_DESC(Trace));
|
||||
|
||||
// TraceGrad
|
||||
CUST_INPUT_MAP(TraceGrad) = {{1, INPUT_DESC(y_grad)}, {2, INPUT_DESC(x_shape)}};
|
||||
CUST_ATTR_MAP(TraceGrad) = EMPTY_ATTR_MAP;
|
||||
CUST_OUTPUT_MAP(TraceGrad) = {{0, OUTPUT_DESC(x_grad)}};
|
||||
REG_ADPT_DESC(TraceGrad, prim::kPrimTraceGrad->name(), CUST_ADPT_DESC(TraceGrad));
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include "mindspore/ccsrc/include/common/utils/utils.h"
|
||||
#include "inc/ops/matrix_calculation_ops.h"
|
||||
#include "transform/graph_ir/custom_op_proto/cust_math_ops.h"
|
||||
#include "transform/graph_ir/op_declare/op_declare_macro.h"
|
||||
#include "utils/hash_map.h"
|
||||
|
||||
|
@ -126,4 +127,10 @@ DECLARE_OP_USE_OUTPUT(Eye)
|
|||
|
||||
DECLARE_OP_ADAPTER(FillDiagonal)
|
||||
DECLARE_OP_USE_OUTPUT(FillDiagonal)
|
||||
|
||||
DECLARE_OP_ADAPTER(Trace)
|
||||
DECLARE_OP_USE_OUTPUT(Trace)
|
||||
|
||||
DECLARE_CUST_OP_ADAPTER(TraceGrad)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(TraceGrad)
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATRIX_CALCULATION_OPS_DECLARE_H_
|
||||
|
|
|
@ -376,4 +376,10 @@ ATTR_MAP(FractionalAvgPool) = {{"pooling_ratio", ATTR_DESC(pooling_ratio, AnyTra
|
|||
OUTPUT_MAP(FractionalAvgPool) = {
|
||||
{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(row_pooling_sequence)}, {2, OUTPUT_DESC(col_pooling_sequence)}};
|
||||
REG_ADPT_DESC(FractionalAvgPool, prim::kPrimFractionalAvgPool->name(), ADPT_DESC(FractionalAvgPool));
|
||||
|
||||
// NthElement
|
||||
INPUT_MAP(NthElement) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(n)}};
|
||||
ATTR_MAP(NthElement) = {{"reverse", ATTR_DESC(reverse, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(NthElement) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(NthElement, prim::kPrimNthElement->name(), ADPT_DESC(NthElement));
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -130,4 +130,7 @@ DECLARE_OP_USE_OUTPUT(FractionalMaxPoolGrad)
|
|||
|
||||
DECLARE_OP_ADAPTER(FractionalAvgPool)
|
||||
DECLARE_OP_USE_OUTPUT(FractionalAvgPool)
|
||||
|
||||
DECLARE_OP_ADAPTER(NthElement)
|
||||
DECLARE_OP_USE_OUTPUT(NthElement)
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NN_POOLING_OPS_DECLARE_H_
|
||||
|
|
|
@ -64,4 +64,11 @@ INPUT_MAP(PadV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}, {3, INPUT_DES
|
|||
ATTR_MAP(PadV2) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(PadV2) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(PadV2, kNamePadV2, ADPT_DESC(PadV2))
|
||||
|
||||
// PadV3Grad
|
||||
INPUT_MAP(PadV3Grad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}};
|
||||
ATTR_MAP(PadV3Grad) = {{"mode", ATTR_DESC(mode, AnyTraits<std::string>())},
|
||||
{"paddings_contiguous", ATTR_DESC(paddings_contiguous, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(PadV3Grad) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(PadV3Grad, kNamePadV3Grad, ADPT_DESC(PadV3Grad));
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -44,4 +44,7 @@ DECLARE_OP_USE_OUTPUT(PadV3)
|
|||
|
||||
DECLARE_OP_ADAPTER(PadV2)
|
||||
DECLARE_OP_USE_OUTPUT(PadV2)
|
||||
|
||||
DECLARE_OP_ADAPTER(PadV3Grad)
|
||||
DECLARE_OP_USE_OUTPUT(PadV3Grad)
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_PAD_OPS_DECLARE_H_
|
||||
|
|
|
@ -117,4 +117,11 @@ CUST_INPUT_MAP(Dropout2D) = {{1, INPUT_DESC(x)}};
|
|||
CUST_ATTR_MAP(Dropout2D) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())}};
|
||||
CUST_OUTPUT_MAP(Dropout2D) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(mask)}};
|
||||
REG_ADPT_DESC(Dropout2D, kNameDropout2D, CUST_ADPT_DESC(Dropout2D))
|
||||
|
||||
// StandardLaplace
|
||||
CUST_INPUT_MAP(StandardLaplace) = {{1, INPUT_DESC(shape)}, {2, INPUT_DESC(seed)}, {3, INPUT_DESC(seed2)}};
|
||||
CUST_ATTR_MAP(StandardLaplace) = {{"seed", ATTR_DESC(seed, AnyTraits<int64_t>())},
|
||||
{"seed2", ATTR_DESC(seed2, AnyTraits<int64_t>())}};
|
||||
CUST_OUTPUT_MAP(StandardLaplace) = {{0, OUTPUT_DESC(output)}};
|
||||
REG_ADPT_DESC(StandardLaplace, prim::kPrimStandardLaplace->name(), CUST_ADPT_DESC(StandardLaplace));
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -64,4 +64,7 @@ DECLARE_OP_USE_OUTPUT(NonDeterministicInts)
|
|||
|
||||
DECLARE_CUST_OP_ADAPTER(Dropout2D)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(Dropout2D)
|
||||
|
||||
DECLARE_CUST_OP_ADAPTER(StandardLaplace)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(StandardLaplace)
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_RANDOM_OPS_DECLARE_H_
|
||||
|
|
Loading…
Reference in New Issue