Add regop and adapters for custom aicpu

This commit is contained in:
panzhihui 2023-07-14 16:04:45 +08:00
parent 0faf4bd9da
commit a2adaa3917
119 changed files with 8432 additions and 1214 deletions

View File

@ -183,6 +183,8 @@
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/" "readability/casting"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/" "readability/namespace"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/" "readability/braces"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/" "readability/braces"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/" ""
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/inc/" "whitespace/ending_newline"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/inc/" "build/include_subdir"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/inc/" "runtime/references"

View File

@ -415,3 +415,12 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/
mindspore/mindspore/lite/src/litert/kernel/opencl/kernel/conv2d.cc:mindspore::kernel::UseWinograd4x4To6x6
mindspore/mindspore/lite/src/litert/kernel/opencl/kernel/fullconnection.cc:mindspore::kernel::FullConnectionOpenCLKernel::CheckSpecs
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/concat_proto.cc:ge::ConcatInferShapeCommon
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/elewise_calculation_ops_proto.cc:ge::IMPLEMT_COMMON_INFERFUNC
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/nn_pooling_ops_proto.cc:ge::CUST_IMPLEMT_VERIFIER
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/combined_non_max_suppression_proto.cc:ge::IMPLEMT_INFERFUNC
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/math_ops_proto.cc:ge::IMPLEMT_INFERFUNC
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/concat_proto.cc:ge::ConcatInferShapeCommon
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/combined_non_max_suppression_proto.cc:ge::IMPLEMT_INFERFUNC
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/im2col_proto.cc:ge::IMPLEMT_COMMON_INFERFUNC
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/math_ops_proto.cc:ge::IMPLEMT_INFERFUNC

View File

@ -36,7 +36,7 @@ input1.name=dy
input1.type=DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_COMPLEX64,DT_COMPLEX128
output0.name=z
output0.type=DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_COMPLEX64,DT_COMPLEX128
[AdaptiveAvgPool2d]
[AdaptiveAvgPool2D]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
@ -807,25 +807,6 @@ output1.type=DT_FLOAT,DT_DOUBLE,DT_COMPLEX64,DT_COMPLEX128
output1.name=values
output2.type=DT_INT32,DT_INT64
output2.name=dense_shape
[Cumprod]
opInfo.engine=DNN_VM_AICPU
opInfo.flagPartial=False
opInfo.computeCost=100
opInfo.flagAsync=False
opInfo.opKernelLib=CUSTAICPUKernel
opInfo.kernelSo=libcust_cpu_kernels.so
opInfo.functionName=RunCpuKernel
opInfo.workspaceSize=1024
opInfo.opsFlag=OPS_FLAG_CLOSE
opInfo.userDefined=True
opInfo.subTypeOfInferShape=1
opInfo.formatAgnostic=False
input0.type=DT_INT8,DT_INT16,DT_INT32,DT_INT64,DT_UINT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_COMPLEX64,DT_COMPLEX128
input0.name=x
input1.type=DT_INT32,DT_INT64
input1.name=axis
output0.type=DT_INT8,DT_INT16,DT_INT32,DT_INT64,DT_UINT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_COMPLEX64,DT_COMPLEX128
output0.name=y
[CumulativeLogsumexp]
opInfo.engine=DNN_VM_AICPU
opInfo.flagPartial=False
@ -4076,3 +4057,709 @@ opInfo.opsFlag=OPS_FLAG_CLOSE
opInfo.userDefined=True
opInfo.subTypeOfInferShape=1
opInfo.formatAgnostic=False
[ArgMin]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
input1.name=dimension
input1.type=DT_INT32,DT_INT64
output0.name=y
output0.type=DT_INT32,DT_INT64
[Diag]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT32,DT_INT64
output0.name=y
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT32,DT_INT64
[Betainc]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=a
input0.type=DT_DOUBLE,DT_FLOAT
input1.name=b
input1.type=DT_DOUBLE,DT_FLOAT
input2.name=x
input2.type=DT_DOUBLE,DT_FLOAT
output0.name=z
output0.type=DT_DOUBLE,DT_FLOAT
[DivNoNan]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x1
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
input1.name=x2
input1.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
output0.name=y
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
[Expm1]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
output0.name=y
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
[ArgMaxWithValue]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
output0.name=indice
output0.type=DT_INT32
output1.name=values
output1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
[Bucketize]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=input
input0.type=DT_DOUBLE,DT_FLOAT,DT_INT32,DT_INT64
output0.name=output
output0.type=DT_INT32
[DiagPart]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT32,DT_INT64
output0.name=y
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT32,DT_INT64
[CheckNumerics]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
output0.name=y
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
[CumProd]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
input1.name=axis
input1.type=DT_INT32,DT_INT64
output0.name=y
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
[Cos]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
output0.name=y
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
[AdaptiveAvgPool2D]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
output0.name=y
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
[AdjustSaturation]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=image
input0.type=DT_FLOAT,DT_FLOAT16
input1.name=scale
input1.type=DT_FLOAT
output0.name=y
output0.type=DT_FLOAT,DT_FLOAT16
[ACosGrad]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=y
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
input1.name=dy
input1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
output0.name=z
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
[AffineGridGrad]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=y_grad
input0.type=DT_FLOAT,DT_FLOAT16
input1.name=x_size
input1.type=DT_INT32,DT_INT64
output0.name=x_grad
output0.type=DT_FLOAT,DT_FLOAT16
[AddN]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
output0.name=y
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
[Div]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT8
input1.name=y
input1.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT8
output0.name=output
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT8
[BiasAddGrad]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
output0.name=y
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
[ArgMinWithValue]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
output0.name=index
output0.type=DT_INT32
output1.name=values
output1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
[AdaptiveAvgPool3D]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
output0.name=y
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
[AdaptiveAvgPool2DGrad]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=input_grad
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
input1.name=orig_input_shape
input1.type=DT_INT64
output0.name=output_grad
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
[Hypot]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x1
input0.type=DT_DOUBLE,DT_FLOAT
input1.name=x2
input1.type=DT_DOUBLE,DT_FLOAT
output0.name=y
output0.type=DT_DOUBLE,DT_FLOAT
[Lcm]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x1
input0.type=DT_INT32,DT_INT64
input1.name=x2
input1.type=DT_INT32,DT_INT64
output0.name=y
output0.type=DT_INT32,DT_INT64
[MatrixExp]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
output0.name=y
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
[Log1p]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
output0.name=y
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
[LessEqual]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x1
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
input1.name=x2
input1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
output0.name=y
output0.type=DT_BOOL
[Heaviside]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
input1.name=values
input1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
output0.name=y
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
[MaskedSelect]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_FLOAT,DT_INT32
input1.name=mask
input1.type=DT_BOOL
output0.name=y
output0.type=DT_FLOAT,DT_INT32
[Multinomial]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=input
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
input1.name=num_sample
input1.type=DT_INT32
input2.name=count
input2.type=DT_UINT64
input3.name=state
input3.type=DT_UINT64
output0.name=output
output0.type=DT_INT32,DT_INT64
[MatrixDeterminant]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_DOUBLE,DT_FLOAT
output0.name=y
output0.type=DT_DOUBLE,DT_FLOAT
[FillDiagonal]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=input_x
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
output0.name=y
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
[Gcd]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x1
input0.type=DT_INT32,DT_INT64
input1.name=x2
input1.type=DT_INT32,DT_INT64
output0.name=y
output0.type=DT_INT32,DT_INT64
[MatrixTriangularSolve]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=matrix
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT
input1.name=rhs
input1.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT
output0.name=y
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT
[FloorDiv]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT8
input1.name=y
input1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT8
output0.name=output
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT8
[LuUnpack]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=LU_data
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
input1.name=LU_pivots
input1.type=DT_INT8, DT_INT16,DT_INT32,DT_INT64,DT_UINT8
output0.name=pivots
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
output1.name=L
output1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
output1.name=U
output1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
[LuUnpackGrad]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=L_grad
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
input1.name=U_grad
input1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
input2.name=LU_data
input2.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
output0.name=L_data_grad
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
output1.name=U_data_grad
output1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
[IsInf]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
output0.name=y
output0.type=DT_BOOL
[Dropout2D]
opInfo.subTypeOfInferShape = 1
opInfo.opsFlag = OPS_FLAG_CLOSE
opInfo.engine = DNN_VM_AICPU
opInfo.flagPartial = False
opInfo.computeCost = 100
opInfo.flagAsync = False
opInfo.opKernelLib = CUSTAICPUKernel
opInfo.formatAgnostic = False
opInfo.userDefined = True
opInfo.workspaceSize = 1024
opInfo.kernelSo = libcust_cpu_kernels.so
opInfo.functionName = RunCpuKernel
input0.name=x
input0.type=DT_BOOL,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
output0.name=y
output0.type=DT_BOOL,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
output1.name=mask
output1.type=DT_BOOL

View File

@ -0,0 +1,44 @@
/**
* Copyright (c) 2023 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "custom_op_proto/cust_nn_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
CUST_IMPLEMT_INFERFUNC(AdaptiveAvgPool2DGrad, AdaptiveAvgPool2dGradInferShape) {
std::vector<std::string> input_infer_depends = {"orig_input_shape"};
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
op_desc->SetOpInferDepends(input_infer_depends);
DataType input_dtype = op.GetInputDescByName("input_grad").GetDataType();
Shape output_shape;
Tensor orig_input_shape_tensor;
if (op.GetInputConstData("orig_input_shape", orig_input_shape_tensor) != GRAPH_SUCCESS) {
auto output_desc = op.GetOutputDescByName("output_grad");
output_desc.SetDataType(input_dtype);
output_desc.SetShape(Shape(ge::UNKNOWN_RANK));
return op.UpdateOutputDesc("output_grad", output_desc);
}
MakeShapeFromShapeTensor(orig_input_shape_tensor, output_shape, op);
TensorDesc output_grad = op.GetOutputDescByName("output_grad");
output_grad.SetShape(output_shape);
output_grad.SetDataType(input_dtype);
return op.UpdateOutputDesc("output_grad", output_grad);
}
CUST_INFER_FUNC_REG(AdaptiveAvgPool2DGrad, AdaptiveAvgPool2dGradInferShape);
} // namespace ge

View File

@ -0,0 +1,65 @@
/**
* Copyright (c) 2023 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "custom_op_proto/cust_nn_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
// ---------------AdaptiveAvgPool2D-------------------
CUST_IMPLEMT_INFERFUNC(AdaptiveAvgPool2D, AdaptiveAvgPool2dInferShape) {
OP_LOGI(TbeGetName(op).c_str(), " AdaptiveAvgPool2d inferShape begin!");
const size_t DIM_SIZE2 = 2;
auto input_tensor_desc = op.GetInputDescByName("x");
auto shape = input_tensor_desc.GetShape();
// get output_size
std::vector<int64_t> ouput_size_list;
if (GRAPH_SUCCESS != op.GetAttr("output_size", ouput_size_list)) {
OP_LOGE(TbeGetName(op).c_str(), "GetOpAttr ouput_size_list failed!");
return GRAPH_FAILED;
}
// check output size
if (ouput_size_list.size() != DIM_SIZE2) {
OP_LOGE(TbeGetName(op).c_str(), "length of output_size must be 2");
return GRAPH_FAILED;
}
std::vector<int64_t> dims_input = shape.GetDims();
// set output shape
std::vector<int64_t> dim_vector;
for (size_t i = 0; i < dims_input.size(); i++) {
int64_t dims = dims_input[i];
dim_vector.push_back(dims);
}
size_t index0 = dims_input.size() - 2;
size_t index1 = dims_input.size() - 1;
dim_vector[index0] = ouput_size_list[0];
dim_vector[index1] = ouput_size_list[1];
TensorDesc td = op.GetOutputDescByName("y");
DataType input_dtype = input_tensor_desc.GetDataType();
Shape output_shape(dim_vector);
td.SetShape(output_shape);
td.SetDataType(input_dtype);
(void)op.UpdateOutputDesc("y", td);
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(AdaptiveAvgPool2D, AdaptiveAvgPool2dVerify) { return GRAPH_SUCCESS; }
CUST_INFER_FUNC_REG(AdaptiveAvgPool2D, AdaptiveAvgPool2dInferShape);
CUST_VERIFY_FUNC_REG(AdaptiveAvgPool2D, AdaptiveAvgPool2dVerify);
// ---------------AdaptiveAvgPool2D End---------------
} // namespace ge

View File

@ -26,15 +26,13 @@ IMPLEMT_COMMON_INFERFUNC(AdaptiveAvgPool3dInferShape) {
{ge::FORMAT_DHWCN, "DHWCN"}, {ge::FORMAT_NDHWC, "NDHWC"}, {ge::FORMAT_NCDHW, "NCDHW"}};
// verify the dim of output_size
auto output_size_desc = op.GetInputDescByName("output_size");
auto output_size_dim = output_size_desc.GetShape().GetDimNum();
ge::AscendString op_name;
(void)op.GetName(op_name);
if (output_size_dim != 1) {
OP_LOGE("AdaptiveAvgPool3d", "Num Dim of output_szie is invalid");
std::vector<int64_t> output_size;
if (GRAPH_SUCCESS != op.GetAttr("output_size", output_size)) {
OP_LOGE(TbeGetName(op).c_str(), "GetOpAttr output_size failed!");
return GRAPH_PARAM_INVALID;
}
ge::AscendString op_name;
(void)op.GetName(op_name);
auto input_desc = op.GetInputDescByName("x");
TensorDesc out_desc = op.GetOutputDescByName("y");
@ -56,38 +54,20 @@ IMPLEMT_COMMON_INFERFUNC(AdaptiveAvgPool3dInferShape) {
std::vector<int64_t> input_size_shape = input_desc.GetShape().GetDims();
auto input_size_dim_num = input_size_shape.size();
std::vector<int64_t> output_shape(input_size_dim_num);
for (uint64_t i = 0; i < input_size_dim_num - 3; ++i) {
output_shape[i] = input_size_shape[i];
}
Tensor output_size_tensor;
if (op.GetInputConstData("output_size", output_size_tensor) != GRAPH_SUCCESS) {
OP_LOGE("AdaptiveAvgPool3d", "failed to get tensor from output_size");
return GRAPH_FAILED;
}
int32_t *output_size_data = reinterpret_cast<int32_t *>(output_size_tensor.GetData());
if (output_size_data == nullptr) {
OP_LOGE("AdaptiveAvgPool3d", "output_size data is invalid");
return GRAPH_PARAM_INVALID;
}
auto output_size_num = output_size_desc.GetShape().GetShapeSize();
std::vector<int64_t> output_shape(input_size_shape.begin(), input_size_shape.end());
auto output_size_num = output_size.size();
if (output_size_num == 1) {
for (uint64_t i = input_size_dim_num - 3; i < input_size_dim_num; ++i) {
if (output_size_data[0] < 0) {
OP_LOGE("AdaptiveAvgPool3d", "Value of output_size can\'t be negative");
return GRAPH_PARAM_INVALID;
if (output_size[0] < 0) {
continue;
}
output_shape[i] = output_size_data[0];
output_shape[i] = output_size[0];
}
} else if (output_size_num == 3) {
for (uint64_t i = input_size_dim_num - 3; i < input_size_dim_num; ++i) {
auto data = output_size_data[i - input_size_dim_num + 3];
auto data = output_size[i - input_size_dim_num + 3];
if (data < 0) {
OP_LOGE("AdaptiveAvgPool3d", "Value of output_size can\'t be negative");
return GRAPH_PARAM_INVALID;
continue;
}
output_shape[i] = data;
}

View File

@ -26,14 +26,14 @@ CUST_IMPLEMT_INFERFUNC(AdjustContrast, AdjustContrastInfer) {
GeShape shape;
std::string err_msg;
auto contrast_factor_desc = op_desc->MutableInputDesc(1);
if (WithRank(contrast_factor_desc, 0, shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(contrast_factor_desc, 0, shape, op) != GRAPH_SUCCESS) {
err_msg = GetShapeErrMsg(1, DebugString(contrast_factor_desc->GetShape().GetDims()), "scalar");
err_msg = string("failed to call WithRank function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
auto images_desc = op_desc->MutableInputDesc(0);
if (WithRankAtLeast(images_desc, 3, shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRankAtLeast(images_desc, 3, shape, op) != GRAPH_SUCCESS) {
err_msg = GetShapeErrMsg(0, DebugString(images_desc->GetShape().GetDims()), "at least 3D");
err_msg = string("failed to call WithRankAtLeast function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "inc/arg_max_op.h"
#include "custom_op_proto/cust_elewise_calculation_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/op_const.h"
@ -98,5 +98,5 @@ IMPLEMT_COMMON_INFERFUNC(ArgMaxInferShape) {
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(ArgMaxV2, ArgMaxInferShape);
CUST_COMMON_INFER_FUNC_REG(ArgMax, ArgMaxInferShape);
} // namespace ge

View File

@ -0,0 +1,108 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/elewise_calculation_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/op_const.h"
namespace ge {
IMPLEMT_COMMON_INFERFUNC(ArgMinInferShape) {
// get all input desc
const vector<string> depend_names = {"dimension"};
PREPARE_DYNAMIC_SHAPE(depend_names);
auto op_info_arg = OpDescUtils::GetOpDescFromOperator(op);
auto input_desc = op_info_arg->MutableInputDesc("x");
auto const_desc = op_info_arg->MutableInputDesc("dimension");
auto y_desc = op_info_arg->MutableOutputDesc("y");
// get and set output dtype
ge::DataType dtype;
if (op.GetAttr("dtype", dtype) == GRAPH_SUCCESS) {
y_desc->SetDataType(dtype);
} else {
OP_LOGW(TbeGetName(op).c_str(), "get attr dtype failed.");
y_desc->SetDataType(DT_INT32);
}
// get x shape
auto x_shape = input_desc->MutableShape().GetDims();
// if x_shape == -2, set output -2
if (IsUnknownRankShape(x_shape)) {
y_desc->SetShape(GeShape(x_shape));
return GRAPH_SUCCESS;
}
// if x_shape.size() < 2, set output scalar
if (x_shape.size() < 2) {
vector<int64_t> output_shape;
y_desc->SetShape(GeShape(output_shape));
return GRAPH_SUCCESS;
}
// read dimension const value
vector<int64_t> dimension_value;
auto dimension_idx = static_cast<uint32_t>(op_info_arg->GetInputIndexByName("dimension"));
const GeTensor *dimension_tensor = OpDescUtils::GetInputConstData(op, dimension_idx);
if (dimension_tensor != nullptr) {
auto const_dtype = const_desc->GetDataType();
GetConstValue(op, dimension_tensor, const_dtype, dimension_value);
// verify dimension_value
if (dimension_value.size() != 1) {
string error_msg = ConcatString("the element size of input[dimension] should be equal to 1, but get ",
dimension_value.size(), ".");
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), error_msg);
return GRAPH_FAILED;
}
int64_t dimension = dimension_value[0] < 0 ? dimension_value[0] + x_shape.size() : dimension_value[0];
if (dimension >= static_cast<int64_t>(x_shape.size())) {
string error_msg = ConcatString("the value of input[dimension] must be range at input shape size,",
" but get input[dimension] value ", dimension_value[0], ", input[x] shape size ",
x_shape.size(), ".");
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), error_msg);
return GRAPH_FAILED;
}
vector<int64_t> output_shape(x_shape);
output_shape.erase(output_shape.begin() + dimension);
y_desc->SetShape(GeShape(output_shape));
// when output is dynamic will update range
if (IsUnknown(output_shape)) {
std::vector<std::pair<int64_t, int64_t>> input_range;
input_desc->GetShapeRange(input_range);
MakeUpShapeRange(x_shape, input_range);
input_range.erase(input_range.begin() + dimension);
y_desc->SetShapeRange(input_range);
}
return GRAPH_SUCCESS;
}
// dimension is not const, set all output is -1, range is [1, -1]
std::vector<std::pair<int64_t, int64_t>> output_range;
vector<int64_t> output_shape;
for (size_t item = 0; item < (x_shape.size() - 1); ++item) {
output_shape.push_back(-1);
}
MakeUpShapeRange(output_shape, output_range);
y_desc->SetShape(GeShape(output_shape));
y_desc->SetShapeRange(output_range);
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(ArgMin, ArgMinInferShape);
} // namespace ge

View File

@ -21,6 +21,161 @@
#include "utils/util.h"
namespace ge {
// ----------------Expand Begin-------------------
template <typename T>
static bool ExpandCalDim(const Tensor &data, std::vector<int64_t> &vec_dim, std::vector<int64_t> &x_dims,
std::vector<std::pair<int64_t, int64_t>> &range_vector) {
int64_t len_x = x_dims.size();
int64_t len_shape = data.GetSize() / sizeof(T);
int64_t diff = abs(len_x - len_shape);
const char *op_name = "Expand";
std::string xShape = to_string(x_dims);
OP_LOGD(op_name, "Get shape of [expand's x] %s", xShape.c_str());
const T *pdata = reinterpret_cast<const T *>(data.GetData());
std::vector<int64_t> shape_dims;
for (int64_t i = 0; i < len_shape; i++) {
T dim = pdata[i];
shape_dims.push_back(dim);
}
std::string shapeVal = to_string(shape_dims);
OP_LOGD(op_name, "Get constValue val of [expand's shape] %s", shapeVal.c_str());
const bool is_shape_less = (len_shape < len_x);
for (int64_t i = 0; i < diff; i++) {
T dim = 0;
if (is_shape_less) {
dim = x_dims[i];
} else {
dim = pdata[i];
}
if (dim == -1) {
range_vector.push_back(std::make_pair(1, -1));
} else {
range_vector.push_back(std::make_pair(dim, dim));
}
vec_dim.push_back(dim);
}
int64_t upb = len_shape;
if (is_shape_less) {
upb = len_x;
}
for (int64_t i = diff; i < upb; i++) {
int64_t idx = i - diff;
T dim = 0;
if (is_shape_less) {
idx = i;
dim = pdata[i - diff];
} else {
dim = pdata[i];
}
if (dim == -1 || x_dims[idx] == -1) {
vec_dim.push_back(-1);
range_vector.push_back(std::make_pair(1, -1));
continue;
}
if (dim == 0) {
vec_dim.push_back(0);
range_vector.push_back(std::make_pair(0, 0));
continue;
}
if ((x_dims[idx] != dim) && (x_dims[idx] != 1) && (dim != 1)) {
return false;
}
if (x_dims[idx] > dim) {
vec_dim.push_back(x_dims[idx]);
range_vector.push_back(std::make_pair(x_dims[idx], x_dims[idx]));
} else {
vec_dim.push_back(dim);
range_vector.push_back(std::make_pair(dim, dim));
}
}
return true;
}
IMPLEMT_INFERFUNC(Expand, ExpandInferShape) {
const char *op_name = "Expand";
OP_LOGD(op_name, "ExpandInferShape start.");
const vector<string> const_names = {"shape"};
PREPARE_DYNAMIC_SHAPE(const_names);
TensorDesc tensordesc_input = op.GetInputDescByName("x");
Shape x_shape = tensordesc_input.GetShape();
std::vector<int64_t> x_dims = x_shape.GetDims();
DataType x_dtype = tensordesc_input.GetDataType();
Tensor data;
std::vector<int64_t> vec_dim;
TensorDesc tensordesc_output = op.GetOutputDescByName("y");
tensordesc_output.SetDataType(x_dtype);
TensorDesc tensordesc_shape = op.GetInputDescByName("shape");
size_t dim_num = tensordesc_shape.GetShape().GetDimNum();
std::vector<int64_t> empty_dim_vec = tensordesc_shape.GetShape().GetDims();
for (size_t i = 0; i < dim_num; i++) {
if (empty_dim_vec[i] == 0) {
tensordesc_output.SetShape(ge::Shape(empty_dim_vec));
return op.UpdateOutputDesc("y", tensordesc_output);
}
}
std::vector<std::pair<int64_t, int64_t>> range_vector;
if (op.GetInputConstData("shape", data) != GRAPH_SUCCESS) {
OP_LOGD(op_name, "Get constValue failed of [shape]");
vector<int64_t> shape_dims = tensordesc_shape.GetShape().GetDims();
size_t dim_num = shape_dims.size();
if (dim_num > 1) {
OP_LOGE(op_name, "The dim numbers of shape [%zu] are more than one.", dim_num);
return GRAPH_FAILED;
}
int64_t max_len = x_dims.size();
if (shape_dims[0] > max_len) {
max_len = shape_dims[0];
}
for (int64_t item = 0; item < max_len; ++item) {
vec_dim.push_back(-1);
range_vector.push_back(std::make_pair(1, -1));
}
} else {
OP_LOGD(op_name, "Get constValue succeeded of [shape]");
vector<int64_t> shape_dims = tensordesc_shape.GetShape().GetDims();
if (shape_dims.size() > 1) {
OP_LOGE(op_name, "The dim numbers of shape [%zu] are more than one.", shape_dims.size());
return GRAPH_FAILED;
}
DataType data_type = tensordesc_shape.GetDataType();
if (data_type == DT_INT32) {
if (!ExpandCalDim<int32_t>(data, vec_dim, x_dims, range_vector)) {
OP_LOGE(op_name, "Data shape are not compatible!");
return GRAPH_FAILED;
}
} else if (data_type == DT_INT64) {
if (!ExpandCalDim<int64_t>(data, vec_dim, x_dims, range_vector)) {
OP_LOGE(op_name, "Data shape are not compatible!");
return GRAPH_FAILED;
}
} else {
OP_LOGE(op_name, "Data type not supported!");
return GRAPH_PARAM_INVALID;
}
}
tensordesc_output.SetShape(ge::Shape(vec_dim));
tensordesc_output.SetShapeRange(range_vector);
(void)op.UpdateOutputDesc("y", tensordesc_output);
OP_LOGD(op_name, "ExpandInferShape finish.");
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(Expand, ExpandInferShape);
// ----------------Expand END---------------------
// ---------------SliceGrad-------------------
CUST_IMPLEMT_INFERFUNC(SliceGrad, SliceGradInfer) {
TensorDesc x_desc = op.GetInputDescByName("x");
@ -33,4 +188,278 @@ CUST_IMPLEMT_INFERFUNC(SliceGrad, SliceGradInfer) {
CUST_INFER_FUNC_REG(SliceGrad, SliceGradInfer);
// ---------------SliceGrad End---------------
// ---------------MaskedSelectGrad-------------------
CUST_IMPLEMT_INFERFUNC(MaskedSelectGrad, MaskedSelectGradInfer) {
TensorDesc x_desc = op.GetInputDescByName("x");
if (op.UpdateOutputDesc("dx", x_desc) != GRAPH_SUCCESS) {
OP_LOGE("MaskedSelectGrad", "Update output desc failed.");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_INFER_FUNC_REG(MaskedSelectGrad, MaskedSelectGradInfer);
// ---------------MaskedSelectGrad End---------------
// -------------------------------IdentityN Begin-------------------------------
// //
IMPLEMT_INFERFUNC(IdentityN, IdentityNInfer) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
for (size_t i = 0; i < op.GetInputsSize(); i++) {
auto input_desc = op_desc->MutableInputDesc(i);
auto input_dims = input_desc->MutableShape().GetDims();
auto output_desc = op_desc->MutableOutputDesc(i);
auto intput_dtype = input_desc->GetDataType();
std::vector<std::pair<int64_t, int64_t>> input_range;
input_desc->GetShapeRange(input_range);
output_desc->SetShape(GeShape(input_dims));
output_desc->SetOriginShape(GeShape(input_dims));
output_desc->SetDataType(intput_dtype);
output_desc->SetShapeRange(input_range);
}
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(IdentityN, IdentityNInfer);
// -------------------------------IdentityN End-------------------------------
// //
// -------------------------------LowerBound------------------------------- //
IMPLEMT_INFERFUNC(LowerBound, LowerBoundInfer) {
TensorDesc sorted_x_desc = op.GetInputDescByName("sorted_x");
TensorDesc values_desc = op.GetInputDescByName("values");
Shape unused_shape;
if (WithRank(sorted_x_desc, 2, unused_shape, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
TbeGetName(op),
ConcatString("call WithRank failed, ", GetShapeErrMsg(0, DebugString(sorted_x_desc.GetShape().GetDims()), "2D")));
return GRAPH_FAILED;
}
if (WithRank(values_desc, 2, unused_shape, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
TbeGetName(op),
ConcatString("call WithRank failed, ", GetShapeErrMsg(1, DebugString(values_desc.GetShape().GetDims()), "2D")));
return GRAPH_FAILED;
}
DataType out_type;
if (op.GetAttr("out_type", out_type) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Get attr [out_type] failed.");
return GRAPH_FAILED;
}
TensorDesc y_desc = op.GetOutputDescByName("y");
y_desc.SetDataType(out_type);
y_desc.SetShape(values_desc.GetShape());
if (op.UpdateOutputDesc("y", y_desc) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Update [y] desc failed.");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(LowerBound, LowerBoundInfer);
// -------------------------------LowerBound END-------------------------------
// //
// -------------------------------ListDiff------------------------------- //
IMPLEMT_INFERFUNC(ListDiff, ListDiffInfer) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto x_desc = op_desc->MutableInputDesc(0);
auto y_desc = op_desc->MutableInputDesc(1);
Shape unused_shape;
std::string error_msg;
if (WithRank(x_desc, 1, unused_shape, op) != GRAPH_SUCCESS) {
std::string error_msg = GetShapeErrMsg(0, DebugString(x_desc->GetShape().GetDims()), "1D");
error_msg = string("failed to call WithRank function, ") + error_msg;
return GRAPH_FAILED;
}
if (WithRank(y_desc, 1, unused_shape, op) != GRAPH_SUCCESS) {
std::string error_msg = GetShapeErrMsg(1, DebugString(y_desc->GetShape().GetDims()), "1D");
error_msg = string("failed to call WithRank function, ") + error_msg;
return GRAPH_FAILED;
}
DataType output_type = x_desc->GetDataType();
DataType index_type;
if (op.GetAttr("out_idx", index_type) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("failed to get attr[out_idx]."));
return GRAPH_FAILED;
}
GeShape result({ge::UNKNOWN_DIM});
auto output_desc = op_desc->MutableOutputDesc(0);
output_desc->SetShape(GeShape(result));
output_desc->SetDataType(output_type);
auto index_desc = op_desc->MutableOutputDesc(1);
index_desc->SetShape(GeShape(result));
index_desc->SetDataType(index_type);
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(ListDiff, ListDiffInfer);
// -------------------------------ListDiff END------------------------------- //
// ----------------HammingWindow Begin---------------------
IMPLEMT_COMMON_INFERFUNC(HammingWindowInferShape) {
std::vector<int64_t> input_dim = op.GetInputDesc(0).GetShape().GetDims();
if (input_dim.size() != 1) {
OP_LOGE(TbeGetName(op).c_str(), "Tensor length input must be 1D.");
return GRAPH_FAILED;
}
Tensor length_tensor;
int64_t length_data;
if (op.GetInputConstData("length", length_tensor) == GRAPH_SUCCESS) {
uint8_t *length = length_tensor.GetData();
length_data = static_cast<int64_t>(*length);
} else {
length_data = UNKNOWN_DIM;
}
std::vector<int64_t> output_dim;
if (length_data != UNKNOWN_DIM && length_data < 0) {
OP_LOGE(TbeGetName(op).c_str(), "Non-negative window length required, got [%ld].", length_data);
return GRAPH_FAILED;
}
if (length_data != 0) {
output_dim.push_back(length_data);
}
ge::Shape output_shape = ge::Shape(output_dim);
Operator::OpInt dtype;
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
dtype = 0;
}
DataType output_dtype = static_cast<DataType>(dtype);
TensorDesc output_desc = op.GetOutputDescByName("y");
output_desc.SetShape(output_shape);
output_desc.SetDataType(output_dtype);
op.UpdateOutputDesc("y", output_desc);
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(HammingWindow, HammingWindowInferShape);
// ----------------HammingWindow End---------------------
// ----------------Mvlgamma Begin-------------------
CUST_IMPLEMT_INFERFUNC(Mvlgamma, MvlgammaInferShape) {
const char *op_name = "Mvlgamma";
OP_LOGD(op_name, "MvlgammaInferShape begin.");
TensorDesc tensordesc_input = op.GetInputDescByName("x");
Shape input_shape = tensordesc_input.GetShape();
std::vector<int64_t> dims_input = input_shape.GetDims();
DataType input_dtype = tensordesc_input.GetDataType();
TensorDesc tensordesc_output1 = op.GetOutputDescByName("y");
tensordesc_output1.SetDataType(input_dtype);
tensordesc_output1.SetShape(ge::Shape(dims_input));
(void)op.UpdateOutputDesc("y", tensordesc_output1);
OP_LOGD(op_name, "MvlgammaInferShape end.");
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(Mvlgamma, MvlgammaVerify) { return GRAPH_SUCCESS; }
CUST_INFER_FUNC_REG(Mvlgamma, MvlgammaInferShape);
CUST_VERIFY_FUNC_REG(Mvlgamma, MvlgammaVerify);
// ----------------Mvlgamma END---------------------
// ----------------MvlgammaGrad Begin-------------------
CUST_IMPLEMT_INFERFUNC(MvlgammaGrad, MvlgammaGradInferShape) {
const char *op_name = "MvlgammaGrad";
OP_LOGD(op_name, "MvlgammaGradInferShape begin.");
TensorDesc tensordesc_input = op.GetInputDescByName("y_grad");
Shape input_shape = tensordesc_input.GetShape();
std::vector<int64_t> dims_input = input_shape.GetDims();
DataType input_dtype = tensordesc_input.GetDataType();
TensorDesc tensordesc_output1 = op.GetOutputDescByName("x_grad");
tensordesc_output1.SetDataType(input_dtype);
tensordesc_output1.SetShape(ge::Shape(dims_input));
(void)op.UpdateOutputDesc("x_grad", tensordesc_output1);
OP_LOGD(op_name, "MvlgammaGradInferShape end.");
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(MvlgammaGrad, MvlgammaGradVerify) { return GRAPH_SUCCESS; }
CUST_INFER_FUNC_REG(MvlgammaGrad, MvlgammaGradInferShape);
CUST_VERIFY_FUNC_REG(MvlgammaGrad, MvlgammaGradVerify);
// ----------------MvlgammaGrad END---------------------
// --------------------------LogSpace---------------------
static bool CheckSteps(const Operator &op, const string &attr_num_steps) {
int64_t steps = 0;
int64_t steps_ori = 100;
if (ge::GRAPH_SUCCESS != op.GetAttr(attr_num_steps.c_str(), steps)) {
steps = steps_ori;
}
if (steps < 0) {
return false;
}
return true;
}
CUST_IMPLEMT_VERIFIER(LogSpace, LogSpaceVerify) {
AscendString opName;
op.GetName(opName);
if (op.GetInputDescByName("start").GetShape().GetDims().size() != 1) {
OP_LOGE(opName.GetString(), "Input start size must be 1.");
return GRAPH_FAILED;
}
if (op.GetInputDescByName("end").GetShape().GetDims().size() != 1) {
OP_LOGE(opName.GetString(), "Input end size must be 1.");
return GRAPH_FAILED;
}
DataType input_type_start = op.GetInputDescByName("start").GetDataType();
DataType input_type_end = op.GetInputDescByName("end").GetDataType();
if (input_type_start != input_type_end) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
// Obtains the processing function of the output tensor description.
IMPLEMT_COMMON_INFERFUNC(LogSpaceInferShape) {
AscendString opName1;
op.GetName(opName1);
TensorDesc v_output_desc = op.GetOutputDescByName("y");
int64_t steps;
int64_t num_rows = 1;
op.GetAttr("steps", steps);
if (!CheckSteps(op, "steps")) {
OP_LOGE(opName1.GetString(), "the attr 'steps' should be greater than or equal to 0.");
return GRAPH_FAILED;
}
std::vector<int64_t> dim_vec;
dim_vec.push_back(num_rows);
dim_vec.push_back(steps);
v_output_desc.SetShape(ge::Shape(dim_vec));
int64_t dtype = 1;
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
v_output_desc.SetDataType(DT_FLOAT16);
} else {
if (dtype == 1) {
v_output_desc.SetDataType(DT_FLOAT16);
}
if (dtype == 0) {
v_output_desc.SetDataType(DT_FLOAT);
}
}
(void)op.UpdateOutputDesc("y", v_output_desc);
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(LogSpace, LogSpaceInferShape);
// Registered verify function
CUST_VERIFY_FUNC_REG(LogSpace, LogSpaceVerify);
// --------------------------LogSpace END---------------------
} // namespace ge

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "inc/bartlett_window_op.h"
#include "custom_op_proto/cust_spectral_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
@ -22,14 +22,14 @@
namespace ge {
// ---------BartlettWindow-------------------
IMPLEMT_COMMON_INFERFUNC(BartlettWindowInferShape) {
int dtype;
Operator::OpType dtype;
Format input_format = op.GetInputDescByName("window_length").GetFormat();
TensorDesc out_desc = op.GetOutputDescByName("y");
Shape unused;
std::string err_msg;
// Set output shape
if (WithRankAtMost(op.GetInputDescByName("window_length"), 1, unused, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRankAtMost(op.GetInputDescByName("window_length"), 1, unused, op) != GRAPH_SUCCESS) {
err_msg = GetShapeErrMsg(1, DebugString(op.GetInputDescByName("window_length").GetShape().GetDims()), "at most 1D");
err_msg = std::string("failed to call WithRankAtMost, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
@ -39,19 +39,19 @@ IMPLEMT_COMMON_INFERFUNC(BartlettWindowInferShape) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
op_desc->SetOpInferDepends(input_infer_depends);
Tensor tensor;
Shape output_shape({UNKNOWN_DIM});
if (op.GetInputConstData("window_length", tensor) == GRAPH_SUCCESS) {
auto tensor_data = reinterpret_cast<int64_t *>(tensor.GetData());
if (*tensor_data == 0) {
std::vector<int64_t> dim_vector = {};
Shape output_shape(dim_vector);
out_desc.SetShape(output_shape);
} else {
std::vector<int64_t> dim_vector = {};
dim_vector.push_back(*tensor_data);
Shape output_shape(dim_vector);
out_desc.SetShape(output_shape);
int64_t length;
if (MakeDimForScalarInput(tensor, length, op) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
if (Vector(length, output_shape) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(),
string("fail to gen vector shape according dim bins."));
return GRAPH_FAILED;
}
}
out_desc.SetShape(output_shape);
// Set output dtype
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {

View File

@ -0,0 +1,66 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/math_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/op_const.h"
#include "utils/common_shape_fns.h"
namespace ge {
IMPLEMT_INFERFUNC(Betainc, BetaincInfer) {
const int num_inputs = 3;
Shape output(UNKNOWN_RANK);
int num_scalars = 0;
Shape some_non_scalar;
for (int i = 0; i < num_inputs; ++i) {
TensorDesc input_desc = op.GetInputDesc(i);
Shape input_shape = input_desc.GetShape();
if (!RankKnown(input_shape)) {
some_non_scalar = input_shape;
} else if (input_shape.GetDimNum() == 0) {
++num_scalars;
} else {
if (Merge(output, input_shape, output, op) != GRAPH_SUCCESS) {
std::string err_msg =
ConcatString("failed to call Merge function to merge", i, "th input shape",
DebugString(input_shape.GetDims()), " and output[z] shape", DebugString(output.GetDims()));
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
some_non_scalar = output;
}
}
if (num_scalars == num_inputs - 1) {
output = some_non_scalar;
} else if (num_scalars == num_inputs) {
TensorDesc a_desc = op.GetInputDescByName("a");
output = a_desc.GetShape();
}
DataType a_type = op.GetInputDescByName("a").GetDataType();
TensorDesc z_desc = op.GetOutputDescByName("z");
z_desc.SetShape(output);
z_desc.SetDataType(a_type);
if (op.UpdateOutputDesc("z", z_desc) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("fail to update output[z] desc."));
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(Betainc, BetaincInfer);
} // namespace ge

View File

@ -0,0 +1,63 @@
/**
* Copyright (c) 2023 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/math_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
IMPLEMT_INFERFUNC(Bincount, BincountInfer) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
op_desc->SetOpInferDepends({"size"});
GeShape unused;
auto size_desc = op_desc->MutableInputDesc(1);
if (WithRank(size_desc, 0, unused, op) != GRAPH_SUCCESS) {
std::string err_msg = GetShapeErrMsg(1, DebugString(size_desc->GetShape().GetDims()), "scalar");
err_msg = string("failed to call WithRank function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
Tensor tensor;
int64_t bins = 0;
if (op.GetInputConstData("size", tensor) != GRAPH_SUCCESS) {
bins = UNKNOWN_DIM;
}
if (bins != UNKNOWN_DIM) {
if (MakeDimForScalarInput(tensor, bins, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("fail to get dim from tensor of input[size]."));
return GRAPH_FAILED;
}
}
Shape bins_shape;
if (Vector(bins, bins_shape) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("fail to gen vector shape according dim bins."));
return GRAPH_FAILED;
}
auto bins_desc = op_desc->MutableOutputDesc(0);
bins_desc->SetShape(GeShape(bins_shape.GetDims()));
bins_desc->SetDataType(op_desc->MutableInputDesc(2)->GetDataType());
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(Bincount, BincountInfer);
} // namespace ge

View File

@ -0,0 +1,43 @@
/**
* Copyright (c) 2023 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/bitwise_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
namespace ge {
// -----------------------------LeftShift-----------------------------
IMPLEMT_VERIFIER(LeftShift, LeftShiftVerify) {
if (!CheckTwoInputDtypeSame(op, "x", "y")) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_INFERFUNC(LeftShift, LeftShiftInferShape) {
Shape data_shape = op.GetInputDescByName("x").GetShape();
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
TensorDesc td = op.GetOutputDescByName("z");
td.SetShape(ge::Shape(data_shape));
td.SetDataType(input_dtype);
(void)op.UpdateOutputDesc("z", td);
return BROADCAST_INFER("x", "y", "z")(op);
}
INFER_FUNC_REG(LeftShift, LeftShiftInferShape);
VERIFY_FUNC_REG(LeftShift, LeftShiftVerify);
// -----------------------------LeftShift END-----------------------------
} // namespace ge

View File

@ -26,29 +26,34 @@ CUST_IMPLEMT_VERIFIER(BlackmanWindow, BlackmanWindowVerify) { return GRAPH_SUCCE
IMPLEMT_COMMON_INFERFUNC(BlackmanWindowInferShape) {
Shape shape;
Shape unused;
Operator::OpType dtype;
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Get attr dtype failed.");
}
if (WithRank(op.GetInputDesc(0), 0, unused, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(op.GetInputDesc(0), 0, unused, op) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
std::vector<std::string> input_infer_depends = {"window_length"};
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
op_desc->SetOpInferDepends(input_infer_depends);
Tensor window_length_tensor;
if (op.GetInputConstData("window_length", window_length_tensor) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
auto output_desc = op.GetOutputDescByName("y");
output_desc.SetDataType(dtype);
output_desc.SetShape(Shape(ge::UNKNOWN_SHAPE));
return op.UpdateOutputDesc("y", output_desc);
}
int64_t length;
if (MakeDimForScalarInput(window_length_tensor, length, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (MakeDimForScalarInput(window_length_tensor, length, op) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
if (Vector(length, shape) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), string("fail to gen vector shape according dim bins."));
return GRAPH_FAILED;
}
Operator::OpType type;
if (op.GetAttr("dtype", type) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), std::string("get attr[dtype] failed"));
return GRAPH_FAILED;
}
TensorDesc output_desc = op.GetOutputDescByName("y");
output_desc.SetDataType(type);
output_desc.SetDataType(dtype);
output_desc.SetShape(shape);
return op.UpdateOutputDesc("y", output_desc);
}

View File

@ -0,0 +1,84 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/pad_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
template <typename T>
static void CaclDims(const Tensor &data, std::vector<int64_t> &vec_dim) {
int32_t size = data.GetSize() / sizeof(T);
for (int32_t i = 0; i < size; i++) {
T dim = *((T *)data.GetData() + i);
vec_dim.push_back(dim);
}
}
// -------------------BroadcastTo-----------------------
IMPLEMT_INFERFUNC(BroadcastTo, BroadcastToInferShape) {
const vector<string> depend_names = {"shape"};
PREPARE_DYNAMIC_SHAPE(depend_names);
Tensor data;
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
if (op.GetInputConstData("shape", data) != GRAPH_SUCCESS) {
OP_LOGI(TbeGetName(op).c_str(), "Get constValue failed of [shape]");
auto shape_desc = op_info->MutableInputDesc("shape");
vector<int64_t> shapedims = shape_desc->MutableShape().GetDims();
size_t dim_num = shapedims.size();
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
if (dim_num > 1) {
std::string err_msg = ConcatString("the rank[", dim_num, "] of input[shape] should not be more than 1");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
std::vector<int64_t> shape_vector;
std::vector<std::pair<int64_t, int64_t>> range_vector;
for (int64_t item = 0; item < shapedims[0]; ++item) {
shape_vector.push_back(-1);
range_vector.push_back(std::make_pair(1, -1));
}
auto output_desc = op_info->MutableOutputDesc("y");
output_desc->SetShape(GeShape(shape_vector));
output_desc->SetShapeRange(range_vector);
output_desc->SetDataType(input_dtype);
return GRAPH_SUCCESS;
}
DataType data_type = data.GetTensorDesc().GetDataType();
std::vector<int64_t> vec_dim;
if (data_type == DT_INT32) {
CaclDims<int32_t>(data, vec_dim);
} else if (data_type == DT_INT64) {
CaclDims<int64_t>(data, vec_dim);
} else {
return GRAPH_PARAM_INVALID;
}
OP_LOGI(TbeGetName(op).c_str(), "the op infer shape and dtype");
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
auto output_desc = op_info->MutableOutputDesc("y");
output_desc->SetShape(GeShape(vec_dim));
output_desc->SetDataType(input_dtype);
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(BroadcastTo, BroadcastToInferShape);
// --------------------BroadcastTo END-----------------------
} // namespace ge

View File

@ -0,0 +1,56 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/math_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
namespace ge {
// ---------Bucketize------------------
IMPLEMT_INFERFUNC(Bucketize, BucketizeInfer) {
std::string str_name = TbeGetName(op);
const char *opname = str_name.c_str();
OP_LOGD(opname, "Enter Bucketize inferfunction!");
// set output shape
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
CHECK(op_desc == nullptr, OP_LOGE(opname, "op desc is null."), return GRAPH_FAILED);
std::vector<int64_t> x_shape = op_desc->MutableInputDesc("x")->MutableShape().GetDims();
auto output_desc = op_desc->MutableOutputDesc(0);
output_desc->SetShape(GeShape(x_shape));
// set output dtype
DataType dtype;
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
std::string err_msg = GetInputInvalidErrMsg("dtype");
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("get attr[dtype] failed"));
return GRAPH_FAILED;
}
if ((dtype != DT_INT32) && (dtype != DT_INT64)) {
std::string err_msg = GetInputInvalidErrMsg("dtype");
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("The attr [dtype] must be one of DT_INT32 or DT_INT64"));
return GRAPH_FAILED;
}
output_desc->SetDataType(dtype);
return GRAPH_SUCCESS;
}
IMPLEMT_VERIFIER(Bucketize, BucketizeVerify) { return GRAPH_SUCCESS; }
INFER_FUNC_REG(Bucketize, BucketizeInfer);
VERIFY_FUNC_REG(Bucketize, BucketizeVerify);
// ---------Bucketize End-------------------
} // namespace ge

View File

@ -14,30 +14,32 @@
* limitations under the License.
*/
#include "inc/log_normal_reverse.h"
#include "inc/ops/linalg_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
#include "utils/linalg_ops_shape_fns.h"
namespace ge {
// ----------------LogNormalReverse-------------------
// Obtains the processing function of the output tensor description.
IMPLEMT_COMMON_INFERFUNC(LogNormalReverseInferShape) {
TensorDesc v_output_desc = op.GetOutputDescByName("y");
IMPLEMT_INFERFUNC(CholeskyGrad, CholeskyGradInfer) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto x_desc = op_desc->MutableInputDesc(0);
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
Format input_format = op.GetInputDescByName("x").GetFormat();
ge::Shape shape_input = op.GetInputDescByName("x").GetShape();
v_output_desc.SetShape(shape_input);
v_output_desc.SetDataType(input_dtype);
v_output_desc.SetFormat(input_format);
if (op.UpdateOutputDesc("y", v_output_desc) != GRAPH_SUCCESS) {
GeShape y_shape;
if (MakeBatchSquareMatrix(x_desc, y_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(),
"Op CholeskyGrad first input x tensor make batch square matrix "
"failed.");
return GRAPH_FAILED;
}
DataType type = x_desc->GetDataType();
auto y_desc = op_desc->MutableOutputDesc(0);
y_desc->SetShape(y_shape);
y_desc->SetDataType(type);
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(LogNormalReverse, LogNormalReverseInferShape);
// ----------------LogNormalReverse-------------------
INFER_FUNC_REG(CholeskyGrad, CholeskyGradInfer);
} // namespace ge

View File

@ -22,8 +22,7 @@ namespace ge {
// -----------------------CholeskyInverse---------------------
IMPLEMT_COMMON_INFERFUNC(CholeskyInverseInferShape) {
TensorDesc out_desc = op.GetOutputDescByName("x");
out_desc.SetDataType(op.GetInputDescByName("x").GetDataType());
if (op.UpdateOutputDesc("x", out_desc) != GRAPH_SUCCESS) {
if (op.UpdateOutputDesc("y", out_desc) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;

View File

@ -0,0 +1,43 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/linalg_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
#include "utils/linalg_ops_shape_fns.h"
namespace ge {
IMPLEMT_INFERFUNC(Cholesky, CholeskyInfer) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto x_desc = op_desc->MutableInputDesc(0);
GeShape y_shape;
if (MakeBatchSquareMatrix(x_desc, y_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Op Cholesky first input x's tensor make batch square matrix failed.");
return GRAPH_FAILED;
}
DataType type = x_desc->GetDataType();
auto y_desc = op_desc->MutableOutputDesc(0);
y_desc->SetShape(y_shape);
y_desc->SetDataType(type);
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(Cholesky, CholeskyInfer);
} // namespace ge

View File

@ -0,0 +1,199 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/image_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/op_log.h"
#include "utils/common_shape_fns.h"
namespace ge {
IMPLEMT_INFERFUNC(CombinedNonMaxSuppression, CombinedNonMaxSuppressionInfer) {
DYNAMIC_SHAPE_NOT_SUPPORTED(op);
Shape boxes;
Shape scores;
Shape max_output_size_per_class;
Shape max_total_size;
Shape unused_shape;
std::vector<std::string> input_infer_depends = {"max_total_size", "max_output_size_per_class"};
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
op_desc->SetOpInferDepends(input_infer_depends);
if (WithRank(op.GetInputDesc(0), 4, boxes, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op),
GetShapeErrMsg(0, DebugString(op.GetInputDesc(0).GetShape().GetDims()), "4D"));
return GRAPH_FAILED;
}
if (WithRank(op.GetInputDesc(1), 3, scores, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op),
GetShapeErrMsg(1, DebugString(op.GetInputDesc(1).GetShape().GetDims()), "3D"));
return GRAPH_FAILED;
}
if (WithRank(op.GetInputDesc(2), 0, max_output_size_per_class, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
TbeGetName(op), GetShapeErrMsg(2, DebugString(op.GetInputDesc(2).GetShape().GetDims()), "scalar"));
return GRAPH_FAILED;
}
if (WithRank(op.GetInputDesc(3), 0, max_total_size, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
TbeGetName(op), GetShapeErrMsg(3, DebugString(op.GetInputDesc(3).GetShape().GetDims()), "scalar"));
return GRAPH_FAILED;
}
if (WithRank(op.GetInputDesc(4), 0, unused_shape, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
TbeGetName(op), GetShapeErrMsg(4, DebugString(op.GetInputDesc(4).GetShape().GetDims()), "scalar"));
return GRAPH_FAILED;
}
if (WithRank(op.GetInputDesc(5), 0, unused_shape, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
TbeGetName(op), GetShapeErrMsg(5, DebugString(op.GetInputDesc(5).GetShape().GetDims()), "scalar"));
return GRAPH_FAILED;
}
int64_t unused = 0;
int64_t dim1 = boxes.GetDim(0);
int64_t dim2 = scores.GetDim(0);
if (Merge(dim1, dim2, unused) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op),
ConcatString("call Merge function failed to merge 0th dim of input[boxes]"
" and input[scores], ",
dim1, " and ", dim2));
return GRAPH_FAILED;
}
int64_t dim3 = boxes.GetDim(1);
int64_t dim4 = scores.GetDim(1);
if (Merge(dim3, dim4, unused) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op),
ConcatString("call Merge function failed to merge 1th dim of input[boxes]"
" and input[scores], ",
dim3, " and ", dim4));
return GRAPH_FAILED;
}
if (boxes.GetDim(3) != 4) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op),
ConcatString("invalid 3th dim value[", boxes.GetDim(3), "], it should be 4"));
return GRAPH_FAILED;
}
Shape boxes_shape = op.GetInputDesc(0).GetShape();
Shape scores_shape = op.GetInputDesc(1).GetShape();
if (ValueKnown(boxes_shape, 2) && ValueKnown(scores_shape, 2)) {
if (boxes_shape.GetDim(2) != 1 && boxes_shape.GetDim(2) != scores_shape.GetDim(2)) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(
TbeGetName(op), ConcatString("2th dim of input[boxes] and input[scores] are not equal, ", boxes_shape.GetDim(2),
" and ", scores_shape.GetDim(2)));
return GRAPH_FAILED;
}
}
Tensor maxTotalSizeTensor;
Tensor maxOutputSizePerClassTensor;
if ((op.GetInputConstData("max_total_size", maxTotalSizeTensor) != GRAPH_SUCCESS) ||
(op.GetInputConstData("max_output_size_per_class", maxOutputSizePerClassTensor) != GRAPH_SUCCESS)) {
Shape out_shape0({-1, -1, 4});
Shape out_shape1({-1, -1});
Shape out_shape2({-1, -1});
Shape out_shape3({-1});
op.GetOutputDesc(0).SetShape(out_shape0);
op.GetOutputDesc(1).SetShape(out_shape1);
op.GetOutputDesc(2).SetShape(out_shape2);
op.GetOutputDesc(3).SetShape(out_shape3);
return GRAPH_SUCCESS;
}
int64_t maxTotalSize;
if (MakeDimForScalarInput(maxTotalSizeTensor, maxTotalSize, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
TbeGetName(op), ConcatString("call MakeDimForScalarInput failed to get value from input[max_total_size] tensor"));
return GRAPH_FAILED;
}
if (maxTotalSize <= 0) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(
TbeGetName(op), ConcatString("invalid value[", maxTotalSize, "] of input[max_total_size], should be > 0"));
return GRAPH_FAILED;
}
int64_t maxOutputSizePerClass;
if (MakeDimForScalarInput(maxOutputSizePerClassTensor, maxOutputSizePerClass, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
TbeGetName(op),
ConcatString("call MakeDimForScalarInput failed to get value from input[max_output_size_per_class] tensor"));
return GRAPH_FAILED;
}
int64_t output_size;
bool pad_per_class;
if (op.GetAttr("pad_per_class", pad_per_class) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), std::string("get attr[pad_per_class] failed"));
return GRAPH_FAILED;
}
if (!pad_per_class) {
output_size = maxTotalSize;
} else {
if (maxOutputSizePerClass <= 0) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(
TbeGetName(op),
ConcatString("invalid value[", maxOutputSizePerClass, "] of input[max_output_size_per_class], should be > 0"));
return GRAPH_FAILED;
}
if (maxTotalSize <= maxOutputSizePerClass * scores_shape.GetDim(2)) {
output_size = maxTotalSize;
} else {
output_size = maxOutputSizePerClass * scores_shape.GetDim(2);
}
}
int64_t batch_dim = boxes.GetDim(0);
Shape shape1({batch_dim, output_size, 4});
Shape shape2({batch_dim, output_size});
Shape shape3({batch_dim, output_size});
Shape shape4({batch_dim});
TensorDesc desc1 = op.GetOutputDescByName("nmsed_boxes");
desc1.SetShape(shape1);
desc1.SetDataType(DT_FLOAT);
if (op.UpdateOutputDesc("nmsed_boxes", desc1) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), std::string("update output[nmsed_boxes] desc failed"));
return GRAPH_FAILED;
}
TensorDesc desc2 = op.GetOutputDescByName("nmsed_scores");
desc2.SetShape(shape2);
desc2.SetDataType(DT_FLOAT);
if (op.UpdateOutputDesc("nmsed_scores", desc2) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), std::string("update output[nmsed_scores] desc failed"));
return GRAPH_FAILED;
}
TensorDesc desc3 = op.GetOutputDescByName("nmsed_classes");
desc3.SetShape(shape3);
desc3.SetDataType(DT_FLOAT);
if (op.UpdateOutputDesc("nmsed_classes", desc3) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), std::string("update output[nmsed_classes] desc failed"));
return GRAPH_FAILED;
}
TensorDesc desc4 = op.GetOutputDescByName("valid_detections");
desc4.SetShape(shape4);
desc4.SetDataType(DT_INT32);
if (op.UpdateOutputDesc("valid_detections", desc4) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(),
std::string("update output[valid_detections] desc failed"));
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(CombinedNonMaxSuppression, CombinedNonMaxSuppressionInfer);
} // namespace ge

View File

@ -0,0 +1,418 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/split_combination_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
#include "utils/op_common_util.h"
#include "register/infer_axis_slice_registry.h"
#include "utils/op_const.h"
namespace ge {
static void JoinShapeRanges(vector<pair<int64_t, int64_t>> &dest_ranges,
const vector<pair<int64_t, int64_t>> &src_ranges) {
auto dest_size = dest_ranges.size();
auto src_size = src_ranges.size();
if (dest_size != src_size) {
return;
}
for (size_t i = 0; i < dest_size; i++) {
dest_ranges[i].first = std::max(dest_ranges[i].first, src_ranges[i].first);
dest_ranges[i].second = std::min(dest_ranges[i].second, src_ranges[i].second);
}
}
static vector<pair<int64_t, int64_t>> GetShapeRangesWithUnKnowConcatDim(Operator &op, int64_t num_concat) {
vector<pair<int64_t, int64_t>> input_shape_ranges;
vector<vector<pair<int64_t, int64_t>>> all_input_shape_ranges;
vector<pair<int64_t, int64_t>> output_shape_ranges;
bool has_shape_ranges = false;
for (int32_t i = 0; i < num_concat; i++) {
const auto input_desc = op.GetDynamicInputDesc("x", i);
(void)input_desc.GetShapeRange(input_shape_ranges);
OP_LOGD(TbeGetName(op).c_str(), "input shape range:%s", to_string(input_shape_ranges).c_str());
if (input_shape_ranges.empty()) {
auto shape_dims = input_desc.GetShape().GetDims();
MakeUpShapeRange(shape_dims, input_shape_ranges);
} else {
has_shape_ranges = true;
}
all_input_shape_ranges.push_back(input_shape_ranges);
}
if (has_shape_ranges) {
output_shape_ranges = all_input_shape_ranges[0];
for (size_t i = 1; i < all_input_shape_ranges.size(); i++) {
if (output_shape_ranges.size() != all_input_shape_ranges[i].size()) {
continue;
}
for (size_t j = 0; j < output_shape_ranges.size(); j++) {
output_shape_ranges[j].first = std::max(output_shape_ranges[j].first, all_input_shape_ranges[i][j].first);
if (output_shape_ranges[j].second == -1 || all_input_shape_ranges[i][j].second == -1) {
output_shape_ranges[j].second = -1;
} else {
output_shape_ranges[j].second = output_shape_ranges[j].second + all_input_shape_ranges[i][j].second;
}
}
}
}
return output_shape_ranges;
}
bool JoinShapes(vector<int64_t> &dst_shape, const vector<int64_t> &src_shape, int64_t axis) {
if (dst_shape == src_shape) {
return true;
}
if (dst_shape.empty() || IsUnknownRankShape(dst_shape)) {
dst_shape = src_shape;
return true;
}
if (!IsUnknownRankShape(src_shape)) {
if (dst_shape.size() != src_shape.size()) {
return false;
}
auto shape_dims = dst_shape.size();
for (size_t i = 0; i < shape_dims; i++) {
if (dst_shape[i] == src_shape[i]) {
continue;
}
if (axis != static_cast<int64_t>(i) && dst_shape[i] != UNKNOWN_DIM && src_shape[i] != UNKNOWN_DIM) {
return false;
}
if (src_shape[i] != UNKNOWN_DIM) {
dst_shape[i] = src_shape[i];
}
}
}
return true;
}
bool ConcatInferShapeCommonStatic(Operator &op, const int64_t dynamic_input_start_idx, int64_t num_concat,
int64_t axis) {
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
auto input_desc = op_info->MutableInputDesc(dynamic_input_start_idx);
auto output_desc = op_info->MutableOutputDesc(0);
const GeShape &input_shape = input_desc->MutableShape();
GeShape &output_shape = output_desc->MutableShape();
output_shape = input_shape;
if (output_shape.IsUnknownShape() || num_concat == 1) {
// dynamic case or the input only one will use dynamic infer func
return false;
}
if (output_shape.IsScalar()) {
// scalar to shape [1]
output_shape.SetDimNum(1);
output_shape.SetDim(0, 1);
}
size_t output_dim = output_shape.GetDimNum();
if ((axis < -static_cast<int64_t>(output_dim)) || (axis >= static_cast<int64_t>(output_dim))) {
// axes is valid
return false;
}
if (axis < 0) {
axis += output_dim;
}
int64_t concat_dim_size = output_shape.GetDim(axis);
for (int64_t input_idx = 1; input_idx < num_concat; input_idx++) {
auto input_i_desc = op_info->MutableInputDesc(input_idx + dynamic_input_start_idx);
const GeShape &input_i_shape = input_i_desc->MutableShape();
if (input_i_shape.IsScalar() && output_dim == 1) {
concat_dim_size += 1;
continue;
}
if (input_i_shape.IsUnknownShape()) {
// dynamic case
return false;
}
if (input_i_shape.GetDimNum() != output_dim) {
// input shape size is not equal output
return false;
}
// check whether the non concat dim is equal
for (int64_t check_dim = 0; check_dim < static_cast<int64_t>(output_dim); check_dim++) {
if (check_dim != axis && input_i_shape.GetDim(check_dim) != output_shape.GetDim(check_dim)) {
return false;
}
}
concat_dim_size += input_i_shape.GetDim(axis);
}
output_shape.SetDim(axis, concat_dim_size);
// set data type
output_desc->SetDataType(input_desc->GetDataType());
return true;
}
static graphStatus ConcatInferShapeCommon(Operator &op, const int64_t dy_input_start_idx, int64_t num_concat,
int64_t axis, bool unknown_axis) {
if (num_concat <= 0) {
std::string err_msg = GetAttrValueErrMsg("num_concat", std::to_string(num_concat), ConcatString("num_concat > 0"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
// try static infershape directly
if (!unknown_axis) {
if (ConcatInferShapeCommonStatic(op, dy_input_start_idx, num_concat, axis)) {
return GRAPH_SUCCESS;
}
}
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
size_t dim_num = 0;
std::vector<GeTensorDescPtr> input_x_desc;
const string input_name = "x";
string input_name_i = "x63";
for (int64_t input_idx = 0; input_idx < num_concat; input_idx++) {
input_name_i = input_name + std::to_string(input_idx);
auto input_desc = op_info->MutableInputDesc(input_name_i);
if (!input_desc) {
std::string err_msg = GetInputInvalidErrMsg(input_name_i.c_str());
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
input_x_desc.emplace_back(input_desc);
}
bool all_unknown_rank_shape = true;
for (const auto &desc : input_x_desc) {
dim_num = std::max(dim_num, desc->MutableShape().GetDimNum());
all_unknown_rank_shape = IsUnknownRankShape(desc->MutableShape()) && all_unknown_rank_shape;
}
if (all_unknown_rank_shape) {
DataType input_dtype = input_x_desc[0]->GetDataType();
auto output_desc = op_info->MutableOutputDesc(0);
output_desc->SetDataType(input_dtype);
output_desc->SetShape(ge::GeShape(UNKNOWN_RANK));
OP_LOGD(TbeGetName(op).c_str(), "output shape:%s", to_string(output_desc->GetShape()).c_str());
return GRAPH_SUCCESS;
}
if (unknown_axis) {
DataType input_dtype = input_x_desc[0]->GetDataType();
auto output_desc = op_info->MutableOutputDesc(0);
output_desc->SetDataType(input_dtype);
vector<int64_t> dimVector(dim_num, -1);
output_desc->SetShape(ge::GeShape(dimVector));
auto output_shape_ranges = GetShapeRangesWithUnKnowConcatDim(op, num_concat);
if (!output_shape_ranges.empty()) {
output_desc->SetShapeRange(output_shape_ranges);
OP_LOGD(TbeGetName(op).c_str(), "output shape range:%s", to_string(output_shape_ranges).c_str());
}
OP_LOGD(TbeGetName(op).c_str(), "output shape:%s", to_string(output_desc->GetShape()).c_str());
return GRAPH_SUCCESS;
}
if ((axis < -static_cast<int64_t>(dim_num)) || (axis >= static_cast<int64_t>(dim_num))) {
OP_LOGE(TbeGetName(op).c_str(), "the parameter [axis] should be in the range of [%ld, %ld], but actually is %ld",
-static_cast<int64_t>(dim_num), static_cast<int64_t>(dim_num), axis);
return GRAPH_FAILED;
}
int64_t non_negative_axis = axis;
if (non_negative_axis < 0) {
non_negative_axis += dim_num;
}
vector<int64_t> output_shape_dims;
for (const auto &desc : input_x_desc) {
auto input_shape_dims = desc->MutableShape().GetDims();
if (!JoinShapes(output_shape_dims, input_shape_dims, non_negative_axis)) {
vector<vector<int64_t>> shapes = {output_shape_dims, input_shape_dims};
std::string err_msg =
OtherErrMsg(ConcatString("the input shape dims should be equal except merge axis,"
"shapes:",
ops::to_string(shapes), "axis:", std::to_string(axis)));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
}
int32_t size = 0;
for (const auto &desc : input_x_desc) {
if (IsUnknownRankShape(desc->MutableShape())) {
size = -1;
break;
}
auto dim_value = desc->MutableShape().GetDim(non_negative_axis);
if (dim_value == -1) {
size = -1;
break;
}
if (size != -1) {
size += dim_value;
}
}
output_shape_dims[non_negative_axis] = size;
DataType input_dtype = input_x_desc[0]->GetDataType();
auto output_desc = op_info->MutableOutputDesc(0);
output_desc->SetDataType(input_dtype);
output_desc->SetShape(ge::GeShape(output_shape_dims));
OP_LOGD(TbeGetName(op).c_str(), "output shape:%s", to_string(output_desc->GetShape()).c_str());
if (IsUnKnownShape(output_shape_dims)) {
vector<pair<int64_t, int64_t>> input_shape_ranges;
vector<pair<int64_t, int64_t>> output_shape_ranges;
pair<int64_t, int64_t> output_concat_dim_range(0, 0);
for (const auto &input_desc : input_x_desc) {
if (IsUnknownRankShape(input_desc->MutableShape())) {
output_concat_dim_range = {0, -1};
continue;
}
input_shape_ranges.clear();
input_desc->GetShapeRange(input_shape_ranges);
OP_LOGD(TbeGetName(op).c_str(), "input shape range:%s", to_string(input_shape_ranges).c_str());
if (input_shape_ranges.empty()) {
MakeUpShapeRange(input_desc->MutableShape(), input_shape_ranges);
}
if (static_cast<int64_t>(input_shape_ranges.size()) > non_negative_axis) {
output_concat_dim_range.first += input_shape_ranges[non_negative_axis].first;
if (input_shape_ranges[non_negative_axis].second == -1 || output_concat_dim_range.second == -1) {
output_concat_dim_range.second = -1;
} else {
output_concat_dim_range.second += input_shape_ranges[non_negative_axis].second;
}
}
if (output_shape_ranges.empty()) {
output_shape_ranges = input_shape_ranges;
} else {
JoinShapeRanges(output_shape_ranges, input_shape_ranges);
}
}
if (output_concat_dim_range.second != 0 && static_cast<uint64_t>(non_negative_axis) < output_shape_ranges.size()) {
output_shape_ranges[non_negative_axis] = output_concat_dim_range;
}
output_desc->SetShapeRange(output_shape_ranges);
OP_LOGD(TbeGetName(op).c_str(), "output shape range:%s", to_string(output_shape_ranges).c_str());
}
return GRAPH_SUCCESS;
}
static graphStatus ConcatInputsVerify(const Operator &op) {
std::vector<std::string> inputs;
const string input_name = "x";
string input_name_i = "x63";
int64_t N;
if (op.GetAttr("N", N) == GRAPH_FAILED) {
OP_LOGE(TbeGetName(op), "get attr N failed");
return GRAPH_FAILED;
}
for (int64_t input_idx = 0; input_idx < N; input_idx++) {
input_name_i = input_name + std::to_string(input_idx);
inputs.emplace_back(input_name_i);
}
if (!CheckInputDtypeSame(op, inputs)) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
// ----------------Concat OP Begin-------------------
IMPLEMT_VERIFIER(Concat, ConcatVerify) { return ConcatInputsVerify(op); }
IMPLEMT_COMMON_INFERFUNC(ConcatInferShape) {
const vector<string> depend_names = {"concat_dim"};
PREPARE_DYNAMIC_SHAPE(depend_names);
int64_t N;
if (op.GetAttr("N", N) == GRAPH_FAILED) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("get attr[N] failed"));
return GRAPH_FAILED;
}
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto shape_idx = static_cast<uint32_t>(op_desc->GetInputIndexByName("concat_dim"));
const GeTensor *data = OpDescUtils::GetInputConstData(op, shape_idx);
bool is_unknown_axis = data == nullptr;
OP_LOGD(TbeGetName(op), "concat_dim is unknown[%s].", is_unknown_axis ? "true" : "false");
int64_t axis = 0;
if (!is_unknown_axis) {
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
DataType dtype = op_info->MutableInputDesc(0)->GetDataType();
std::vector<int64_t> const_vec;
if (!GetConstValue(op, data, dtype, const_vec)) {
is_unknown_axis = true;
OP_LOGW(TbeGetName(op), "Get concat_dim value failed.");
} else {
axis = const_vec[0];
}
}
return ConcatInferShapeCommon(op, 1, N, axis, is_unknown_axis);
}
IMPLEMT_COMMON_INFER_AXIS_TYPE_INFO(InferAxisType4Concat) {
OP_LOGD(TbeGetName(op), "Infer axis type for %s begin.", TbeGetOpType(op).c_str());
std::vector<int64_t> allowed_split_inputs = {};
std::vector<int64_t> allowed_split_outputs = {0};
std::vector<int64_t> excepted_axes = {};
bool concat_dim_done = true;
static const int64_t concat_dim_input_index = 0;
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
CHECK(op_desc == nullptr,
INFER_AXIS_TYPE_ERR_REPORT(TbeGetName(op), "Failed to get desc from operator. Please check the node info."),
return GRAPH_FAILED);
vector<string> depends = op_desc->GetOpInferDepends();
if (find(depends.begin(), depends.end(), "concat_dim") == depends.end()) {
depends.emplace_back("concat_dim");
op_desc->SetOpInferDepends(depends);
}
if (!(ops::GetConstIntData(op, concat_dim_input_index, excepted_axes))) {
concat_dim_done = false;
OP_LOGD(TbeGetName(op), "Get const concat_dim value failed. Can not infer axis type.");
}
CHECK(!concat_dim_done, OP_LOGD(TbeGetName(op), "Concat dim is not const node. Can not infer axis type."),
return GRAPH_SUCCESS);
int64_t attr_n;
CHECK(op.GetAttr("N", attr_n) == GRAPH_FAILED, OP_LOGD(TbeGetName(op), "Get attr N failed. Can not infer axis type."),
return GRAPH_SUCCESS);
for (auto i = 0; i < attr_n; ++i) {
allowed_split_inputs.emplace_back(i + 1);
}
return InferElementwiseAxisTypeHelper(op, axis_type, allowed_split_inputs, allowed_split_outputs, excepted_axes);
}
COMMON_INFER_FUNC_REG(Concat, ConcatInferShape);
VERIFY_FUNC_REG(Concat, ConcatVerify);
INFER_VALUE_RANGE_DEFAULT_REG(Concat);
INFER_AXIS_TYPE_INFO_REG(Concat, InferAxisType4Concat);
// ----------------Concat OP End-------------------
} // namespace ge

View File

@ -25,41 +25,41 @@ CUST_IMPLEMT_INFERFUNC(CSRSparseMatrixToSparseTensor, CSRSparseMatrixToSparseTen
GeShape x_dense_shape_shape;
auto x_dense_shape_desc = op_desc->MutableInputDesc(0);
if (WithRank(x_dense_shape_desc, 1, x_dense_shape_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(x_dense_shape_desc, 1, x_dense_shape_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input x_dense_shape must be 1-D.");
return GRAPH_FAILED;
}
GeShape x_batch_pointers_shape;
auto x_batch_pointers_desc = op_desc->MutableInputDesc(1);
if (WithRank(x_batch_pointers_desc, 1, x_batch_pointers_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(x_batch_pointers_desc, 1, x_batch_pointers_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input x_batch_pointers must be 1-D.");
return GRAPH_FAILED;
}
GeShape x_row_pointers_shape;
auto x_row_pointers_desc = op_desc->MutableInputDesc(2);
if (WithRank(x_row_pointers_desc, 1, x_row_pointers_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(x_row_pointers_desc, 1, x_row_pointers_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input x_row_pointers must be 1-D.");
return GRAPH_FAILED;
}
GeShape x_col_indices_shape;
auto x_col_indices_desc = op_desc->MutableInputDesc(3);
if (WithRank(x_col_indices_desc, 1, x_col_indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(x_col_indices_desc, 1, x_col_indices_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input x_col_indices must be 1-D.");
return GRAPH_FAILED;
}
GeShape x_values_shape;
auto x_values_desc = op_desc->MutableInputDesc(4);
if (WithRank(x_values_desc, 1, x_values_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(x_values_desc, 1, x_values_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input x_values must be 1-D.");
return GRAPH_FAILED;
}
GeShape unused;
if (Merge(x_col_indices_shape, x_values_shape, unused, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (Merge(x_col_indices_shape, x_values_shape, unused, op) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}

View File

@ -0,0 +1,30 @@
/**
* Copyright (c) 2022-2023 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/selection_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
// ----------------Cumprod-------------------
IMPLEMT_COMMON_INFERFUNC(CumprodInferShape) {
TensorDesc desc = op.GetInputDescByName("x");
return op.UpdateOutputDesc("y", desc);
}
COMMON_INFER_FUNC_REG(Cumprod, CumprodInferShape);
// ----------------Cumprod END-------------------
} // namespace ge

View File

@ -0,0 +1,402 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/elewise_calculation_ops.h"
#include "custom_op_proto/cust_math_ops.h"
#include "register/infer_axis_slice_registry.h"
#include <string>
#include <vector>
#include "utils/op_attr.h"
#include "utils/op_log.h"
#include "utils/op_const.h"
#include "utils/util.h"
#include "utils/error_util.h"
#include "utils/reduce_infer_util.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/node_utils_ex.h"
#include "graph/utils/op_desc_utils.h"
#include "register/infer_data_slice_registry.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/axis_type_info.h"
namespace ge {
IMPLEMT_COMMON_INFERFUNC(TwoInOneOutCommonInferShape) {
bool is_dynamic_output = true;
if (!InferShapeAndTypeTwoInOneOutBroadcast(op, 0, 1, 0, is_dynamic_output)) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(OneInOneOutCommonInferShape) {
static const int64_t input_x_idx = 0;
static const int64_t output_y_idx = 0;
if (OneInOneOutDynamicInfer(op, input_x_idx, {output_y_idx})) {
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}
// ----------------------------------OneInOneOutCommonInfer-----------------------------
COMMON_INFER_FUNC_REG(CheckNumerics, OneInOneOutCommonInferShape);
COMMON_INFER_FUNC_REG(Conj, OneInOneOutCommonInferShape);
COMMON_INFER_FUNC_REG(Cos, OneInOneOutCommonInferShape);
COMMON_INFER_FUNC_REG(Expm1, OneInOneOutCommonInferShape);
COMMON_INFER_FUNC_REG(Log1p, OneInOneOutCommonInferShape);
COMMON_INFER_FUNC_REG(Log, OneInOneOutCommonInferShape);
// ----------------------------------OneInOneOutCommonInfer END-----------------------------
// ----------------------------------TowInOneOutCommonInfer-----------------------------
COMMON_INFER_FUNC_REG(Div, TwoInOneOutCommonInferShape);
COMMON_INFER_FUNC_REG(DivNoNan, TwoInOneOutCommonInferShape);
CUST_COMMON_INFER_FUNC_REG(Gcd, TwoInOneOutCommonInferShape);
CUST_COMMON_INFER_FUNC_REG(Heaviside, TwoInOneOutCommonInferShape);
CUST_COMMON_INFER_FUNC_REG(Hypot, TwoInOneOutCommonInferShape);
CUST_COMMON_INFER_FUNC_REG(Lcm, TwoInOneOutCommonInferShape);
// ----------------------------------TowInOneOutCommonInfer END-----------------------------
// --------------AcosGrad----------------
IMPLEMT_VERIFIER(AcosGrad, AcosGradVerify) {
if (!CheckTwoInputDtypeSame(op, "y", "dy")) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
VERIFY_FUNC_REG(AcosGrad, AcosGradVerify);
COMMON_INFER_FUNC_REG(AcosGrad, TwoInOneOutCommonInferShape);
INFER_AXIS_TYPE_INFO_REG(AcosGrad, InferAxisType4ElementwiseOp);
// ------------AcosGrad END----------------
// ----------------AcoshGrad-------------------
IMPLEMT_VERIFIER(AcoshGrad, AcoshGradVerify) {
if (!CheckTwoInputDtypeSame(op, "y", "dy")) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
VERIFY_FUNC_REG(AcoshGrad, AcoshGradVerify);
IMPLEMT_COMMON_INFERFUNC(AcoshGradInferShape) {
bool is_dynamic_output = true;
if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "y", "dy", "z", is_dynamic_output)) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(AcoshGrad, AcoshGradInferShape);
INFER_AXIS_TYPE_INFO_REG(AcoshGrad, InferAxisType4ElementwiseOp);
// --------------AcoshGrad END-----------------
// ----------------AsinGrad---------------
IMPLEMT_VERIFIER(AsinGrad, AsinGradVerify) {
if (!CheckTwoInputDtypeSame(op, "y", "dy")) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
VERIFY_FUNC_REG(AsinGrad, AsinGradVerify);
IMPLEMT_COMMON_INFERFUNC(AsinGradInferShape) {
bool is_dynamic_output = true;
if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "y", "dy", "z", is_dynamic_output)) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(AsinGrad, AsinGradInferShape);
INFER_AXIS_TYPE_INFO_REG(AsinGrad, InferAxisType4ElementwiseOp);
// --------------AsinGrad END-------------
// ----------------AsinhGrad-------------------
IMPLEMT_VERIFIER(AsinhGrad, AsinhGradVerify) {
if (!CheckTwoInputDtypeSame(op, "y", "dy")) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
VERIFY_FUNC_REG(AsinhGrad, AsinhGradVerify);
IMPLEMT_COMMON_INFERFUNC(AsinhGradInferShape) {
bool is_dynamic_output = true;
if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "y", "dy", "z", is_dynamic_output)) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(AsinhGrad, AsinhGradInferShape);
INFER_AXIS_TYPE_INFO_REG(AsinhGrad, InferAxisType4ElementwiseOp);
// --------------AsinhGrad END-----------------
// ----------------AddN-------------------
int64_t GetAddNConstValue(const ge::Operator &op) {
int64_t tensor_num;
if (ge::GRAPH_SUCCESS != op.GetAttr("N", tensor_num)) {
std::string err_msg = GetInputInvalidErrMsg("N");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
}
return tensor_num;
}
int64_t AddNInferClassify(ge::Operator &op, int64_t &tensor_num) {
const int64_t infer_condition_one_one = 11;
const int64_t infer_condition_one_two = 12;
const int64_t infer_condition_two = 2;
const int64_t infer_condition_three = 3;
int64_t empty_num = 0;
int64_t static_num = 0;
int64_t dynamic_shape_num = 0;
int64_t dynamic_dim_num = 0;
for (int64_t i = 0; i < tensor_num; i++) {
vector<int64_t> tempVector = op.GetDynamicInputDesc("x", i).GetShape().GetDims();
if (tempVector.empty()) {
empty_num++;
} else if (std::find(tempVector.begin(), tempVector.end(), ge::UNKNOWN_DIM) != tempVector.end()) {
dynamic_shape_num++;
} else if (std::find(tempVector.begin(), tempVector.end(), ge::UNKNOWN_DIM_NUM) != tempVector.end()) {
dynamic_dim_num++;
} else {
static_num++;
}
}
if (tensor_num == empty_num + dynamic_dim_num) {
if (tensor_num == empty_num) {
return infer_condition_one_one;
} else {
return infer_condition_one_two;
}
} else if (tensor_num == static_num || tensor_num == empty_num + static_num ||
tensor_num == static_num + dynamic_dim_num || tensor_num == empty_num + static_num + dynamic_dim_num) {
return infer_condition_two;
} else {
return infer_condition_three;
}
}
IMPLEMT_COMMON_INFERFUNC(AddNInferShape) {
/*
add_n has four type inputs:
1.empty 2.static shape 3.-1 4.-2
The combinations bring 15 scenes, and the 15 scenes can be classify into 4 categories:
1.input with no range and output no need range, and it can be divided half:
1.1 all input is empty
1.2 input only contains empty and -2 shape
2.input contains static shape and with no -1 shape
3.input contains -1 shape
*/
int64_t tensor_num = GetAddNConstValue(op);
int64_t infer_classify = AddNInferClassify(op, tensor_num);
// condition 1: all input shape is empty
if (infer_classify == 11) {
std::vector<int64_t> shape_vector = op.GetDynamicInputDesc("x", 0).GetShape().GetDims();
DataType x_dtype = op.GetDynamicInputDesc("x", 0).GetDataType();
TensorDesc y_desc = op.GetOutputDescByName("y");
y_desc.SetShape(Shape(shape_vector));
y_desc.SetDataType(x_dtype);
(void)op.UpdateOutputDesc("y", y_desc);
// condition 2: all input is -2 or only empty and -2
} else if (infer_classify == 12) {
std::vector<int64_t> shape_vector = {-2};
DataType x_dtype = op.GetDynamicInputDesc("x", 0).GetDataType();
TensorDesc y_desc = op.GetOutputDescByName("y");
y_desc.SetShape(Shape(shape_vector));
y_desc.SetDataType(x_dtype);
(void)op.UpdateOutputDesc("y", y_desc);
// condition 3: contains static shape and no -1 shape
} else if (infer_classify == 2) {
DataType x_dtype = op.GetDynamicInputDesc("x", 0).GetDataType();
std::vector<int64_t> shape_vector = op.GetDynamicInputDesc("x", 0).GetShape().GetDims();
for (int64_t i = 0; i < tensor_num; i++) {
std::vector<int64_t> temp_vector = op.GetDynamicInputDesc("x", i).GetShape().GetDims();
if (!shape_vector.empty() && !IsUnknownRankShape(shape_vector)) {
shape_vector = temp_vector;
break;
}
}
TensorDesc y_desc = op.GetOutputDescByName("y");
y_desc.SetShape(ge::Shape(shape_vector));
y_desc.SetDataType(x_dtype);
std::vector<std::pair<int64_t, int64_t>> out_range;
MakeUpShapeRange(shape_vector, out_range);
y_desc.SetShapeRange(out_range);
(void)op.UpdateOutputDesc("y", y_desc);
// condition 4: contains -1 shape, range need to choose the intersection
} else {
Shape out_shape = op.GetDynamicInputDesc("x", 0).GetShape();
DataType x_dtype = op.GetDynamicInputDesc("x", 0).GetDataType();
std::vector<int64_t> out_vector;
std::vector<std::pair<int64_t, int64_t>> out_range;
// Init the output shape and range
for (int64_t i = 0; i < tensor_num; i++) {
std::vector<int64_t> temp_vector = op.GetDynamicInputDesc("x", i).GetShape().GetDims();
if (!temp_vector.empty() && !IsUnknownRankShape(temp_vector)) {
out_vector = temp_vector;
op.GetDynamicInputDesc("x", i).GetShapeRange(out_range);
MakeUpShapeRange(out_vector, out_range);
break;
}
}
// compute the shape dims and range intersection
for (int64_t i = 0; i < tensor_num; i++) {
std::vector<int64_t> temp_vector = op.GetDynamicInputDesc("x", i).GetShape().GetDims();
if (temp_vector.empty() || IsUnknownRankShape(temp_vector)) {
continue;
}
std::vector<std::pair<int64_t, int64_t>> temp_range;
op.GetDynamicInputDesc("x", i).GetShapeRange(temp_range);
MakeUpShapeRange(temp_vector, temp_range);
for (size_t j = 0; j < temp_vector.size(); j++) {
// two condition: const == const; const > -1
if (temp_vector[j] >= out_vector[j]) {
out_vector[j] = temp_vector[j];
// update range: left choose the max value
if (temp_range[j].first >= out_range[j].first) {
out_range[j].first = temp_range[j].first;
}
// update range: right choose the miner value but when it was > 0
if ((temp_range[j].second <= out_range[j].second && temp_range[j].second > 0) ||
(out_range[j].second == -1 && temp_range[j].second != -1)) {
out_range[j].second = temp_range[j].second;
}
}
}
}
TensorDesc y_desc = op.GetOutputDescByName("y");
out_shape = Shape(out_vector);
y_desc.SetShape(out_shape);
y_desc.SetDataType(x_dtype);
y_desc.SetShapeRange(out_range);
(void)op.UpdateOutputDesc("y", y_desc);
}
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(AddN, AddNInferShape);
INFER_AXIS_TYPE_INFO_REG(AddN, InferAxisType4ElementwiseOp);
// ----------------AddN END-------------------
// --------------------------------BiasAdd-------------------------------------
IMPLEMT_VERIFIER(BiasAdd, BiasAddVerify) {
std::string data_format;
if (op.GetAttr("data_format", data_format) == GRAPH_FAILED) {
std::string err_msg = GetInputInvalidErrMsg("data_format");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
}
if (data_format != "NHWC" && data_format != "NCHW" && data_format != "NDHWC" && data_format != "NCDHW") {
string expected_format_list = ConcatString("NHWC, NCHW, NDHWC, NCDHW");
std::string err_msg = GetInputFormatNotSupportErrMsg(TbeGetName(op).c_str(), expected_format_list, data_format);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(BiasAddInferShape) {
const int64_t input_x_idx = 0;
const int64_t output_y_idx = 0;
if (!OneInOneOutDynamicInfer(op, input_x_idx, {output_y_idx})) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(BiasAdd, BiasAddInferShape);
VERIFY_FUNC_REG(BiasAdd, BiasAddVerify);
INFER_AXIS_TYPE_INFO_REG(BiasAdd, InferAxisType4BroadcastOp);
// ----------------------------------BiasAdd END-----------------------------
// --------------MulNoNan--------------
IMPLEMT_VERIFIER(MulNoNan, MulNoNanVerify) {
DataType input_type_x1 = op.GetInputDescByName("x1").GetDataType();
DataType input_type_x2 = op.GetInputDescByName("x2").GetDataType();
if (input_type_x1 != input_type_x2) {
string err_msg1 =
ConcatString("the dtype of input_type_x1 and input_type_x2 must be same! input_type_x1:", input_type_x1,
", input_type_x2:", input_type_x2);
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
VERIFY_FUNC_REG(MulNoNan, MulNoNanVerify);
IMPLEMT_COMMON_INFERFUNC(MulNoNanInferShape) {
bool is_dynamic_output = true;
if (InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y", is_dynamic_output)) {
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}
COMMON_INFER_FUNC_REG(MulNoNan, MulNoNanInferShape);
// ------------MulNoNan END--------------
// -------------------LessEqual---------------------
IMPLEMT_VERIFIER(LessEqual, LessEqualVerify) {
if (!CheckTwoInputDtypeSame(op, "x1", "x2")) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(LessEqualInferShape) {
if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y")) {
return GRAPH_FAILED;
}
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto vec_y = op_desc->MutableOutputDesc("y")->MutableShape().GetDims();
if (IsUnknownRankShape(vec_y) || IsUnknownVec(vec_y)) {
if (!InferShapeRangeTwoInOneOutBroadcast(op, "x1", "x2", "y")) {
return GRAPH_FAILED;
}
}
op_desc->MutableOutputDesc("y")->SetDataType(DT_BOOL);
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(LessEqual, LessEqualInferShape);
VERIFY_FUNC_REG(LessEqual, LessEqualVerify);
// --------------------LessEqual END-----------------------
// --------------------Mul-----------------------
IMPLEMT_VERIFIER(Mul, MulVerify) {
if (!CheckTwoInputDtypeSame(op, "x1", "x2")) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(Mul, TwoInOneOutCommonInferShape);
VERIFY_FUNC_REG(Mul, MulVerify);
// --------------------Mul END-----------------------
// -------------------FloorDiv-----------------------
IMPLEMT_VERIFIER(FloorDiv, FloorDivVerify) {
if (!CheckTwoInputDtypeSame(op, "x1", "x2")) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(FloorDiv, TwoInOneOutCommonInferShape);
VERIFY_FUNC_REG(FloorDiv, FloorDivVerify);
// ----------------FloorDiv END------------------------
} // namespace ge

View File

@ -0,0 +1,51 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/nn_pooling_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
IMPLEMT_INFERFUNC(FractionalAvgPoolGrad, FractionalAvgPoolGradInfer) {
Tensor tensor;
if (op.GetInputConstData("orig_input_tensor_shape", tensor) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op),
ConcatString("get const data from input[orig_input_tensor_shape] failed"));
return GRAPH_FAILED;
}
Shape result;
if (MakeShapeFromShapeTensor(tensor, result, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op),
ConcatString("call MakeShapeFromShapeTensor function failed to make",
" shape from input[orig_input_tensor_shape] data"));
return GRAPH_FAILED;
}
DataType y_type = op.GetInputDescByName("out_backprop").GetDataType();
TensorDesc out_desc = op.GetOutputDescByName("y");
out_desc.SetShape(Shape(result));
out_desc.SetDataType(y_type);
if (op.UpdateOutputDesc("y", out_desc) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("update output[y] desc failed."));
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(FractionalAvgPoolGrad, FractionalAvgPoolGradInfer);
} // namespace ge

View File

@ -0,0 +1,83 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/nn_pooling_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
IMPLEMT_INFERFUNC(FractionalAvgPool, FractionalAvgPoolInfer) {
Shape input;
if (WithRank(op.GetInputDesc(0), 4, input, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
TbeGetName(op), ConcatString("Call WithRank function failed, ",
GetShapeErrMsg(0, DebugString(op.GetInputDesc(0).GetShape().GetDims()), "4D")));
return GRAPH_FAILED;
}
std::vector<float> pooling_ratio;
op.GetAttr("pooling_ratio", pooling_ratio);
if (pooling_ratio.size() != 4) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op),
GetAttrValueErrMsg("pooling_ratio", ConcatString(pooling_ratio.size()), "4"));
return GRAPH_PARAM_INVALID;
}
auto x_dims = op.GetInputDesc(0).GetShape().GetDims();
std::vector<int64_t> dims;
dims.reserve(4);
for (int i = 0; i < 4; ++i) {
int64_t val = ge::UNKNOWN_DIM;
if (x_dims[i] != ge::UNKNOWN_DIM) {
val = static_cast<int64_t>(x_dims[i] / pooling_ratio[i]);
if (val < 0) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(
TbeGetName(op), ConcatString("size computed for ", i, "th dim is ", val, ", should be >= 0"));
return GRAPH_PARAM_INVALID;
}
}
OP_LOGI(TbeGetName(op).c_str(), "i = %d, x_dims[i] = %ld, pooling_ratio[i] = %f, val = %ld", i, x_dims[i],
pooling_ratio[i], val);
dims.push_back(val);
}
Shape out(dims);
Shape row_pooling_sequence;
(void)Vector(dims[1] + 1, row_pooling_sequence);
Shape col_pooling_sequence;
(void)Vector(dims[2] + 1, col_pooling_sequence);
DataType type = op.GetInputDescByName("x").GetDataType();
TensorDesc y_desc = op.GetOutputDescByName("y");
y_desc.SetShape(out);
y_desc.SetDataType(type);
op.UpdateOutputDesc("y", y_desc);
TensorDesc row_desc = op.GetOutputDescByName("row_pooling_sequence");
row_desc.SetShape(row_pooling_sequence);
row_desc.SetDataType(DT_INT64);
op.UpdateOutputDesc("row_pooling_sequence", row_desc);
TensorDesc col_desc = op.GetOutputDescByName("col_pooling_sequence");
col_desc.SetShape(col_pooling_sequence);
col_desc.SetDataType(DT_INT64);
op.UpdateOutputDesc("col_pooling_sequence", col_desc);
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(FractionalAvgPool, FractionalAvgPoolInfer);
} // namespace ge

View File

@ -1,73 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/geqrf_op.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
// ---------------Geqrf-------------------
CUST_IMPLEMT_INFERFUNC(Geqrf, GeqrfInfer) {
auto tensor = op.get_input_desc_x();
Shape input;
if (WithRank(tensor, 2, input, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
int dim_num = input.GetDimNum();
int m = input.GetDim(dim_num - 2);
int n = input.GetDim(dim_num - 1);
Shape r_shape;
Shape tau_shape;
int p = m > n ? n : m;
Matrix(m, n, r_shape);
Vector(p, tau_shape);
DataType type = op.GetInputDescByName("x").GetDataType();
TensorDesc r_desc = op.GetOutputDescByName("r");
r_desc.SetShape(Shape(r_shape));
r_desc.SetDataType(type);
if (op.UpdateOutputDesc("r", r_desc) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Update r desc failed.");
return GRAPH_FAILED;
}
TensorDesc tau_desc = op.GetOutputDescByName("tau");
tau_desc.SetShape(Shape(tau_shape));
tau_desc.SetDataType(type);
if (op.UpdateOutputDesc("tau", tau_desc) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Update tau desc failed.");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_INFER_FUNC_REG(Geqrf, GeqrfInfer);
CUST_IMPLEMT_VERIFIER(Geqrf, GeqrfVerify) {
DataType type = op.GetInputDescByName("x").GetDataType();
if (type != DT_FLOAT16 && type != DT_FLOAT && type != DT_DOUBLE && type != DT_COMPLEX64 && type != DT_COMPLEX128) {
OP_LOGE(TbeGetName(op).c_str(), "Expect a floating point or complex tensor as input.");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_VERIFY_FUNC_REG(Geqrf, GeqrfVerify);
// ---------------Geqrf End---------------
} // namespace ge

View File

@ -1,64 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/hamming_window_op.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
namespace ge {
// ----------------HammingWindow Begin---------------------
IMPLEMT_COMMON_INFERFUNC(HammingWindowInferShape) {
std::vector<int64_t> input_dim = op.GetInputDesc(0).GetShape().GetDims();
if (input_dim.size() != 1) {
OP_LOGE(TbeGetName(op).c_str(), "Tensor length input must be 1D.");
return GRAPH_FAILED;
}
Tensor length_tensor;
int64_t length_data;
if (op.GetInputConstData("length", length_tensor) == GRAPH_SUCCESS) {
uint8_t *length = length_tensor.GetData();
length_data = static_cast<int64_t>(*length);
} else {
length_data = UNKNOWN_DIM;
}
std::vector<int64_t> output_dim;
if (length_data != UNKNOWN_DIM && length_data < 0) {
OP_LOGE(TbeGetName(op).c_str(), "Non-negative window length required, got [%ld].", length_data);
return GRAPH_FAILED;
}
if (length_data != 0) {
output_dim.push_back(length_data);
}
ge::Shape output_shape = ge::Shape(output_dim);
Operator::OpInt dtype;
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
dtype = 0;
}
DataType output_dtype = static_cast<DataType>(dtype);
TensorDesc output_desc = op.GetOutputDescByName("y");
output_desc.SetShape(output_shape);
output_desc.SetDataType(output_dtype);
op.UpdateOutputDesc("y", output_desc);
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(HammingWindow, HammingWindowInferShape);
// ----------------HammingWindow End---------------------
} // namespace ge

View File

@ -0,0 +1,202 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/transformation_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
#include "graph/common_error_codes.h"
namespace ge {
const std::string ATTR_NAME_DATA_SLICE = "_data_slice";
static bool CheckListEmptyAndValue(const std::string &op_name, const std::vector<int64_t> &list,
const std::string &attr_name) {
if (list.size() < 1) {
OP_LOGE(op_name.c_str(), "The %s dose not have enough elements(%lu)!", attr_name.c_str(), list.size());
return false;
}
return true;
}
static std::vector<int64_t> GetAttrValue(const Operator &op, const std::string &key_name) {
std::vector<int64_t> list;
if (op.GetAttr(key_name.c_str(), list) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "GetOpAttr ConstValue failed!");
}
return list;
}
// -----------------Im2col Op-------------------------
IMPLEMT_VERIFIER(Im2col, Im2colVerify) {
std::vector<int64_t> ksize;
ksize = GetAttrValue(op, "ksizes");
if (ksize.size() < 2) {
OP_LOGE(TbeGetName(op).c_str(), "The ksizes dose not have enough elements(%lu)!", ksize.size());
return GRAPH_FAILED;
}
std::vector<int64_t> stride;
stride = GetAttrValue(op, "strides");
if (!CheckListEmptyAndValue(TbeGetName(op), stride, "strides")) {
return GRAPH_FAILED;
}
std::vector<int64_t> dilation;
dilation = GetAttrValue(op, "dilations");
if (!CheckListEmptyAndValue(TbeGetName(op), dilation, "dilations")) {
return GRAPH_FAILED;
}
std::string padding_mode;
if (op.GetAttr("padding_mode", padding_mode) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Get padding_mode failed!");
return GRAPH_FAILED;
}
if (padding_mode != "CALCULATED" && padding_mode != "SAME" && padding_mode != "VALID") {
OP_LOGE(TbeGetName(op).c_str(), "padding_mode only support CALCULATED, SAME and VALID!");
return GRAPH_FAILED;
}
std::vector<int64_t> pad;
pad = GetAttrValue(op, "pads");
if (!CheckListEmptyAndValue(TbeGetName(op), pad, "pads")) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(Im2colInferShape) {
OP_LOGI(TbeGetName(op).c_str(), "Enter op_proto inferfunction!");
std::vector<int64_t> ksize;
ksize = GetAttrValue(op, "ksizes");
std::vector<int64_t> stride;
stride = GetAttrValue(op, "strides");
std::vector<int64_t> dilation;
dilation = GetAttrValue(op, "dilations");
std::string padding_mode;
if (op.GetAttr("padding_mode", padding_mode) == GRAPH_FAILED) {
OP_LOGE(TbeGetName(op).c_str(), "GetOpAttr ConstValue padding_mode failed!");
return GRAPH_FAILED;
}
std::vector<int64_t> pad;
pad = GetAttrValue(op, "pads");
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
GeTensorDescPtr desc_in_ptr = op_desc->MutableInputDesc("x");
GeTensorDescPtr desc_out_ptr = op_desc->MutableOutputDesc("y");
auto dtype = desc_in_ptr->GetDataType();
auto shape_in = desc_in_ptr->GetShape();
auto x_format = desc_in_ptr->GetOriginFormat();
if (x_format != FORMAT_NHWC && x_format != FORMAT_NCHW) {
OP_LOGE(TbeGetName(op).c_str(), "Attr x_format only support NHWC, NCHW.");
return GRAPH_FAILED;
}
std::map<char, int> idx_map{{'N', 0}, {'H', 1}, {'W', 2}, {'C', 3}};
if (x_format == FORMAT_NCHW) {
idx_map = {{'N', 0}, {'C', 1}, {'H', 2}, {'W', 3}};
}
int64_t in_n = shape_in.GetDim(idx_map['N']);
int64_t in_h = shape_in.GetDim(idx_map['H']);
int64_t in_w = shape_in.GetDim(idx_map['W']);
int64_t in_c = shape_in.GetDim(idx_map['C']);
if (ksize.size() != 2) {
OP_LOGE(TbeGetName(op).c_str(), "The size of ksizes must be 2 when x_format only support NHWC, NCHW.");
return GRAPH_FAILED;
}
int64_t filter_h = ksize[0];
int64_t filter_w = ksize[1];
int64_t stride_h = stride[0];
int64_t stride_w = stride[0];
if (stride.size() == 2) {
stride_h = stride[0];
stride_w = stride[1];
} else if (stride.size() != 1) {
OP_LOGE(TbeGetName(op).c_str(), "The size of strides must be 1 or 2 when x_format only support NHWC, NCHW.");
return GRAPH_FAILED;
}
if (stride_h == 0 || stride_w == 0) {
OP_LOGE(TbeGetName(op).c_str(), "The stride_h or stride_w should not 0");
return GRAPH_FAILED;
}
int64_t dilation_h = dilation[0];
int64_t dilation_w = dilation[0];
if (dilation.size() == 2) {
dilation_h = dilation[0];
dilation_w = dilation[1];
} else if (dilation.size() != 1) {
OP_LOGE(TbeGetName(op).c_str(), "The size of dilations must be 1 or 2 when x_format only support NHWC, NCHW.");
return GRAPH_FAILED;
}
int64_t effective_filter_h = (filter_h - 1) * dilation_h + 1;
int64_t effective_filter_w = (filter_w - 1) * dilation_w + 1;
int64_t out_h{0};
int64_t out_w{0};
int64_t out_c{0};
if (padding_mode == "VALID") {
out_h = (in_h - effective_filter_h + stride_h) / stride_h;
out_w = (in_w - effective_filter_w + stride_w) / stride_w;
} else if (padding_mode == "SAME") {
out_h = (in_h + stride_h - 1) / stride_h;
out_w = (in_w + stride_w - 1) / stride_w;
} else if (padding_mode == "CALCULATED") {
int64_t pad_h_top;
int64_t pad_h_bottom;
int64_t pad_w_before;
int64_t pad_w_after;
if (pad.size() == 1) {
pad_h_top = pad[0];
pad_h_bottom = pad[0];
pad_w_before = pad[0];
pad_w_after = pad[0];
} else if (pad.size() == 4) {
pad_h_top = pad[0];
pad_h_bottom = pad[1];
pad_w_before = pad[2];
pad_w_after = pad[3];
} else {
OP_LOGE(TbeGetName(op).c_str(), "The size of pads must be 1 or 4 when x_format only support NHWC, NCHW.");
return GRAPH_FAILED;
}
out_h = (in_h + pad_h_top + pad_h_bottom - (dilation_h * (filter_h - 1) + 1)) / stride_h + 1;
out_w = (in_w + pad_w_before + pad_w_after - (dilation_w * (filter_w - 1) + 1)) / stride_w + 1;
} else {
OP_LOGE(TbeGetName(op).c_str(), "The padding_mode only support VALID, SAME and CALCULATED.");
return GRAPH_FAILED;
}
out_c = in_c * filter_h * filter_w;
std::vector<int64_t> out_dim{in_n, out_h, out_w, out_c};
if (x_format == FORMAT_NCHW) {
out_dim = {in_n, out_c, out_h, out_w};
}
desc_out_ptr->SetShape(ge::GeShape(out_dim));
desc_out_ptr->SetDataType(dtype);
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(Im2col, Im2colInferShape);
VERIFY_FUNC_REG(Im2col, Im2colVerify);
// -----------------Im2col END-------------------------
} // namespace ge

View File

@ -0,0 +1,142 @@
/**
* Copyright (c) 2022-202 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/image_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/image_ops_shape_fns.h"
namespace ge {
// ----------------AdjustHue Start-------------------
IMPLEMT_INFERFUNC(AdjustHue, AdjustHueInfer) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto images_desc = op_desc->MutableInputDesc(0);
GeShape out;
auto ret = WithRankAtLeast(images_desc, 3, out, op);
if (ret != GRAPH_SUCCESS) {
std::string err_msg = GetShapeErrMsg(0, DebugString(images_desc->GetShape().GetDims()), "at least 3D");
err_msg = string("failed to call WithRankAtLeast function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
std::vector<std::pair<int64_t, int64_t>> range;
if (images_desc->GetShapeRange(range) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
auto y_desc = op_desc->MutableOutputDesc(0);
y_desc->SetShape(out);
y_desc->SetShapeRange(range);
y_desc->SetDataType(images_desc->GetDataType());
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(AdjustHue, AdjustHueInfer);
// ----------------AdjustHue End-------------------
// ----------------AdjustSaturation Start-------------------
static graphStatus AdjustSaturationCommInferShape(const Operator &op) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto images_desc = op_desc->MutableInputDesc(0);
GeShape out;
auto ret = WithRankAtLeast(images_desc, 3, out, op);
if (ret != GRAPH_SUCCESS) {
std::string err_msg = GetShapeErrMsg(0, DebugString(images_desc->GetShape().GetDims()), "at least 3D");
err_msg = string("failed to call WithRankAtLeast function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
std::vector<std::pair<int64_t, int64_t>> range;
if (images_desc->GetShapeRange(range) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
auto y_desc = op_desc->MutableOutputDesc(0);
y_desc->SetShape(out);
y_desc->SetShapeRange(range);
y_desc->SetDataType(images_desc->GetDataType());
return GRAPH_SUCCESS;
}
IMPLEMT_INFERFUNC(AdjustSaturation, AdjustSaturationInfer) { return AdjustSaturationCommInferShape(op); }
INFER_FUNC_REG(AdjustSaturation, AdjustSaturationInfer);
// ----------------AdjustSaturation END-------------------
// ----------------ExtractGlimpse-------------------
IMPLEMT_INFERFUNC(ExtractGlimpse, ExtractGlimpseInfer) {
Shape x_shape;
auto ret = WithRank(op.GetInputDesc(0), 4, x_shape, op);
if (ret != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op), "input x must be 4-D");
return GRAPH_PARAM_INVALID;
}
Shape offsets_shape;
ret = WithRank(op.GetInputDesc(2), 2, offsets_shape, op);
if (ret != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op), "input offsets must be 2-D");
return GRAPH_PARAM_INVALID;
}
auto x_dims = op.GetInputDesc(0).GetShape().GetDims();
auto offsets_dims = op.GetInputDesc(2).GetShape().GetDims();
CHECK(x_dims.size() < 4 || offsets_dims.size() < 2, OP_LOGE(TbeGetName(op), "invalid x_dims or offsets_dims."),
return GRAPH_FAILED);
int64_t batch_dim;
if (Merge(x_dims[0], offsets_dims[0], batch_dim) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op), "x dim-0 or offsets dim-0 is invalid");
return GRAPH_PARAM_INVALID;
}
if (offsets_dims[1] != 2) {
OP_LOGE(TbeGetName(op), "offsets dim-1 must be 2");
return GRAPH_PARAM_INVALID;
}
bool uniform_noise = false;
if (op.GetAttr("uniform_noise", uniform_noise) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op), "get attr uniform_noise failed");
return GRAPH_FAILED;
}
std::string noise;
if (op.GetAttr("noise", noise) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op), "get attr noise failed");
return GRAPH_FAILED;
}
if (uniform_noise && (!noise.empty() && noise != "uniform")) {
OP_LOGE(TbeGetName(op), "The uniform_noise and noise should not be specified at the same time");
return GRAPH_FAILED;
}
TensorDesc desc = op.GetOutputDescByName("y");
desc.SetDataType(DT_FLOAT);
if (op.UpdateOutputDesc("y", desc) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
auto channel_dim = x_dims[3];
TensorDesc input_td = op.GetInputDesc(0);
if (static_cast<ge::Format>(ge::GetPrimaryFormat(input_td.GetFormat())) == FORMAT_NCHW) {
channel_dim = x_dims[1];
}
return SetOutputToSizedImage(op, batch_dim, "size", channel_dim, "y");
}
INFER_FUNC_REG(ExtractGlimpse, ExtractGlimpseInfer);
// ----------------ExtractGlimpse-------------------
} // namespace ge

View File

@ -1,45 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/index_fill.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
// ----------------IndexFill-------------------
// Obtains the processing function of the output tensor description.
IMPLEMT_COMMON_INFERFUNC(IndexFillInferShape) {
TensorDesc v_output_desc = op.GetOutputDescByName("y");
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
Format input_format = op.GetInputDescByName("x").GetFormat();
// shape of output y is the same as input x
ge::Shape shape_input = op.GetInputDescByName("x").GetShape();
v_output_desc.SetShape(shape_input);
v_output_desc.SetDataType(input_dtype);
v_output_desc.SetFormat(input_format);
if (op.UpdateOutputDesc("y", v_output_desc) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
// Registered inferfunction
CUST_COMMON_INFER_FUNC_REG(IndexFill, IndexFillInferShape);
// ----------------IndexFill END-------------------
} // namespace ge

View File

@ -0,0 +1,407 @@
/**
* Copyright (c) 2023 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/linalg_ops.h"
#include "custom_op_proto/cust_linalg_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/linalg_ops_shape_fns.h"
#include "utils/common_shape_fns.h"
namespace ge {
// ----------------MatrixSolve-------------------
IMPLEMT_INFERFUNC(MatrixSolve, MatrixSolveInfer) {
auto matrix_tensor = op.get_input_desc_matrix();
auto rhs_tensor = op.get_input_desc_rhs();
Shape result;
if (MatrixSolve(matrix_tensor, rhs_tensor, true, result, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Op MatrixSolve Call MatrixSolve Infer Shape fns Failed.");
return GRAPH_FAILED;
}
DataType type = op.GetInputDescByName("matrix").GetDataType();
TensorDesc y_desc = op.GetOutputDescByName("y");
y_desc.SetShape(Shape(result));
y_desc.SetDataType(type);
op.UpdateOutputDesc("y", y_desc);
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(MatrixSolve, MatrixSolveInfer);
// ----------------MatrixSolve End-------------------
// ----------------MatrixDeterminant-------------------
IMPLEMT_INFERFUNC(MatrixDeterminant, MatrixDeterminantInfer) {
auto tensor = op.get_input_desc_x();
Shape s;
if (WithRankAtLeast(tensor, 2, s, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "The rank of x must be at least 2.");
return GRAPH_FAILED;
}
int64_t existing = s.GetDimNum();
int64_t dim1 = s.GetDim(existing - 1);
int64_t dim2 = s.GetDim(existing - 2);
int64_t unused_dim = 0;
if (Merge(dim1, dim2, unused_dim) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Merge two dimension failed.");
return GRAPH_FAILED;
}
Shape result;
if (SubShape(s, 0, -2, 1, result, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Op MatrixDeterminant Get SubShape Failed.");
return GRAPH_FAILED;
}
DataType type = op.GetInputDescByName("x").GetDataType();
TensorDesc y_desc = op.GetOutputDescByName("y");
y_desc.SetShape(Shape(result));
y_desc.SetDataType(type);
op.UpdateOutputDesc("y", y_desc);
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(MatrixDeterminant, MatrixDeterminantInfer);
// ----------------MatrixDeterminant End-------------------
// ----------------MatrixTriangularSolve-------------------
IMPLEMT_INFERFUNC(MatrixTriangularSolve, MatrixTriangularSolveInfer) {
auto matrix_tensor = op.get_input_desc_matrix();
auto rhs_tensor = op.get_input_desc_rhs();
Shape result;
if (MatrixSolve(matrix_tensor, rhs_tensor, true, result, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Op MatrixTriangularSolve Call MatrixSolve Infer Shape fns Failed.");
return GRAPH_FAILED;
}
DataType type = op.GetInputDescByName("matrix").GetDataType();
TensorDesc y_desc = op.GetOutputDescByName("y");
y_desc.SetShape(Shape(result));
y_desc.SetDataType(type);
op.UpdateOutputDesc("y", y_desc);
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(MatrixTriangularSolve, MatrixTriangularSolveInfer);
// ----------------MatrixTriangularSolve END-------------------
// ---------------Geqrf-------------------
CUST_IMPLEMT_INFERFUNC(Geqrf, GeqrfInfer) {
auto tensor = op.get_input_desc_x();
Shape input;
if (WithRank(tensor, 2, input, op) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
int dim_num = input.GetDimNum();
int m = input.GetDim(dim_num - 2);
int n = input.GetDim(dim_num - 1);
Shape r_shape;
Shape tau_shape;
int p = m > n ? n : m;
Matrix(m, n, r_shape);
Vector(p, tau_shape);
DataType type = op.GetInputDescByName("x").GetDataType();
TensorDesc r_desc = op.GetOutputDescByName("r");
r_desc.SetShape(Shape(r_shape));
r_desc.SetDataType(type);
if (op.UpdateOutputDesc("r", r_desc) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Update r desc failed.");
return GRAPH_FAILED;
}
TensorDesc tau_desc = op.GetOutputDescByName("tau");
tau_desc.SetShape(Shape(tau_shape));
tau_desc.SetDataType(type);
if (op.UpdateOutputDesc("tau", tau_desc) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Update tau desc failed.");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_INFER_FUNC_REG(Geqrf, GeqrfInfer);
CUST_IMPLEMT_VERIFIER(Geqrf, GeqrfVerify) {
DataType type = op.GetInputDescByName("x").GetDataType();
if (type != DT_FLOAT16 && type != DT_FLOAT && type != DT_DOUBLE && type != DT_COMPLEX64 && type != DT_COMPLEX128) {
OP_LOGE(TbeGetName(op).c_str(), "Expect a floating point or complex tensor as input.");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_VERIFY_FUNC_REG(Geqrf, GeqrfVerify);
// ---------------Geqrf End---------------
// ---------------LuUnpack---------------
CUST_IMPLEMT_INFERFUNC(LuUnpack, LuUnpackInferShape) {
Shape LU_data;
if (WithRankAtLeast(op.GetInputDesc(0), 2, LU_data, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "LU_data rank must be at least 2.");
return GRAPH_FAILED;
}
int64_t existing = LU_data.GetDimNum();
int64_t dim1 = LU_data.GetDim(existing - 2);
int64_t dim2 = LU_data.GetDim(existing - 1);
Shape batch_shape;
if (SubShape(LU_data, 0, -2, 1, batch_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Op LuUnpack Get SubShape Failed.");
return GRAPH_FAILED;
}
Shape L_shape;
vector<int64_t> L_dims;
L_dims.reserve(2);
if (dim1 >= dim2) {
L_dims.push_back(dim1);
L_dims.push_back(dim2);
} else {
L_dims.push_back(dim1);
L_dims.push_back(dim1);
}
Shape L_sec_shape(L_dims);
if (Concatenate(batch_shape, L_sec_shape, L_shape) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Concatenate L_shape failed!");
return GRAPH_FAILED;
}
Shape U_shape;
vector<int64_t> U_dims;
U_dims.reserve(2);
if (dim1 >= dim2) {
U_dims.push_back(dim2);
U_dims.push_back(dim2);
} else {
U_dims.push_back(dim1);
U_dims.push_back(dim2);
}
Shape U_sec_shape(U_dims);
if (Concatenate(batch_shape, U_sec_shape, U_shape) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Concatenate U_shape failed!");
return GRAPH_FAILED;
}
Shape pivots_shape;
vector<int64_t> pivots_dims;
pivots_dims.reserve(2);
pivots_dims.push_back(dim1);
pivots_dims.push_back(dim1);
Shape pivots_sec_shape(pivots_dims);
if (Concatenate(batch_shape, pivots_sec_shape, pivots_shape) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Concatenate pivots_shape failed!");
return GRAPH_FAILED;
}
TensorDesc L_desc = op.GetOutputDescByName("L");
L_desc.SetShape(Shape(L_shape));
DataType L_type = op.GetInputDescByName("LU_data").GetDataType();
L_desc.SetDataType(L_type);
if (L_desc.GetDataType() != L_type) {
OP_LOGE(TbeGetName(op).c_str(), "the type of L must be the same as the type of LU_data.");
return GRAPH_FAILED;
}
if (op.UpdateOutputDesc("L", L_desc) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "fail to update output L.");
return GRAPH_FAILED;
}
TensorDesc U_desc = op.GetOutputDescByName("U");
U_desc.SetShape(Shape(U_shape));
DataType U_type = op.GetInputDescByName("LU_data").GetDataType();
U_desc.SetDataType(U_type);
if (U_desc.GetDataType() != U_type) {
OP_LOGE(TbeGetName(op).c_str(), "the type of U must be the same as the type of LU_data.");
return GRAPH_FAILED;
}
if (op.UpdateOutputDesc("U", U_desc) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "fail to update output U.");
return GRAPH_FAILED;
}
TensorDesc pivots_desc = op.GetOutputDescByName("pivots");
pivots_desc.SetShape(Shape(pivots_shape));
DataType pivots_type = op.GetInputDescByName("LU_data").GetDataType();
pivots_desc.SetDataType(pivots_type);
if (pivots_desc.GetDataType() != pivots_type) {
OP_LOGE(TbeGetName(op).c_str(), "the type of pivots must be the same as the type of LU_data.");
return GRAPH_FAILED;
}
if (op.UpdateOutputDesc("pivots", pivots_desc) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "fail to update output pivots.");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(LuUnpack, LuUnpackVerify) {
DataType LU_data_type = op.GetInputDescByName("LU_data").GetDataType();
DataType LU_pivots_type = op.GetInputDescByName("LU_pivots").GetDataType();
if (LU_data_type != DT_FLOAT16 && LU_data_type != DT_FLOAT && LU_data_type != DT_DOUBLE && LU_data_type != DT_INT8 &&
LU_data_type != DT_UINT8 && LU_data_type != DT_INT16 && LU_data_type != DT_INT32 && LU_data_type != DT_INT64) {
std::string err_msg;
err_msg = ConcatString("Op LuUnpack first input LU_data_type's data type should be of the follows: ",
"DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE,",
"but this type is ", LU_data_type, ".");
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (LU_pivots_type != DT_INT8 && LU_pivots_type != DT_UINT8 && LU_pivots_type != DT_INT16 &&
LU_pivots_type != DT_INT32 && LU_pivots_type != DT_INT64) {
std::string err_msg;
err_msg =
ConcatString("Op LuUnpack first input LU_data_type's data type should be of the follows: ",
"DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64,", "but this type is ", LU_pivots_type, ".");
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_INFER_FUNC_REG(LuUnpack, LuUnpackInferShape);
CUST_VERIFY_FUNC_REG(LuUnpack, LuUnpackVerify);
// ---------------LuUnpack END---------------
// ---------------LuUnpackGrad---------------
IMPLEMT_COMMON_INFERFUNC(LuUnpackGradInferShape) {
TensorDesc L_grad;
TensorDesc U_grad;
TensorDesc LU_data;
if (op.TryGetInputDesc("LU_data", LU_data) == GRAPH_FAILED) {
OP_LOGE(TbeGetName(op).c_str(), "LU_data can not be null.");
return GRAPH_FAILED;
}
TensorDesc L_data_grad = op.GetOutputDescByName("L_data_grad");
L_data_grad.SetDataType(op.GetInputDescByName("LU_data").GetDataType());
L_data_grad.SetShape(op.GetInputDescByName("LU_data").GetShape());
if (op.UpdateOutputDesc("L_data_grad", L_data_grad) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "fail to update output L_data_grad.");
return GRAPH_FAILED;
}
TensorDesc U_data_grad = op.GetOutputDescByName("U_data_grad");
U_data_grad.SetDataType(op.GetInputDescByName("LU_data").GetDataType());
U_data_grad.SetShape(op.GetInputDescByName("LU_data").GetShape());
if (op.UpdateOutputDesc("U_data_grad", U_data_grad) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "fail to update output U_data_grad.");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(LuUnpackGrad, LuUnpackGradVerify) {
DataType LU_data_type = op.GetInputDescByName("LU_data").GetDataType();
Shape LU_data_shape = op.GetInputDescByName("LU_data").GetShape();
TensorDesc L_grad;
TensorDesc U_grad;
if (op.TryGetInputDesc("L_grad", L_grad) == GRAPH_SUCCESS) {
DataType L_grad_type = op.GetInputDescByName("L_grad").GetDataType();
Shape L_data_shape = op.GetInputDescByName("L_grad").GetShape();
auto L_data_dim1 = L_data_shape.GetDim(-2);
auto L_data_dim2 = L_data_shape.GetDim(-1);
auto LU_data_dim1 = LU_data_shape.GetDim(-2);
auto LU_data_dim2 = LU_data_shape.GetDim(-1);
int64_t LU_data_min = std::min(LU_data_dim1, LU_data_dim2);
if (LU_data_dim1 != L_data_dim1) {
OP_LOGE(TbeGetName(op).c_str(), "L_grad's data dim[-2] and LU_data's dim[-2] should be same.");
return GRAPH_FAILED;
}
if (LU_data_min != L_data_dim2) {
OP_LOGE(TbeGetName(op).c_str(), "L_grad's data dim[-1] and LU_data's minimum dim should be same.");
return GRAPH_FAILED;
}
if (LU_data_type != L_grad_type) {
OP_LOGE(TbeGetName(op).c_str(), "L_grad's data type and LU_data's type should be same.");
return GRAPH_FAILED;
}
}
if (op.TryGetInputDesc("U_grad", U_grad) == GRAPH_SUCCESS) {
DataType U_grad_type = op.GetInputDescByName("U_grad").GetDataType();
Shape U_data_shape = op.GetInputDescByName("U_grad").GetShape();
auto U_data_dim1 = U_data_shape.GetDim(-2);
auto U_data_dim2 = U_data_shape.GetDim(-1);
auto LU_data_dim1 = LU_data_shape.GetDim(-2);
auto LU_data_dim2 = LU_data_shape.GetDim(-1);
int64_t LU_data_min = std::min(LU_data_dim1, LU_data_dim2);
if (U_data_dim2 != LU_data_dim2) {
OP_LOGE(TbeGetName(op).c_str(), "U_grad's data dim[-1] and LU_data's dim[-1] should be same.");
return GRAPH_FAILED;
}
if (LU_data_min != U_data_dim1) {
OP_LOGE(TbeGetName(op).c_str(), "U_grad's data dim[-2] and LU_data's minimum dim should be same.");
return GRAPH_FAILED;
}
if (LU_data_type != U_grad_type) {
OP_LOGE(TbeGetName(op).c_str(), "U_grad's data type and LU_data's type should be same.");
return GRAPH_FAILED;
}
}
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(LuUnpackGrad, LuUnpackGradInferShape);
CUST_VERIFY_FUNC_REG(LuUnpackGrad, LuUnpackGradVerify);
// ---------------LuUnpackGrad End---------------
// -----------------------LuSolve---------------------------------
IMPLEMT_COMMON_INFERFUNC(LuSolveInferShape) {
Shape b_shape = op.GetInputDescByName("x").GetShape();
Shape lu_shape = op.GetInputDescByName("lu_data").GetShape();
size_t b_dims = b_shape.GetDimNum();
size_t lu_dims = lu_shape.GetDimNum();
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
Format input_format = op.GetInputDescByName("x").GetFormat();
std::vector<int64_t> dim_vector;
if (b_dims >= lu_dims) {
Shape output_shape = b_shape;
TensorDesc td = op.GetOutputDescByName("y");
td.SetShape(output_shape);
td.SetDataType(input_dtype);
td.SetFormat(input_format);
(void)op.UpdateOutputDesc("y", td);
return GRAPH_SUCCESS;
} else {
for (size_t i = 0; i <= lu_dims - b_dims - 1; i++) {
dim_vector.push_back(lu_shape.GetDim(i));
}
for (size_t i = 0; i <= b_dims - 1; i++) {
dim_vector.push_back(b_shape.GetDim(i));
}
Shape output_shape(dim_vector);
TensorDesc td = op.GetOutputDescByName("y");
td.SetShape(output_shape);
td.SetDataType(input_dtype);
td.SetFormat(input_format);
(void)op.UpdateOutputDesc("y", td);
return GRAPH_SUCCESS;
}
}
CUST_IMPLEMT_VERIFIER(LuSolve, LuSolveVerify) {
DataType input_type_x = op.GetInputDescByName("x").GetDataType();
DataType input_type_y = op.GetInputDescByName("lu_data").GetDataType();
if (input_type_x != input_type_y) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(LuSolve, LuSolveInferShape);
CUST_VERIFY_FUNC_REG(LuSolve, LuSolveVerify);
// -----------------------LuSolve END---------------------------------
} // namespace ge

View File

@ -1,88 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/logspace.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
// --------------------------LogSpace---------------------
static bool CheckSteps(const Operator &op, const string &attr_num_steps) {
int64_t steps = 0;
int64_t steps_ori = 100;
if (ge::GRAPH_SUCCESS != op.GetAttr(attr_num_steps.c_str(), steps)) {
steps = steps_ori;
}
if (steps < 0) {
return false;
}
return true;
}
CUST_IMPLEMT_VERIFIER(LogSpace, LogSpaceVerify) {
AscendString opName;
op.GetName(opName);
if (op.GetInputDescByName("start").GetShape().GetDims().size() != 1) {
OP_LOGE(opName.GetString(), "Input start size must be 1.");
return GRAPH_FAILED;
}
if (op.GetInputDescByName("end").GetShape().GetDims().size() != 1) {
OP_LOGE(opName.GetString(), "Input end size must be 1.");
return GRAPH_FAILED;
}
DataType input_type_start = op.GetInputDescByName("start").GetDataType();
DataType input_type_end = op.GetInputDescByName("end").GetDataType();
if (input_type_start != input_type_end) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
// Obtains the processing function of the output tensor description.
IMPLEMT_COMMON_INFERFUNC(LogSpaceInferShape) {
AscendString opName1;
op.GetName(opName1);
TensorDesc v_output_desc = op.GetOutputDescByName("y");
int64_t steps;
int64_t num_rows = 1;
op.GetAttr("steps", steps);
if (!CheckSteps(op, "steps")) {
OP_LOGE(opName1.GetString(), "the attr 'steps' should be greater than or equal to 0.");
return GRAPH_FAILED;
}
std::vector<int64_t> dim_vec;
dim_vec.push_back(num_rows);
dim_vec.push_back(steps);
v_output_desc.SetShape(ge::Shape(dim_vec));
int64_t dtype = 1;
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
v_output_desc.SetDataType(DT_FLOAT16);
} else {
if (dtype == 1) {
v_output_desc.SetDataType(DT_FLOAT16);
}
if (dtype == 0) {
v_output_desc.SetDataType(DT_FLOAT);
}
}
(void)op.UpdateOutputDesc("y", v_output_desc);
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(LogSpace, LogSpaceInferShape);
// Registered verify function
CUST_VERIFY_FUNC_REG(LogSpace, LogSpaceVerify);
// --------------------------LogSpace END---------------------
} // namespace ge

View File

@ -1,68 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/lu_solve_op.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
namespace ge {
// -----------------------LuSolve---------------------------------
IMPLEMT_COMMON_INFERFUNC(LuSolveInferShape) {
Shape b_shape = op.GetInputDescByName("x").GetShape();
Shape lu_shape = op.GetInputDescByName("lu_data").GetShape();
size_t b_dims = b_shape.GetDimNum();
size_t lu_dims = lu_shape.GetDimNum();
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
Format input_format = op.GetInputDescByName("x").GetFormat();
std::vector<int64_t> dim_vector;
if (b_dims >= lu_dims) {
Shape output_shape = b_shape;
TensorDesc td = op.GetOutputDescByName("y");
td.SetShape(output_shape);
td.SetDataType(input_dtype);
td.SetFormat(input_format);
(void)op.UpdateOutputDesc("y", td);
return GRAPH_SUCCESS;
} else {
for (size_t i = 0; i <= lu_dims - b_dims - 1; i++) {
dim_vector.push_back(lu_shape.GetDim(i));
}
for (size_t i = 0; i <= b_dims - 1; i++) {
dim_vector.push_back(b_shape.GetDim(i));
}
Shape output_shape(dim_vector);
TensorDesc td = op.GetOutputDescByName("y");
td.SetShape(output_shape);
td.SetDataType(input_dtype);
td.SetFormat(input_format);
(void)op.UpdateOutputDesc("y", td);
return GRAPH_SUCCESS;
}
}
CUST_IMPLEMT_VERIFIER(LuSolve, LuSolveVerify) {
DataType input_type_x = op.GetInputDescByName("x").GetDataType();
DataType input_type_y = op.GetInputDescByName("lu_data").GetDataType();
if (input_type_x != input_type_y) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(LuSolve, LuSolveInferShape);
CUST_VERIFY_FUNC_REG(LuSolve, LuSolveVerify);
// -----------------------LuSolve END---------------------------------
} // namespace ge

View File

@ -0,0 +1,215 @@
/*
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/math_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
// ----------------ComplexAbs-------------------
IMPLEMT_INFERFUNC(ComplexAbs, ComplexAbsInfer) {
TensorDesc x_desc = op.GetInputDescByName("x");
DataType x_type = x_desc.GetDataType();
DataType out_type;
switch (x_type) {
case DT_COMPLEX64:
out_type = DT_FLOAT;
break;
case DT_COMPLEX128:
out_type = DT_DOUBLE;
break;
default:
OP_LOGE("ComplexAbs", "Invalid input dtype: %s", DTypeStr(x_type).c_str());
return GRAPH_FAILED;
}
x_desc.SetDataType(out_type);
return op.UpdateOutputDesc("y", x_desc);
}
INFER_FUNC_REG(ComplexAbs, ComplexAbsInfer);
// ----------------ComplexAbs End-------------------
// ----------------ComplexAbs-------------------
IMPLEMT_INFERFUNC(Complex, ComplexInfer) {
bool is_dynamic_output = true;
if (!InferShapeAndTypeTwoInOneOutBroadcast(op, 0, 1, 0, is_dynamic_output)) {
return GRAPH_FAILED;
}
TensorDesc x_desc = op.GetInputDescByName("real");
DataType x_type = x_desc.GetDataType();
DataType out_type;
switch (x_type) {
case DT_FLOAT:
out_type = DT_COMPLEX64;
break;
case DT_DOUBLE:
out_type = DT_COMPLEX128;
break;
default:
OP_LOGE("Complex", "Invalid input dtype: %s", DTypeStr(x_type).c_str());
return GRAPH_FAILED;
}
TensorDesc out_desc = op.GetOutputDescByName("out");
out_desc.SetDataType(out_type);
return op.UpdateOutputDesc("out", out_desc);
}
INFER_FUNC_REG(Complex, ComplexInfer);
// ----------------ComplexAbs-------------------
// ----------------IsNan-------------------
IMPLEMT_INFERFUNC(IsNan, IsNanInfer) {
TensorDesc out_desc = op.GetOutputDescByName("y");
out_desc.SetDataType(DT_BOOL);
if (op.UpdateOutputDesc("y", out_desc) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("update output[y] failed."));
return GRAPH_FAILED;
}
return UnchangedShape(op, "x", "y");
}
INFER_FUNC_REG(IsNan, IsNanInfer);
// ----------------IsNan End-------------------
// ----------------NextAfter-------------------
IMPLEMT_INFERFUNC(NextAfter, NextAfterInfer) {
Shape x_shape = op.GetInputDescByName("x1").GetShape();
Shape y_shape = op.GetInputDescByName("x2").GetShape();
TensorDesc out_desc = op.GetOutputDescByName("output");
DataType x_type = op.GetInputDescByName("x1").GetDataType();
DataType y_type = op.GetInputDescByName("x2").GetDataType();
if (x_type != y_type) {
OP_LOGE(TbeGetName(op).c_str(), "the type of x1 is different from that of x2!");
return GRAPH_FAILED;
}
out_desc.SetDataType(x_type);
if ((!RankKnown(x_shape)) || (!RankKnown(y_shape))) {
Shape out_shape(UNKNOWN_SHAPE);
out_desc.SetShape(out_shape);
if (op.UpdateOutputDesc("output", out_desc) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "update output failed");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
const size_t rank_x = x_shape.GetDimNum();
const size_t rank_y = y_shape.GetDimNum();
const size_t rank_out = std::max(rank_x, rank_y);
// To compute the broadcast dimensions, zip together x_shape and y_shape
// and pad with 1 to make them the same length.
std::vector<int64_t> dims;
int64_t dim_one = 1;
if (rank_x != rank_y) {
OP_LOGI(TbeGetName(op).c_str(), "x1 shape is not equal to x2 shape!");
dim_one = 1;
}
for (size_t i = 0; i < rank_out; ++i) {
int64_t dim_x;
if (i < (rank_out - rank_x)) {
dim_x = dim_one;
} else {
// rank_out = rank_x or i >= rank_y - rank_x.
for (size_t j = 0; j < x_shape.GetDimNum(); ++j) {
if (x_shape.GetDim(j) == UNKNOWN_DIM) {
dim_x = UNKNOWN_DIM;
break;
}
}
if ((i - (rank_out - rank_x)) < 0) {
dim_x = x_shape.GetDim(rank_x + i - (rank_out - rank_x));
} else {
dim_x = x_shape.GetDim(i - (rank_out - rank_x));
}
}
const bool dim_y_is_one = (i < (rank_out - rank_y));
int64_t dim_y;
if (dim_y_is_one) {
dim_y = dim_one;
} else {
// rank_out = rank_y or i >= rank_x - rank_y.
for (size_t j = 0; j < y_shape.GetDimNum(); ++j) {
if (y_shape.GetDim(j) == UNKNOWN_DIM) {
dim_y = UNKNOWN_DIM;
break;
}
}
if ((i - (rank_out - rank_y)) < 0) {
dim_y = y_shape.GetDim(rank_y + i - (rank_out - rank_y));
} else {
dim_y = y_shape.GetDim(i - (rank_out - rank_y));
}
}
if ((dim_x == UNKNOWN_DIM) || (dim_y == UNKNOWN_DIM)) {
/* One or both dimensions is unknown.
* If either dimension is greater than 1, assume that the program is
* correct, and the other dimension will be broadcast to match it.
* For shape inference, if eliminate the shape checks
* in this code, assert that the unknown dim is either 1
* or the same as the known dim.
* If either dimension is 1, the other dimension is the output.
*/
if (dim_x > 1) {
dims.push_back(dim_x);
} else if (dim_y > 1) {
dims.push_back(dim_y);
} else if (dim_x == 1) {
dims.push_back(dim_y);
} else if (dim_y == 1) {
dims.push_back(dim_x);
} else if (dim_x == dim_y) {
dims.push_back(dim_x);
} else {
dims.push_back(UNKNOWN_DIM);
}
} else if ((dim_x == 1) || (dim_y == 1)) {
// dim_x is dim_one or dim_y is dim_one.
if ((dim_x == 1) && (!dim_y_is_one)) {
// broadcast dim_x to dim_y.
dims.push_back(dim_y);
} else {
if (dim_y == 1) {
// broadcast dim_y to dim_x.
dims.push_back(dim_x);
}
}
} else {
int64_t dim;
if (Merge(dim_x, dim_y, dim) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
dims.push_back(dim);
}
}
Shape out_shape(dims);
out_desc.SetShape(out_shape);
if (op.UpdateOutputDesc("output", out_desc) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "update output failed");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(NextAfter, NextAfterInfer);
// ----------------NextAfter End-------------------
// ----------------IsInf------------------------
IMPLEMT_INFERFUNC(IsInf, IsInfInfer) {
TensorDesc out_desc = op.GetOutputDescByName("y");
out_desc.SetDataType(DT_BOOL);
if (op.UpdateOutputDesc("y", out_desc) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("update output[y] failed."));
return GRAPH_FAILED;
}
return UnchangedShape(op, "x", "y");
}
INFER_FUNC_REG(IsInf, IsInfInfer);
// ----------------IsInf END------------------------
} // namespace ge

View File

@ -0,0 +1,162 @@
/*
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/matrix_calculation_ops.h"
#include "custom_op_proto/cust_math_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
IMPLEMT_COMMON_INFERFUNC(OneInOneOutCommonInferShape) {
static const int64_t input_x_idx = 0;
static const int64_t output_y_idx = 0;
if (OneInOneOutDynamicInfer(op, input_x_idx, {output_y_idx})) {
return GRAPH_SUCCESS;
}
return GRAPH_FAILED;
}
// ----------------DiagPart-------------------
IMPLEMT_COMMON_INFERFUNC(DiagPartInferShape) {
ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(op);
CHECK(op_desc == nullptr, VECTOR_INFER_SHAPE_INNER_ERR_REPORT("DiagPart", GetInputInvalidErrMsg("op_desc")),
return GRAPH_FAILED);
ge::ConstGeTensorDescPtr input_x_desc = op_desc->GetInputDescPtr(0);
CHECK(input_x_desc == nullptr, VECTOR_INFER_SHAPE_INNER_ERR_REPORT("DiagPart", GetInputInvalidErrMsg("x")),
return GRAPH_FAILED);
const GeShape &input_shape = input_x_desc->GetShape();
const size_t input_to_output_dims_times = 2;
size_t output_shape_len = input_shape.GetDimNum() / input_to_output_dims_times;
ge::GeTensorDescPtr output_desc = op_desc->MutableOutputDesc(0);
GeShape &output_shape = output_desc->MutableShape();
DataType input_dtype = input_x_desc->GetDataType();
if (input_shape.IsUnknownDimNum()) {
output_desc->SetShape(input_shape);
} else {
output_shape.SetDimNum(output_shape_len);
for (size_t i = 0; i < output_shape_len; i++) {
output_shape.SetDim(i, input_shape.GetDim(i));
}
}
if (input_shape.IsUnknownShape()) {
std::vector<std::pair<int64_t, int64_t>> shape_range;
input_x_desc->GetShapeRange(shape_range);
for (unsigned i = 0; i < shape_range.size(); i++) {
if (shape_range[i].first > 0) {
shape_range[i].first = shape_range[i].first;
}
if (shape_range[i].second > 0) {
shape_range[i].second = shape_range[i].second;
}
}
output_desc->SetShapeRange(shape_range);
}
output_desc->SetShape(output_shape);
output_desc->SetDataType(input_dtype);
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(DiagPart, DiagPartInferShape);
// ----------------DiagPart END-------------------
// ---------------Eye----------------------------
static bool CheckRows(const Operator &op, const string &attr_num_rows) {
int64_t num_rows;
op.GetAttr(attr_num_rows.c_str(), num_rows);
if (num_rows <= 0) {
return false;
}
return true;
}
static bool CheckBatchShape(const Operator &op, const string &attr_batch_shape) {
const std::string opName = TbeGetName(op);
std::vector<int64_t> batch_shape;
op.GetAttr(attr_batch_shape.c_str(), batch_shape);
for (size_t i = 0; i < batch_shape.size(); ++i) {
if (batch_shape[i] <= 0) {
OP_LOGE(opName, "the value of batch_shape less than 0.");
return false;
}
}
return true;
}
IMPLEMT_COMMON_INFERFUNC(EyeInferShape) {
TensorDesc td = op.GetOutputDescByName("y");
int64_t num_rows, num_columns;
std::vector<int64_t> batch_shape;
op.GetAttr("num_rows", num_rows);
op.GetAttr("num_columns", num_columns);
op.GetAttr("batch_shape", batch_shape);
if (!CheckRows(op, "num_rows") || !CheckBatchShape(op, "batch_shape")) {
return GRAPH_FAILED;
}
if (num_columns <= 0) {
num_columns = num_rows;
}
std::vector<int64_t> dim_vec;
for (size_t i = 0; i < batch_shape.size(); ++i) {
dim_vec.push_back(batch_shape[i]);
}
dim_vec.push_back(num_rows);
dim_vec.push_back(num_columns);
td.SetShape(ge::Shape(dim_vec));
(void)op.UpdateOutputDesc("y", td);
return GRAPH_SUCCESS;
}
IMPLEMT_VERIFIER(Eye, EyeVerify) { return GRAPH_SUCCESS; }
COMMON_INFER_FUNC_REG(Eye, EyeInferShape);
VERIFY_FUNC_REG(Eye, EyeVerify);
// --------------Eye END-------------------------------
// ----------------FillDiagonal-------------------
IMPLEMT_COMMON_INFERFUNC(FillDiagonalInferShape) {
Shape x_shape = op.GetInputDescByName("x").GetShape();
DataType x_dtype = op.GetInputDescByName("x").GetDataType();
TensorDesc td = op.GetOutputDescByName("y");
td.SetShape(ge::Shape(x_shape));
td.SetDataType(x_dtype);
(void)op.UpdateOutputDesc("y", td);
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(FillDiagonal, FillDiagonalInferShape);
// ----------------FillDiagonal END-------------------
// ----------------MatrixLogarithm--------------------
IMPLEMT_COMMON_INFERFUNC(MatrixLogarithmInferShaper) {
auto x_shape = op.GetInputDescByName("x").GetShape().GetDims();
Shape input_shape = op.GetInputDescByName("x").GetShape();
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
int64_t size_num = op.GetInputDescByName("x").GetShape().GetDimNum();
TensorDesc td = op.GetOutputDescByName("y");
td.SetShape(ge::Shape(input_shape));
td.SetDataType(input_dtype);
if (op.UpdateOutputDesc("y", td) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
if (size_num < 2) {
string err_msg = ConcatString("the input[x] should be greater than 2, but get ", size_num, ".");
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
// 注册函数
CUST_COMMON_INFER_FUNC_REG(MatrixLogarithm, MatrixLogarithmInferShaper);
// ----------------MatrixLogarithm END-------------------
// ----------------MatrixExp-------------------
CUST_COMMON_INFER_FUNC_REG(MatirxExp, OneInOneOutCommonInferShape);
// ----------------MatrixExp END-------------------
} // namespace ge

View File

@ -1,44 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/matrix_logarithm.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
// ----------------MatrixLogarithm--------------------
IMPLEMT_COMMON_INFERFUNC(MatrixLogarithmInferShaper) {
auto x_shape = op.GetInputDescByName("x").GetShape().GetDims();
Shape input_shape = op.GetInputDescByName("x").GetShape();
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
int64_t size_num = op.GetInputDescByName("x").GetShape().GetDimNum();
TensorDesc td = op.GetOutputDescByName("y");
td.SetShape(ge::Shape(input_shape));
td.SetDataType(input_dtype);
if (op.UpdateOutputDesc("y", td) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
if (size_num < 2) {
string err_msg = ConcatString("the input[x] should be greater than 2, but get ", size_num, ".");
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
// 注册函数
CUST_COMMON_INFER_FUNC_REG(MatrixLogarithm, MatrixLogarithmInferShaper);
// ----------------MatrixLogarithm END-------------------
} // namespace ge

View File

@ -1,138 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/max_pool_3d_grad_with_argmax_op.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
namespace ge {
CUST_IMPLEMT_VERIFIER(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxVerify) {
const size_t DIM_SIZE1 = 1;
const size_t DIM_SIZE3 = 3;
const size_t DIM_SIZE5 = 5;
std::vector<int32_t> ksizeList;
if (GRAPH_SUCCESS != op.GetAttr("ksize", ksizeList)) {
std::string err_msg = GetInputInvalidErrMsg("ksize");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if ((ksizeList.size() != DIM_SIZE1) && (ksizeList.size() != DIM_SIZE3)) {
string excepted_size = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3);
std::string err_msg = GetAttrSizeErrMsg("ksizeList", ConcatString(ksizeList.size()), excepted_size);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
std::vector<int32_t> stridesList;
if (GRAPH_SUCCESS != op.GetAttr("strides", stridesList)) {
std::string err_msg = GetInputInvalidErrMsg("strides");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if ((stridesList.size() != DIM_SIZE1) && (stridesList.size() != DIM_SIZE3)) {
string excepted_size = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3);
std::string err_msg = GetAttrSizeErrMsg("stridesList", ConcatString(stridesList.size()), excepted_size);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
std::vector<int32_t> padsList;
if (GRAPH_SUCCESS != op.GetAttr("pads", padsList)) {
std::string err_msg = GetInputInvalidErrMsg("pads");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if ((padsList.size() != DIM_SIZE1) && (padsList.size() != DIM_SIZE3)) {
string excepted_size = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3);
std::string err_msg = GetAttrSizeErrMsg("padsList", ConcatString(padsList.size()), excepted_size);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
std::vector<int32_t> dilationList;
if (GRAPH_SUCCESS != op.GetAttr("dilation", dilationList)) {
std::string err_msg = GetInputInvalidErrMsg("dilation");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if ((dilationList.size() != DIM_SIZE1) && (dilationList.size() != DIM_SIZE3) && (dilationList.size() != DIM_SIZE5)) {
string excepted_value = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3, " or ", DIM_SIZE5);
std::string err_msg = GetAttrSizeErrMsg("dilationList", ConcatString(dilationList.size()), excepted_value);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
bool ceilMode = false;
if (GRAPH_SUCCESS != op.GetAttr("ceil_mode", ceilMode)) {
std::string err_msg = GetInputInvalidErrMsg("ceil_mode");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
std::string data_format;
if (op.GetAttr("data_format", data_format) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "get attr data_format failed.");
return GRAPH_FAILED;
}
if (data_format != "NCDHW") {
OP_LOGE(TbeGetName(op).c_str(), "Attr data_format(%s) only support NCDHW.", data_format.c_str());
return GRAPH_FAILED;
}
int dtype = 0;
if (GRAPH_SUCCESS != op.GetAttr("dtype", dtype)) {
std::string err_msg = GetInputInvalidErrMsg("dtype");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
CHECK_PTR_NULL(op_desc, "op desc", return GRAPH_FAILED);
auto grads_desc = op_desc->MutableInputDesc("grads");
CHECK_PTR_NULL(grads_desc, "grads desc", return GRAPH_FAILED);
vector<int64_t> grads_shape = grads_desc->MutableShape().GetDims();
if (grads_shape.size() != DIM_SIZE5 && !IsUnknownRankShape(grads_shape)) {
OP_LOGE(TbeGetName(op).c_str(), "grads_shape's dim expect: %lu, but real: %lu.", DIM_SIZE5, grads_shape.size());
return GRAPH_FAILED;
}
TensorDesc inputDesc = op.GetInputDescByName("x");
vector<int64_t> inputShape = inputDesc.GetShape().GetDims();
if (inputShape.size() != DIM_SIZE5) {
OP_LOGE(TbeGetName(op).c_str(), "input x's dim expect: %lu, but real: %lu.", DIM_SIZE5, inputShape.size());
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_INFERFUNC(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxInferShape) {
auto shape = op.GetInputDescByName("x").GetShape();
auto shape_dims = shape.GetDims();
TensorDesc td = op.GetOutputDescByName("y");
td.SetShape(shape);
td.SetDataType(op.GetInputDescByName("x").GetDataType());
(void)op.UpdateOutputDesc("y", td);
return GRAPH_SUCCESS;
}
CUST_INFER_FUNC_REG(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxInferShape);
CUST_VERIFY_FUNC_REG(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxVerify);
} // namespace ge

View File

@ -1,112 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/multi_margin_loss_grad_op.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
namespace ge {
// ----------------------MultiMarginLossGrad------------------------
CUST_IMPLEMT_VERIFIER(MultiMarginLossGrad, MultiMarginLossGradVerify) { return GRAPH_SUCCESS; }
CUST_VERIFY_FUNC_REG(MultiMarginLossGrad, MultiMarginLossGradVerify);
IMPLEMT_COMMON_INFERFUNC(MultiMarginLossGradInferShape) {
Shape shape_x = op.GetInputDescByName("x").GetShape();
Shape shape_target = op.GetInputDescByName("target").GetShape();
TensorDesc tensordesc_weight;
DataType x_dtype = op.GetInputDescByName("x").GetDataType();
DataType y_grad_dtype = op.GetInputDescByName("y_grad").GetDataType();
DataType target_dtype = op.GetInputDescByName("target").GetDataType();
if (y_grad_dtype != x_dtype) {
string err_msg1 = ConcatString("dtype of input x must be the same as y_grad.");
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (x_dtype != DT_DOUBLE && x_dtype != DT_FLOAT && x_dtype != DT_FLOAT16) {
string err_msg1 = ConcatString("dtype of input x must be double, float or float16");
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (target_dtype != DT_INT64) {
string err_msg1 = ConcatString("dtype of input target must be int64.");
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (op.TryGetInputDesc("weight", tensordesc_weight) == GRAPH_SUCCESS) {
Shape shape_w = op.GetInputDescByName("weight").GetShape();
DataType weight_dtype = op.GetInputDescByName("weight").GetDataType();
if (weight_dtype != x_dtype) {
string err_msg1 = ConcatString("weight should have the same dtype with x.");
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (shape_w.GetDimNum() != 1) {
string err_msg1 = ConcatString("rank of weight must be 1, shape_weight.GetDimNum():", shape_w.GetDimNum());
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
}
if ((shape_x.GetDimNum() != 2) || (shape_target.GetDimNum() != 1)) {
string err_msg2 =
ConcatString("Rank of x must be 2, rank of target must be 1, shape_x.GetDimNum():", shape_x.GetDimNum(),
", shape_target.GetDimNum():", shape_target.GetDimNum());
std::string err_msg = OtherErrMsg(err_msg2);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (shape_x.GetDim(0) != (shape_target.GetDim(0))) {
string err_msg3 = ConcatString(
"shape[0] of x and shape[0] of target must be "
"the same, shape_x.GetDim(0):",
shape_x.GetDim(0), ", shape_target.GetDim(0):", shape_target.GetDim(0));
std::string err_msg = OtherErrMsg(err_msg3);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
std::string reduction;
op.GetAttr("reduction", reduction);
if ((reduction != "mean") && (reduction != "sum") && (reduction != "none")) {
string expected_reduction_list = ConcatString("mean, sum, none");
std::string err_msg = GetInputFormatNotSupportErrMsg("reduction", expected_reduction_list, reduction);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
int64_t p;
op.GetAttr("p", p);
if ((p != 1) && (p != 2)) {
string err_msg4 = ConcatString("The value of p must be 1 or 2, p:", p);
std::string err_msg = OtherErrMsg(err_msg4);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
TensorDesc tensordesc_output = op.GetOutputDescByName("x_grad");
Shape x_grad_shape = Shape(shape_x);
tensordesc_output.SetShape(x_grad_shape);
TensorDesc input_desc = op.GetInputDescByName("x");
tensordesc_output.SetDataType(input_desc.GetDataType());
op.UpdateOutputDesc("x_grad", tensordesc_output);
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(MultiMarginLossGrad, MultiMarginLossGradInferShape);
// ----------------------MultiMarginLossGrad END------------------------
} // namespace ge

View File

@ -1,117 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/multi_margin_loss_op.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
// ----------------MultiMarginLoss Begin-------------------
CUST_IMPLEMT_VERIFIER(MultiMarginLoss, MultiMarginLossVerify) {
Shape shape_x = op.GetInputDescByName("x").GetShape();
Shape shape_target = op.GetInputDescByName("target").GetShape();
TensorDesc tensordesc_weight;
DataType x_dtype = op.GetInputDescByName("x").GetDataType();
DataType target_dtype = op.GetInputDescByName("target").GetDataType();
if (x_dtype != DT_DOUBLE && x_dtype != DT_FLOAT && x_dtype != DT_FLOAT16) {
string err_msg1 = ConcatString("dtype of input x must be double, float or float16.");
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (target_dtype != DT_INT64) {
string err_msg1 = ConcatString("dtype of input target must be int64.");
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (op.TryGetInputDesc("weight", tensordesc_weight) == GRAPH_SUCCESS) {
Shape shape_w = op.GetInputDescByName("weight").GetShape();
DataType weight_dtype = op.GetInputDescByName("weight").GetDataType();
if (weight_dtype != x_dtype) {
string err_msg1 = ConcatString("weight should have the same dtype with x.");
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (shape_w.GetDimNum() != 1) {
string err_msg1 = ConcatString("rank of input weight must be 1, shape_weight.GetDimNum():", shape_w.GetDimNum());
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
}
if ((shape_x.GetDimNum() != 2) || (shape_target.GetDimNum() != 1)) {
string err_msg2 =
ConcatString("Rank of x must be 2, rank of target must be 1, shape_x.GetDimNum():", shape_x.GetDimNum(),
", shape_target.GetDimNum():", shape_target.GetDimNum());
std::string err_msg = OtherErrMsg(err_msg2);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (shape_x.GetDim(0) != (shape_target.GetDim(0))) {
string err_msg3 = ConcatString(
"shape[0] of x and shape[0] of target must be "
"the same, shape_x.GetDim(0):",
shape_x.GetDim(0), ", shape_target.GetDim(0):", shape_target.GetDim(0));
std::string err_msg = OtherErrMsg(err_msg3);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
std::string reduction;
op.GetAttr("reduction", reduction);
if ((reduction != "mean") && (reduction != "sum") && (reduction != "none")) {
OP_LOGE(TbeGetName(op).c_str(), "The val of reduction is invalid.");
return GRAPH_FAILED;
}
int64_t p;
op.GetAttr("p", p);
if ((p != 1) && (p != 2)) {
string err_msg4 = ConcatString("The value of p must be 1 or 2, p:", p);
std::string err_msg = OtherErrMsg(err_msg4);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(MultiMarginLossInferShape) {
auto shape_x = op.GetInputDescByName("x").GetShape().GetDims();
auto shape_target = op.GetInputDescByName("target").GetShape().GetDims();
TensorDesc tensordesc_output = op.GetOutputDescByName("y");
Shape y_shape = Shape(shape_target);
std::string reduction;
op.GetAttr("reduction", reduction);
if ((reduction == "mean") || (reduction == "sum")) {
Shape scalar_shape;
Scalar(scalar_shape);
tensordesc_output.SetShape(scalar_shape);
}
if (reduction == "none") {
tensordesc_output.SetShape(y_shape);
}
TensorDesc input_desc = op.GetInputDescByName("x");
tensordesc_output.SetDataType(input_desc.GetDataType());
tensordesc_output.SetFormat(FORMAT_ND);
op.UpdateOutputDesc("y", tensordesc_output);
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(MultiMarginLoss, MultiMarginLossInferShape);
CUST_VERIFY_FUNC_REG(MultiMarginLoss, MultiMarginLossVerify);
// ----------------MultiMarginLoss END---------------------
} // namespace ge

View File

@ -1,45 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/mvlgamma_grad_op.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
namespace ge {
// ----------------MvlgammaGrad Begin-------------------
CUST_IMPLEMT_INFERFUNC(MvlgammaGrad, MvlgammaGradInferShape) {
const char *op_name = "MvlgammaGrad";
OP_LOGD(op_name, "MvlgammaGradInferShape begin.");
TensorDesc tensordesc_input = op.GetInputDescByName("y_grad");
Shape input_shape = tensordesc_input.GetShape();
std::vector<int64_t> dims_input = input_shape.GetDims();
DataType input_dtype = tensordesc_input.GetDataType();
TensorDesc tensordesc_output1 = op.GetOutputDescByName("x_grad");
tensordesc_output1.SetDataType(input_dtype);
tensordesc_output1.SetShape(ge::Shape(dims_input));
(void)op.UpdateOutputDesc("x_grad", tensordesc_output1);
OP_LOGD(op_name, "MvlgammaGradInferShape end.");
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(MvlgammaGrad, MvlgammaGradVerify) { return GRAPH_SUCCESS; }
CUST_INFER_FUNC_REG(MvlgammaGrad, MvlgammaGradInferShape);
CUST_VERIFY_FUNC_REG(MvlgammaGrad, MvlgammaGradVerify);
// ----------------MvlgammaGrad END---------------------
} // namespace ge

View File

@ -1,45 +0,0 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/mvlgamma_op.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
namespace ge {
// ----------------Mvlgamma Begin-------------------
CUST_IMPLEMT_INFERFUNC(Mvlgamma, MvlgammaInferShape) {
const char *op_name = "Mvlgamma";
OP_LOGD(op_name, "MvlgammaInferShape begin.");
TensorDesc tensordesc_input = op.GetInputDescByName("x");
Shape input_shape = tensordesc_input.GetShape();
std::vector<int64_t> dims_input = input_shape.GetDims();
DataType input_dtype = tensordesc_input.GetDataType();
TensorDesc tensordesc_output1 = op.GetOutputDescByName("y");
tensordesc_output1.SetDataType(input_dtype);
tensordesc_output1.SetShape(ge::Shape(dims_input));
(void)op.UpdateOutputDesc("y", tensordesc_output1);
OP_LOGD(op_name, "MvlgammaInferShape end.");
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_VERIFIER(Mvlgamma, MvlgammaVerify) { return GRAPH_SUCCESS; }
CUST_INFER_FUNC_REG(Mvlgamma, MvlgammaInferShape);
CUST_VERIFY_FUNC_REG(Mvlgamma, MvlgammaVerify);
// ----------------Mvlgamma END---------------------
} // namespace ge

View File

@ -0,0 +1,235 @@
/*
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/nn_norm_ops.h"
#include "custom_op_proto/cust_nn_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
// ----------------KlDivLossGrad Begin-------------------
bool InferShapeAndTypeKlDivLossGrad(Operator &op, const string &input_name, const string &output_name) {
TensorDesc output_desc = op.GetOutputDescByName(output_name.c_str());
DataType input_dtype = op.GetInputDescByName(input_name.c_str()).GetDataType();
Format input_format =
static_cast<ge::Format>(ge::GetPrimaryFormat(op.GetInputDescByName(input_name.c_str()).GetFormat()));
ge::Shape input_shape = op.GetInputDescByName(input_name.c_str()).GetShape();
output_desc.SetShape(input_shape);
output_desc.SetDataType(input_dtype);
output_desc.SetFormat(input_format);
op.UpdateOutputDesc(output_name.c_str(), output_desc);
return true;
}
IMPLEMT_COMMON_INFERFUNC(KlDivLossGradInferShape) {
if (InferShapeAndTypeKlDivLossGrad(op, "input", "y")) {
return GRAPH_SUCCESS;
}
OP_LOGE(TbeGetName(op).c_str(), "KL_DIV_LOSS_GRAD Infershape Failed");
return GRAPH_FAILED;
}
IMPLEMT_VERIFIER(KlDivLossGrad, KlDivLossGradVerify) {
if (op.GetInputDescByName("grad").GetDataType() != op.GetInputDescByName("input").GetDataType() ||
op.GetInputDescByName("input").GetDataType() != op.GetInputDescByName("target").GetDataType()) {
OP_LOGE(TbeGetName(op).c_str(), "grad type is not same with input or target");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(KlDivLossGrad, KlDivLossGradInferShape);
VERIFY_FUNC_REG(KlDivLossGrad, KlDivLossGradVerify);
// ----------------KlDivLossGrad END---------------------
// ----------------MultiMarginLoss Begin-------------------
CUST_IMPLEMT_VERIFIER(MultiMarginLoss, MultiMarginLossVerify) {
Shape shape_x = op.GetInputDescByName("x").GetShape();
Shape shape_target = op.GetInputDescByName("target").GetShape();
TensorDesc tensordesc_weight;
DataType x_dtype = op.GetInputDescByName("x").GetDataType();
DataType target_dtype = op.GetInputDescByName("target").GetDataType();
if (x_dtype != DT_DOUBLE && x_dtype != DT_FLOAT && x_dtype != DT_FLOAT16) {
string err_msg1 = ConcatString("dtype of input x must be double, float or float16.");
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (target_dtype != DT_INT64) {
string err_msg1 = ConcatString("dtype of input target must be int64.");
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (op.TryGetInputDesc("weight", tensordesc_weight) == GRAPH_SUCCESS) {
Shape shape_w = op.GetInputDescByName("weight").GetShape();
DataType weight_dtype = op.GetInputDescByName("weight").GetDataType();
if (weight_dtype != x_dtype) {
string err_msg1 = ConcatString("weight should have the same dtype with x.");
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (shape_w.GetDimNum() != 1) {
string err_msg1 = ConcatString("rank of input weight must be 1, shape_weight.GetDimNum():", shape_w.GetDimNum());
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
}
if ((shape_x.GetDimNum() != 2) || (shape_target.GetDimNum() != 1)) {
string err_msg2 =
ConcatString("Rank of x must be 2, rank of target must be 1, shape_x.GetDimNum():", shape_x.GetDimNum(),
", shape_target.GetDimNum():", shape_target.GetDimNum());
std::string err_msg = OtherErrMsg(err_msg2);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (shape_x.GetDim(0) != (shape_target.GetDim(0))) {
string err_msg3 = ConcatString(
"shape[0] of x and shape[0] of target must be "
"the same, shape_x.GetDim(0):",
shape_x.GetDim(0), ", shape_target.GetDim(0):", shape_target.GetDim(0));
std::string err_msg = OtherErrMsg(err_msg3);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
std::string reduction;
op.GetAttr("reduction", reduction);
if ((reduction != "mean") && (reduction != "sum") && (reduction != "none")) {
OP_LOGE(TbeGetName(op).c_str(), "The val of reduction is invalid.");
return GRAPH_FAILED;
}
int64_t p;
op.GetAttr("p", p);
if ((p != 1) && (p != 2)) {
string err_msg4 = ConcatString("The value of p must be 1 or 2, p:", p);
std::string err_msg = OtherErrMsg(err_msg4);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(MultiMarginLossInferShape) {
auto shape_x = op.GetInputDescByName("x").GetShape().GetDims();
auto shape_target = op.GetInputDescByName("target").GetShape().GetDims();
TensorDesc tensordesc_output = op.GetOutputDescByName("y");
Shape y_shape = Shape(shape_target);
std::string reduction;
op.GetAttr("reduction", reduction);
if ((reduction == "mean") || (reduction == "sum")) {
Shape scalar_shape;
Scalar(scalar_shape);
tensordesc_output.SetShape(scalar_shape);
}
if (reduction == "none") {
tensordesc_output.SetShape(y_shape);
}
TensorDesc input_desc = op.GetInputDescByName("x");
tensordesc_output.SetDataType(input_desc.GetDataType());
tensordesc_output.SetFormat(FORMAT_ND);
op.UpdateOutputDesc("y", tensordesc_output);
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(MultiMarginLoss, MultiMarginLossInferShape);
CUST_VERIFY_FUNC_REG(MultiMarginLoss, MultiMarginLossVerify);
// ----------------MultiMarginLoss END---------------------
// ----------------------MultiMarginLossGrad------------------------
CUST_IMPLEMT_VERIFIER(MultiMarginLossGrad, MultiMarginLossGradVerify) { return GRAPH_SUCCESS; }
CUST_VERIFY_FUNC_REG(MultiMarginLossGrad, MultiMarginLossGradVerify);
IMPLEMT_COMMON_INFERFUNC(MultiMarginLossGradInferShape) {
Shape shape_x = op.GetInputDescByName("x").GetShape();
Shape shape_target = op.GetInputDescByName("target").GetShape();
TensorDesc tensordesc_weight;
DataType x_dtype = op.GetInputDescByName("x").GetDataType();
DataType y_grad_dtype = op.GetInputDescByName("y_grad").GetDataType();
DataType target_dtype = op.GetInputDescByName("target").GetDataType();
if (y_grad_dtype != x_dtype) {
string err_msg1 = ConcatString("dtype of input x must be the same as y_grad.");
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (x_dtype != DT_DOUBLE && x_dtype != DT_FLOAT && x_dtype != DT_FLOAT16) {
string err_msg1 = ConcatString("dtype of input x must be double, float or float16");
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (target_dtype != DT_INT64) {
string err_msg1 = ConcatString("dtype of input target must be int64.");
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (op.TryGetInputDesc("weight", tensordesc_weight) == GRAPH_SUCCESS) {
Shape shape_w = op.GetInputDescByName("weight").GetShape();
DataType weight_dtype = op.GetInputDescByName("weight").GetDataType();
if (weight_dtype != x_dtype) {
string err_msg1 = ConcatString("weight should have the same dtype with x.");
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (shape_w.GetDimNum() != 1) {
string err_msg1 = ConcatString("rank of weight must be 1, shape_weight.GetDimNum():", shape_w.GetDimNum());
std::string err_msg = OtherErrMsg(err_msg1);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
}
if ((shape_x.GetDimNum() != 2) || (shape_target.GetDimNum() != 1)) {
string err_msg2 =
ConcatString("Rank of x must be 2, rank of target must be 1, shape_x.GetDimNum():", shape_x.GetDimNum(),
", shape_target.GetDimNum():", shape_target.GetDimNum());
std::string err_msg = OtherErrMsg(err_msg2);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if (shape_x.GetDim(0) != (shape_target.GetDim(0))) {
string err_msg3 = ConcatString(
"shape[0] of x and shape[0] of target must be "
"the same, shape_x.GetDim(0):",
shape_x.GetDim(0), ", shape_target.GetDim(0):", shape_target.GetDim(0));
std::string err_msg = OtherErrMsg(err_msg3);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
std::string reduction;
op.GetAttr("reduction", reduction);
if ((reduction != "mean") && (reduction != "sum") && (reduction != "none")) {
string expected_reduction_list = ConcatString("mean, sum, none");
std::string err_msg = GetInputFormatNotSupportErrMsg("reduction", expected_reduction_list, reduction);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
int64_t p;
op.GetAttr("p", p);
if ((p != 1) && (p != 2)) {
string err_msg4 = ConcatString("The value of p must be 1 or 2, p:", p);
std::string err_msg = OtherErrMsg(err_msg4);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
TensorDesc tensordesc_output = op.GetOutputDescByName("x_grad");
Shape x_grad_shape = Shape(shape_x);
tensordesc_output.SetShape(x_grad_shape);
TensorDesc input_desc = op.GetInputDescByName("x");
tensordesc_output.SetDataType(input_desc.GetDataType());
op.UpdateOutputDesc("x_grad", tensordesc_output);
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(MultiMarginLossGrad, MultiMarginLossGradInferShape);
// ----------------------MultiMarginLossGrad END------------------------
} // namespace ge

View File

@ -0,0 +1,287 @@
/*
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/nn_pooling_ops.h"
#include "custom_op_proto/cust_nn_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
// -------------------DataFormatVecPermute---------------------
IMPLEMT_INFERFUNC(DataFormatVecPermute, DataFormatVecPermuteInfer) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto x_desc = op_desc->MutableInputDesc(0);
std::vector<std::pair<int64_t, int64_t>> range;
if (x_desc->GetShapeRange(range) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
DataType y_type = x_desc->GetDataType();
auto y_desc = op_desc->MutableOutputDesc(0);
y_desc->SetShape(x_desc->GetShape());
y_desc->SetShapeRange(range);
y_desc->SetDataType(y_type);
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(DataFormatVecPermute, DataFormatVecPermuteInfer);
// -------------------DataFormatVecPermute End---------------------
// -------------------MaxPool3DWithArgmax---------------------
IMPLEMT_INFERFUNC(MaxPool3DWithArgmax, MaxPool3DWithArgmaxInferShape) {
TensorDesc inputDesc = op.GetInputDescByName("x");
auto inputShape = inputDesc.GetShape().GetDims();
DataType inputDtype = inputDesc.GetDataType();
TensorDesc argmaxDesc = op.GetOutputDescByName("argmax");
TensorDesc outputDesc = op.GetOutputDescByName("y");
std::vector<int64_t> stridesList;
op.GetAttr("strides", stridesList);
std::vector<int64_t> kernelList;
op.GetAttr("ksize", kernelList);
int64_t dOut = (inputShape[1] - kernelList[2]) / stridesList[2] + 1;
int64_t hOut = (inputShape[3] - kernelList[3]) / stridesList[3] + 1;
int64_t wOut = (inputShape[4] - kernelList[4]) / stridesList[4] + 1;
int64_t alignedBmLine;
alignedBmLine = (wOut * hOut % 16 == 0) ? (wOut * hOut) : (((int64_t)(wOut * hOut / 16) + 1) * 16);
std::vector<int64_t> argShapeVec;
argShapeVec.push_back(inputShape[0]);
argShapeVec.push_back(dOut);
argShapeVec.push_back(inputShape[2] * kernelList[2] * kernelList[3] * kernelList[4]);
argShapeVec.push_back((int64_t)(alignedBmLine / 16));
argShapeVec.push_back(inputShape[5]);
Shape argmaxShape(argShapeVec);
argmaxDesc.SetShape(argmaxShape);
argmaxDesc.SetDataType(DT_UINT16);
(void)op.UpdateOutputDesc("argmax", argmaxDesc);
std::vector<int64_t> outShapeVec{inputShape[0], dOut, inputShape[2], hOut, wOut, inputShape[5]};
Shape outputShape(outShapeVec);
outputDesc.SetShape(outputShape);
outputDesc.SetDataType(inputDtype);
(void)op.UpdateOutputDesc("y", outputDesc);
return GRAPH_SUCCESS;
}
IMPLEMT_VERIFIER(MaxPool3DWithArgmax, MaxPool3DWithArgmaxVerify) {
// verify in infer func
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(MaxPool3DWithArgmax, MaxPool3DWithArgmaxInferShape);
VERIFY_FUNC_REG(MaxPool3DWithArgmax, MaxPool3DWithArgmaxVerify);
//-------------------MaxPool3DWithArgmax---------------------
//-------------------FractionalMaxPool---------------------
IMPLEMT_INFERFUNC(FractionalMaxPool, FractionalMaxPoolInfer) {
auto tensor = op.get_input_desc_x();
Shape input_value;
if (WithRank(tensor, 4, input_value, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
TbeGetName(op),
ConcatString("call WithRank failed, ", GetShapeErrMsg(0, DebugString(tensor.GetShape().GetDims()), "4D")));
return GRAPH_FAILED;
}
std::vector<float> pooling_ratio;
pooling_ratio = op.get_attr_pooling_ratio();
if (pooling_ratio.size() != 4) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(
TbeGetName(op), GetAttrSizeErrMsg("pooling_ratio", DebugString(tensor.GetShape().GetDims()), "4D"));
return GRAPH_FAILED;
}
std::vector<int64_t> output_dims;
for (int i = 0; i < 4; ++i) {
int64_t dim = input_value.GetDim(i);
if (dim != UNKNOWN_DIM) {
auto real_dim = static_cast<int64_t>(dim / pooling_ratio[i]);
if (real_dim < 0) {
string err_msg = ConcatString("size computed for ", i, "th dim of output[y] is ", real_dim);
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
output_dims.push_back(real_dim);
} else {
output_dims.push_back(UNKNOWN_DIM);
}
}
TensorDesc y_desc = op.GetOutputDescByName("y");
y_desc.SetShape(Shape(output_dims));
y_desc.SetDataType(op.GetInputDescByName("x").GetDataType());
if (op.UpdateOutputDesc("y", y_desc) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("update output[y] desc failed."));
return GRAPH_FAILED;
}
TensorDesc row_pooling_desc = op.GetOutputDescByName("row_pooling_sequence");
row_pooling_desc.SetShape(Shape({output_dims[1] + 1}));
row_pooling_desc.SetDataType(DT_INT64);
if (op.UpdateOutputDesc("row_pooling_sequence", row_pooling_desc) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("update output[row_pooling_sequence] desc failed."));
return GRAPH_FAILED;
}
TensorDesc col_pooling_desc = op.GetOutputDescByName("col_pooling_sequence");
col_pooling_desc.SetShape(Shape({output_dims[2] + 1}));
col_pooling_desc.SetDataType(DT_INT64);
if (op.UpdateOutputDesc("col_pooling_sequence", col_pooling_desc) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("update output[col_pooling_sequence] desc failed."));
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(FractionalMaxPool, FractionalMaxPoolInfer);
//-------------------FractionalMaxPool END---------------------
//-------------------FractionalMaxPoolGrad---------------------
IMPLEMT_INFERFUNC(FractionalMaxPoolGrad, FractionalMaxPoolGradInfer) {
Shape input_shape;
if (WithRank(op.GetInputDesc(0), 4, input_shape, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
TbeGetName(op), ConcatString("call WithRank failed, ",
GetShapeErrMsg(0, DebugString(op.GetInputDesc(0).GetShape().GetDims()), "4D")));
return GRAPH_FAILED;
}
auto type = op.GetInputDescByName("orig_input").GetDataType();
TensorDesc output_desc = op.GetOutputDescByName("y");
output_desc.SetShape(Shape(input_shape));
output_desc.SetDataType(type);
if (op.UpdateOutputDesc("y", output_desc) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("update output[y] desc failed"));
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(FractionalMaxPoolGrad, FractionalMaxPoolGradInfer);
//-------------------FractionalMaxPoolGrad END---------------------
//-------------------MaxPool3DGradWithArgMax---------------------
CUST_IMPLEMT_VERIFIER(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxVerify) {
const size_t DIM_SIZE1 = 1;
const size_t DIM_SIZE3 = 3;
const size_t DIM_SIZE5 = 5;
std::vector<int32_t> ksizeList;
if (GRAPH_SUCCESS != op.GetAttr("ksize", ksizeList)) {
std::string err_msg = GetInputInvalidErrMsg("ksize");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if ((ksizeList.size() != DIM_SIZE1) && (ksizeList.size() != DIM_SIZE3)) {
string excepted_size = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3);
std::string err_msg = GetAttrSizeErrMsg("ksizeList", ConcatString(ksizeList.size()), excepted_size);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
std::vector<int32_t> stridesList;
if (GRAPH_SUCCESS != op.GetAttr("strides", stridesList)) {
std::string err_msg = GetInputInvalidErrMsg("strides");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if ((stridesList.size() != DIM_SIZE1) && (stridesList.size() != DIM_SIZE3)) {
string excepted_size = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3);
std::string err_msg = GetAttrSizeErrMsg("stridesList", ConcatString(stridesList.size()), excepted_size);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
std::vector<int32_t> padsList;
if (GRAPH_SUCCESS != op.GetAttr("pads", padsList)) {
std::string err_msg = GetInputInvalidErrMsg("pads");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if ((padsList.size() != DIM_SIZE1) && (padsList.size() != DIM_SIZE3)) {
string excepted_size = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3);
std::string err_msg = GetAttrSizeErrMsg("padsList", ConcatString(padsList.size()), excepted_size);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
std::vector<int32_t> dilationList;
if (GRAPH_SUCCESS != op.GetAttr("dilation", dilationList)) {
std::string err_msg = GetInputInvalidErrMsg("dilation");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
if ((dilationList.size() != DIM_SIZE1) && (dilationList.size() != DIM_SIZE3) && (dilationList.size() != DIM_SIZE5)) {
string excepted_value = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3, " or ", DIM_SIZE5);
std::string err_msg = GetAttrSizeErrMsg("dilationList", ConcatString(dilationList.size()), excepted_value);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
bool ceilMode = false;
if (GRAPH_SUCCESS != op.GetAttr("ceil_mode", ceilMode)) {
std::string err_msg = GetInputInvalidErrMsg("ceil_mode");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
std::string data_format;
if (op.GetAttr("data_format", data_format) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "get attr data_format failed.");
return GRAPH_FAILED;
}
if (data_format != "NCDHW") {
OP_LOGE(TbeGetName(op).c_str(), "Attr data_format(%s) only support NCDHW.", data_format.c_str());
return GRAPH_FAILED;
}
int dtype = 0;
if (GRAPH_SUCCESS != op.GetAttr("dtype", dtype)) {
std::string err_msg = GetInputInvalidErrMsg("dtype");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
return GRAPH_FAILED;
}
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
CHECK_PTR_NULL(op_desc, "op desc", return GRAPH_FAILED);
auto grads_desc = op_desc->MutableInputDesc("grads");
CHECK_PTR_NULL(grads_desc, "grads desc", return GRAPH_FAILED);
vector<int64_t> grads_shape = grads_desc->MutableShape().GetDims();
if (grads_shape.size() != DIM_SIZE5 && !IsUnknownRankShape(grads_shape)) {
OP_LOGE(TbeGetName(op).c_str(), "grads_shape's dim expect: %lu, but real: %lu.", DIM_SIZE5, grads_shape.size());
return GRAPH_FAILED;
}
TensorDesc inputDesc = op.GetInputDescByName("x");
vector<int64_t> inputShape = inputDesc.GetShape().GetDims();
if (inputShape.size() != DIM_SIZE5) {
OP_LOGE(TbeGetName(op).c_str(), "input x's dim expect: %lu, but real: %lu.", DIM_SIZE5, inputShape.size());
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_IMPLEMT_INFERFUNC(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxInferShape) {
auto shape = op.GetInputDescByName("x").GetShape();
auto shape_dims = shape.GetDims();
TensorDesc td = op.GetOutputDescByName("y");
td.SetShape(shape);
td.SetDataType(op.GetInputDescByName("x").GetDataType());
(void)op.UpdateOutputDesc("y", td);
return GRAPH_SUCCESS;
}
CUST_INFER_FUNC_REG(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxInferShape);
CUST_VERIFY_FUNC_REG(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxVerify);
//-------------------MaxPool3DGradWithArgMax---------------------
} // namespace ge

View File

@ -0,0 +1,94 @@
/**
* Copyright (c) 2023 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/random_ops.h"
#include "inc/ops/stateful_random_ops.h"
#include "custom_op_proto/cust_random_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
IMPLEMT_INFERFUNC(NonDeterministicInts, NonDeterministicIntsInfer) {
Shape shape;
Tensor shape_tensor;
if (op.GetInputConstData("shape", shape_tensor) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Get shape_tensor error.");
return GRAPH_FAILED;
}
if (MakeShapeFromShapeTensor(shape_tensor, shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Get shape error.");
return GRAPH_FAILED;
}
DataType dtype;
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Get attr dtype error.");
return GRAPH_FAILED;
}
TensorDesc outputDesc = op.GetOutputDescByName("y");
outputDesc.SetDataType(dtype);
outputDesc.SetShape(shape);
return op.UpdateOutputDesc("y", outputDesc);
}
INFER_FUNC_REG(NonDeterministicInts, NonDeterministicIntsInfer);
// ----------------LogNormalReverse-------------------
// Obtains the processing function of the output tensor description.
IMPLEMT_COMMON_INFERFUNC(LogNormalReverseInferShape) {
TensorDesc v_output_desc = op.GetOutputDescByName("y");
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
Format input_format = op.GetInputDescByName("x").GetFormat();
ge::Shape shape_input = op.GetInputDescByName("x").GetShape();
v_output_desc.SetShape(shape_input);
v_output_desc.SetDataType(input_dtype);
v_output_desc.SetFormat(input_format);
if (op.UpdateOutputDesc("y", v_output_desc) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(LogNormalReverse, LogNormalReverseInferShape);
// ----------------LogNormalReverse END-------------------
// ----------------Dropout2D-------------------
IMPLEMT_COMMON_INFERFUNC(Dropout2DInferShape) {
TensorDesc output_desc = op.GetOutputDescByName("output");
TensorDesc mask_desc = op.GetOutputDescByName("mask");
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
ge::Shape shape_input = op.GetInputDescByName("x").GetShape();
output_desc.SetShape(shape_input);
output_desc.SetDataType(input_dtype);
mask_desc.SetShape(shape_input);
mask_desc.SetDataType(DT_BOOL);
if (op.UpdateOutputDesc("output", output_desc) != GRAPH_SUCCESS ||
op.UpdateOutputDesc("mask", mask_desc) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
CUST_COMMON_INFER_FUNC_REG(Dropout2D, Dropout2DInferShape);
// ----------------Dropout2D END-------------------
} // namespace ge

View File

@ -0,0 +1,82 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime_util.h"
#include "utils/op_util.h"
#include "utils/op_log.h"
#include "utils/op_const.h"
using namespace ge;
namespace ops {
// -------------------Diag Ops START---------------------
static constexpr size_t DIAG_IN_X_IDX = 0;
static constexpr size_t DIAG_OUT_Y_IDX = 0;
static constexpr size_t INT_DATA_2 = 2;
ge::graphStatus Infershape4Diag(gert::InferShapeContext *context) {
OP_LOGD(context->GetNodeName(), "Begin to do DiagInfershape.");
const gert::Shape *input_x_shape = context->GetInputShape(DIAG_IN_X_IDX);
OPS_CHECK_NULL_WITH_CONTEXT(context, input_x_shape);
gert::Shape *output_y_shape = context->GetOutputShape(DIAG_OUT_Y_IDX);
OPS_CHECK_NULL_WITH_CONTEXT(context, output_y_shape);
size_t x_dim_num = input_x_shape->GetDimNum();
output_y_shape->SetDimNum(0);
for (size_t i = 0; i < INT_DATA_2; i++) {
for (size_t j = 0; j < x_dim_num; j++) {
output_y_shape->AppendDim(input_x_shape->GetDim(j));
}
}
OP_LOGD(context->GetNodeName(), "output_y_shape = %s.", ToString(*output_y_shape).c_str());
OP_LOGD(context->GetNodeName(), "End to do DiagInfershape.");
return ge::GRAPH_SUCCESS;
}
IMPL_OP(Diag).InferShape(Infershape4Diag);
// -------------------Diag Ops END---------------------
// -------------------DiagPart Ops START---------------------
ge::graphStatus Infershape4DiagPart(gert::InferShapeContext *context) {
OP_LOGD(context->GetNodeName(), "Begin to do DiagPartInfershape.");
const gert::Shape *input_x_shape = context->GetInputShape(DIAG_IN_X_IDX);
OPS_CHECK_NULL_WITH_CONTEXT(context, input_x_shape);
int64_t input_to_output_dims_times = 2;
OP_CHECK(input_x_shape->GetDimNum() % 2 != 0,
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(context->GetNodeName(),
"input_x_shape->GetDimNum() % 2 != 0 is not supported."),
return ge::GRAPH_FAILED);
int64_t output_shape_len = input_x_shape->GetDimNum() / input_to_output_dims_times;
gert::Shape *output_y_shape = context->GetOutputShape(DIAG_OUT_Y_IDX);
OPS_CHECK_NULL_WITH_CONTEXT(context, output_y_shape);
output_y_shape->SetDimNum(output_shape_len);
for (int64_t i = 0; i < output_shape_len; i++) {
output_y_shape->SetDim(i, input_x_shape->GetDim(i));
}
OP_LOGD(context->GetNodeName(), "output_y_shape = %s.", ToString(*output_y_shape).c_str());
OP_LOGD(context->GetNodeName(), "End to do DiagPartInfershape.");
return ge::GRAPH_SUCCESS;
}
IMPL_OP(DiagPart).InferShape(Infershape4DiagPart);
// -------------------DiagPart Ops END---------------------
} // namespace ops

View File

@ -0,0 +1,165 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime_util.h"
#include "utils/op_util.h"
using namespace ge;
namespace ops {
ge::graphStatus InferShape4Elewise(gert::InferShapeContext *context) {
auto in_shape = context->GetInputShape(0);
OPS_CHECK_NULL_WITH_CONTEXT(context, in_shape);
auto out_shape = context->GetOutputShape(0);
OPS_CHECK_NULL_WITH_CONTEXT(context, out_shape);
if (IsUnknownRank(in_shape)) {
OP_LOGD(context->GetNodeName(), "input shape is UnknownRank, set output shape to (-2, )");
return SetUnknownRank(out_shape);
}
*out_shape = *in_shape;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CopyShapeInput2OutputWithIdx(gert::InferShapeContext *context, int64_t input_idx, int64_t output_idx) {
auto in_shape = context->GetInputShape(input_idx);
OPS_CHECK_NULL_WITH_CONTEXT(context, in_shape);
auto out_shape = context->GetOutputShape(output_idx);
OPS_CHECK_NULL_WITH_CONTEXT(context, out_shape);
*out_shape = *in_shape;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus InferShape4InIdxAndOutVector(gert::InferShapeContext *context, int64_t input_idx,
const std::vector<int64_t> &output_idxs) {
auto in_shape = context->GetInputShape(input_idx);
OPS_CHECK_NULL_WITH_CONTEXT(context, in_shape);
for (int64_t idx : output_idxs) {
auto out_shape = context->GetOutputShape(idx);
OPS_CHECK_NULL_WITH_CONTEXT(context, out_shape);
*out_shape = *in_shape;
}
return ge::GRAPH_SUCCESS;
}
std::string ShapeCannotBroadcastMsg(const gert::Shape &shape1, const gert::Shape &shape2) {
std::string res = "shape ";
res += ToString(shape1);
res += " and ";
res += ToString(shape2);
res += " cannot broadcast!";
return res;
}
static bool BroadcastDim(int64_t &dim1, const int64_t dim2) {
if (dim1 == dim2) {
return true;
}
/* column is dim1, row is dim2, matrix value is broadcast(dim1, dim2)
dim 0 1 d2
0 0 0 E
1 0 1 d2
d1 E d1 E
*/
if ((dim1 != 1) && (dim2 != 1)) {
string msg = ConcatString(dim1, " and ", dim2, " cannot broadcast!");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastDim", msg);
return false;
}
dim1 = (dim1 == 1) ? dim2 : dim1;
return true;
}
/*
* @brief: broadcast new shape to output shape
* @param [in] shape: const gert::Shape*, new shape to broadcast
* @param [in/out] shape_output: gert::Shape*, output shape
* @return succeed or not
*/
static bool BroadcastShapeToOutShape(const gert::Shape *shape, gert::Shape *shape_output) {
OP_LOGD("BroadcastShapeToOutShape", "start broadcast %s to %s!", ToString(*shape).c_str(),
ToString(*shape_output).c_str());
size_t shape_len = shape->GetDimNum();
size_t shape_y_len = shape_output->GetDimNum();
if (shape_len > shape_y_len) {
shape_output->SetDimNum(shape_len);
size_t len_sub = shape_len - shape_y_len;
for (size_t i = shape_y_len; i > 0; i--) {
int64_t dim1 = shape->GetDim(len_sub + i - 1);
int64_t dim2 = shape_output->GetDim(i - 1);
if (!BroadcastDim(dim1, dim2)) {
string msg = ConcatString(dim1, " and ", dim2, " cannot broadcast!");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastShapeToOutShape", msg);
return false;
}
shape_output->SetDim(len_sub + i - 1, dim1);
}
for (size_t i = 0; i < len_sub; i++) {
shape_output->SetDim(i, shape->GetDim(i));
}
} else {
auto len_sub = shape_y_len - shape_len;
for (size_t i = 0; i < shape_len; i++) {
int64_t dim1 = shape_output->GetDim(len_sub + i);
int64_t dim2 = shape->GetDim(i);
if (!BroadcastDim(dim1, dim2)) {
string msg = ConcatString(dim1, " and ", dim2, " cannot broadcast!");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastShapeToOutShape", msg);
return false;
}
shape_output->SetDim(len_sub + i, dim1);
}
}
return true;
}
bool BroadcastShape(const gert::Shape *in1_shape, const gert::Shape *in2_shape, gert::Shape *out_shape) {
*out_shape = *in1_shape;
OP_CHECK(!BroadcastShapeToOutShape(in2_shape, out_shape),
VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastShape", ShapeCannotBroadcastMsg(*in2_shape, *in1_shape)),
return false);
return true;
}
bool BroadcastShape(const std::vector<const gert::Shape *> &in_shapes, gert::Shape *out_shape) {
size_t size = in_shapes.size();
OP_CHECK(size == 0, VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastShape", "in_shapes is empty!"), return false);
*out_shape = *in_shapes[0];
for (size_t i = 1; i < size; i++) {
OP_CHECK(!BroadcastShapeToOutShape(in_shapes[i], out_shape),
VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastShape", ShapeCannotBroadcastMsg(*in_shapes[i], *out_shape)),
return false);
}
return true;
}
bool BroadcastShape(const gert::Shape **in_shapes, size_t size, gert::Shape *out_shape) {
OP_CHECK(size == 0, VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastShape", "in_shapes is empty!"), return false);
*out_shape = *in_shapes[0];
for (size_t i = 1; i < size; i++) {
OP_CHECK(!BroadcastShapeToOutShape(in_shapes[i], out_shape),
VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastShape", ShapeCannotBroadcastMsg(*in_shapes[i], *out_shape)),
return false);
}
return true;
}
} // namespace ops

View File

@ -0,0 +1,114 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file runtime_util.h
* \brief
*/
#ifndef CUSTOMIZE_OP_PROTO_RUNTIME_RUNTIME_UTIL_H_
#define CUSTOMIZE_OP_PROTO_RUNTIME_RUNTIME_UTIL_H_
#include "utils/context_util.h"
#include "register/op_impl_registry.h"
#include "runtime/continuous_vector.h"
#include "runtime/infer_shape_context.h"
#include "runtime/storage_shape.h"
#include "error_util.h"
#include "op_util.h"
namespace ops {
using QuickVector = gert::Shape;
constexpr int64_t UNKNOWN_DIM_VALUE_ = -1;
constexpr int64_t UNKNOWN_RANK_DIM_VALUE_ = -2;
// Do infershape for OP which is single-input single-output and in-shape equal out-shape.
ge::graphStatus InferShape4Elewise(gert::InferShapeContext *context);
/*
* @brief: get output shape
* @param [in] context: gert::InferShapeContext
* @param [in] input_idx: constvalue input index
* @param [in] output_idx: constvalue output index
* @return vector<int64_t>: success or failed
*/
ge::graphStatus CopyShapeInput2OutputWithIdx(gert::InferShapeContext *context, int64_t input_idx, int64_t output_idx);
/*
* @brief: get output shape
* @param [in] context: gert::InferShapeContext
* @param [in] input_idx: constvalue input index
* @param [in] output_idxs: constvalue output indexes,vector<int64_t>
* @return graphStatus: success or failed
*/
ge::graphStatus InferShape4InIdxAndOutVector(gert::InferShapeContext *context, int64_t input_idx,
const std::vector<int64_t> &output_idxs);
std::string ShapeCannotBroadcastMsg(const gert::Shape &shape1, const gert::Shape &shape2);
/*
* @brief: broadcast new shape to output shape
* @param [in] shape: const gert::Shape*, new shape to broadcast
* @param [in/out] shape_output: gert::Shape*, output shape
* @return succeed or not
*/
bool BroadcastShape(const gert::Shape *in1_shape, const gert::Shape *in2_shape, gert::Shape *out_shape);
bool BroadcastShape(const std::vector<const gert::Shape *> &in_shapes, gert::Shape *out_shape);
bool BroadcastShape(const gert::Shape **in_shapes, size_t size, gert::Shape *out_shape);
/*
* @brief: set all the output shape to [-1, -1, ....] with input rank
* @param [in] rank: the output input rank
* @param [out] output_shape: the output shape ptr
* @return ge::graphStatus
*/
inline ge::graphStatus SetAllUnknownDim(const int64_t rank, gert::Shape *output_shape) {
OP_CHECK(output_shape == nullptr, OP_LOGD("SetAllUnknownDim", "the output_shape is nullptr, return failed"),
return ge::GRAPH_FAILED);
output_shape->SetDimNum(rank);
for (int64_t i = 0; i < rank; ++i) {
output_shape->SetDim(i, UNKNOWN_DIM_VALUE_);
}
OP_LOGD("SetAllUnknownDim", "set all dim = -1, output = %s", ToString(*output_shape).c_str());
return ge::GRAPH_SUCCESS;
}
/*
* @brief: set output shape to [-2]
* @param [out] output_shape: the output shape ptr
* @return ge::graphStatus
*/
inline ge::graphStatus SetUnknownRank(gert::Shape *output_shape) {
OP_CHECK(output_shape == nullptr, OP_LOGD("SetUnknownRank", "the output_shape is nullptr, return failed"),
return ge::GRAPH_FAILED);
output_shape->SetDimNum(0);
output_shape->AppendDim(UNKNOWN_RANK_DIM_VALUE_);
OP_LOGD("SetUnknownRank", "set unknown rank = -2, output = %s", ToString(*output_shape).c_str());
return ge::GRAPH_SUCCESS;
}
/*
* @brief: check whether the output shape is unknown rank
* @param [out] output_shape: the output shape ptr
* @return ge::graphStatus
*/
inline bool IsUnknownRank(const gert::Shape *check_shape) {
return check_shape->GetDimNum() == 1 && check_shape->GetDim(0) == UNKNOWN_RANK_DIM_VALUE_;
}
} // namespace ops
#endif // CUSTOMIZE_OP_PROTO_RUNTIME_RUNTIME_UTIL_H_

View File

@ -0,0 +1,165 @@
/*
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/selection_ops.h"
#include "custom_op_proto/cust_array_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
// ----------------CumulativeLogsumexp-------------------
IMPLEMT_COMMON_INFERFUNC(CumulativeLogsumexpInferShape) {
TensorDesc output_desc = op.GetOutputDescByName("y");
output_desc.SetShape(op.GetInputDescByName("x").GetShape());
output_desc.SetDataType(op.GetInputDescByName("x").GetDataType());
op.UpdateOutputDesc("y", output_desc);
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(CumulativeLogsumexp, CumulativeLogsumexpInferShape);
// ----------------CumulativeLogsumexp END-------------------
// ----------------GatherNd-------------------
bool CheckGatherNdInputIndicesSize(const Operator &op, const string &input_name) {
auto indices_shape = OpDescUtils::GetOpDescFromOperator(op)->MutableInputDesc("indices")->GetShape();
auto indices_shape_size = indices_shape.GetDimNum();
int indices_last_element = indices_shape.GetDim(indices_shape_size - 1);
int64_t indices_part{1};
for (int i = 0; i < indices_last_element - 1; ++i) {
indices_part *= static_cast<int64_t>(indices_shape.GetDim(i));
}
if (indices_part > std::numeric_limits<int>::max()) {
OP_LOGE(TbeGetName(op).c_str(), "Indices has too many elements for int indexing");
return false;
}
return true;
}
bool CheckGatherNdParamsSize(const Operator &op, int last_dim, int shape_size) {
if (last_dim > shape_size) {
OP_LOGE(TbeGetName(op).c_str(), "The last dim(%d) of indices must be <= params.rank(%d).", last_dim, shape_size);
return false;
}
return true;
}
IMPLEMT_VERIFIER(GatherNd, GatherNdVerify) {
if (!CheckGatherNdInputIndicesSize(op, "indices")) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(GatherNdInferShape) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
GeTensorDescPtr output_tensor_desc = op_desc->MutableOutputDesc("y");
std::vector<std::pair<int64_t, int64_t>> shape_range_x;
op_desc->MutableInputDesc("x")->GetShapeRange(shape_range_x);
std::vector<std::pair<int64_t, int64_t>> shape_range_indices;
op_desc->MutableInputDesc("indices")->GetShapeRange(shape_range_indices);
std::vector<std::pair<int64_t, int64_t>> out_range;
auto input_params = op_desc->MutableInputDesc("x");
auto input_indices = op_desc->MutableInputDesc("indices");
auto params_shape = input_params->GetShape();
auto indices_shape = input_indices->GetShape();
auto params_shape_size = params_shape.GetDimNum();
int indices_shape_size = indices_shape.GetDimNum();
vector<int64_t> dim_vec;
vector<int64_t> params_shape_vec = params_shape.GetDims();
vector<int64_t> indices_shape_vec = indices_shape.GetDims();
MakeUpShapeRange(params_shape_vec, shape_range_x);
MakeUpShapeRange(indices_shape_vec, shape_range_indices);
int indices_last_element{-2};
if (!IsUnknownRankShape(indices_shape_vec)) {
indices_last_element = indices_shape.GetDim(indices_shape_size - 1);
}
DataType params_type = input_params->GetDataType();
if (indices_last_element == -1 || indices_last_element == -2 || IsUnknownRankShape(params_shape_vec)) {
dim_vec.push_back(-2);
} else if (!CheckGatherNdParamsSize(op, indices_last_element, (int)params_shape_size)) {
return GRAPH_FAILED;
} else {
for (int i = 0; i < indices_shape_size - 1; ++i) {
dim_vec.push_back(indices_shape.GetDim(i));
if ((size_t)i < shape_range_indices.size()) {
out_range.push_back(shape_range_indices[i]);
}
}
for (size_t i = indices_last_element; i < params_shape_size; ++i) {
dim_vec.push_back(params_shape.GetDim(i));
if (i < shape_range_x.size()) {
out_range.push_back(shape_range_x[i]);
}
}
}
ge::GeShape output_shape = ge::GeShape(dim_vec);
DataType output_dtype = params_type;
output_tensor_desc->SetShape(output_shape);
output_tensor_desc->SetDataType(output_dtype);
TensorUtils::SetRealDimCnt(*output_tensor_desc, dim_vec.size());
if (!IsUnknownRankShape(dim_vec)) {
output_tensor_desc->SetShapeRange(out_range);
}
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(GatherNd, GatherNdInferShape);
VERIFY_FUNC_REG(GatherNd, GatherNdVerify);
// ----------------GatherNd End-------------------
// ----------------MaskedSelect Begin-------------------
bool InferShapeAndTypeMaskedSelect(Operator &op) {
OpDescPtr op_desc = OpDescUtils::GetOpDescFromOperator(op);
GeTensorDescPtr x_input = op_desc->MutableInputDesc(0);
GeShape x_shape = x_input->GetShape();
GeTensorDescPtr y_desc = op_desc->MutableOutputDesc(0);
DataType input_dtype = x_input->GetDataType();
y_desc->SetDataType(input_dtype);
std::vector<std::pair<int64_t, int64_t>> range;
y_desc->SetShape(GeShape({UNKNOWN_DIM}));
y_desc->SetOriginShape(GeShape({UNKNOWN_DIM}));
range.emplace_back(std::make_pair(1, x_shape.GetShapeSize()));
y_desc->SetShapeRange(range);
return true;
}
// Obtains the processing function of the output tensor description.
IMPLEMT_COMMON_INFERFUNC(MaskedSelectInferShape) {
if (InferShapeAndTypeMaskedSelect(op)) {
return GRAPH_SUCCESS;
}
OP_LOGE(TbeGetName(op).c_str(), "The shape of output y does not match that of x1 x2.");
return GRAPH_FAILED;
}
COMMON_INFER_FUNC_REG(MaskedSelect, MaskedSelectInferShape);
// ----------------MaskedSelect END---------------------
// ----------------IndexFill-------------------
// Obtains the processing function of the output tensor description.
IMPLEMT_COMMON_INFERFUNC(IndexFillInferShape) {
TensorDesc v_output_desc = op.GetOutputDescByName("y");
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
Format input_format = op.GetInputDescByName("x").GetFormat();
// shape of output y is the same as input x
ge::Shape shape_input = op.GetInputDescByName("x").GetShape();
v_output_desc.SetShape(shape_input);
v_output_desc.SetDataType(input_dtype);
v_output_desc.SetFormat(input_format);
if (op.UpdateOutputDesc("y", v_output_desc) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
// Registered inferfunction
CUST_COMMON_INFER_FUNC_REG(IndexFill, IndexFillInferShape);
// ----------------IndexFill END-------------------
} // namespace ge

View File

@ -93,7 +93,7 @@ CUST_IMPLEMT_VERIFIER(SparseApplyAdagradDA, SparseApplyAdagradDAVerify) {
}
Shape indices_shape;
if (WithRank(op.GetInputDesc(4), 1, indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(op.GetInputDesc(4), 1, indices_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input indices must be 1-D.");
return GRAPH_FAILED;
}

View File

@ -90,7 +90,7 @@ CUST_IMPLEMT_VERIFIER(SparseApplyCenteredRMSProp, SparseApplyCenteredRMSPropVeri
}
Shape indices_shape;
if (WithRank(op.GetInputDesc(9), 1, indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(op.GetInputDesc(9), 1, indices_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input indices must be 1-D.");
return GRAPH_FAILED;
}

View File

@ -90,7 +90,7 @@ CUST_IMPLEMT_VERIFIER(SparseApplyMomentum, SparseApplyMomentumVerify) {
}
Shape indices_shape;
if (WithRank(op.GetInputDesc(4), 1, indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(op.GetInputDesc(4), 1, indices_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input indices must be 1-D.");
return GRAPH_FAILED;
}

View File

@ -74,7 +74,7 @@ CUST_IMPLEMT_VERIFIER(SparseApplyProximalGradientDescent, SparseApplyProximalGra
}
Shape indices_shape;
if (WithRank(op.GetInputDesc(5), 1, indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(op.GetInputDesc(5), 1, indices_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input indices must be 1-D.");
return GRAPH_FAILED;
}

View File

@ -31,32 +31,32 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentMeanWithNumSegments, SparseSegmentMeanWithNu
GeShape indices_shape;
auto indices_desc = op_desc->MutableInputDesc(1);
if (WithRank(indices_desc, 1, indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(indices_desc, 1, indices_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input indices must be 1-D.");
return GRAPH_FAILED;
}
GeShape segment_ids_shape;
auto segment_ids_desc = op_desc->MutableInputDesc(2);
if (WithRank(segment_ids_desc, 1, segment_ids_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(segment_ids_desc, 1, segment_ids_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input segment_ids must be 1-D.");
return GRAPH_FAILED;
}
GeShape num_segments_shape;
auto num_segments_desc = op_desc->MutableInputDesc(3);
if (WithRankAtMost(num_segments_desc, 1, num_segments_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRankAtMost(num_segments_desc, 1, num_segments_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input nums_segments should be at most 1-D.");
return GRAPH_FAILED;
}
GeShape unused;
if (Merge(indices_shape, segment_ids_shape, unused, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (Merge(indices_shape, segment_ids_shape, unused, op) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
GeShape subshape;
if (SubShape(x_shape, 1, x_shape.GetDimNum(), 1, subshape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (SubShape(x_shape, 1, x_shape.GetDimNum(), 1, subshape, op) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
@ -67,7 +67,7 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentMeanWithNumSegments, SparseSegmentMeanWithNu
}
if (nums != UNKNOWN_DIM) {
if (MakeDimForScalarInput(tensor, nums, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (MakeDimForScalarInput(tensor, nums, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("fail to get dim from tensor of input[num_segments]."));
return GRAPH_FAILED;
}

View File

@ -30,7 +30,7 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentSqrtNGrad, SparseSegmentSqrtNGradInfer) {
auto x_desc = op_desc->MutableInputDesc(0);
GeShape x_ge_shape;
if (WithRankAtLeast(x_desc, 1, x_ge_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRankAtLeast(x_desc, 1, x_ge_shape, op) != GRAPH_SUCCESS) {
err_msg = GetShapeErrMsg(0, DebugString(x_desc->GetShape().GetDims()), "at least 1D");
err_msg = string("failed to call WithRankAtLeast function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
@ -39,7 +39,7 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentSqrtNGrad, SparseSegmentSqrtNGradInfer) {
auto indices_desc = op_desc->MutableInputDesc(1);
GeShape indices_shape;
if (WithRank(indices_desc, 1, indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(indices_desc, 1, indices_shape, op) != GRAPH_SUCCESS) {
err_msg = GetShapeErrMsg(1, DebugString(indices_desc->GetShape().GetDims()), "1D");
err_msg = string("failed to call WithRank function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
@ -48,7 +48,7 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentSqrtNGrad, SparseSegmentSqrtNGradInfer) {
GeShape unused;
GeShape segment_ids_shape(op_desc->MutableInputDesc(2)->GetShape());
if (Merge(segment_ids_shape, indices_shape, unused, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (Merge(segment_ids_shape, indices_shape, unused, op) != GRAPH_SUCCESS) {
err_msg = ConcatString("failed to call Merge function to merge input[segment_ids]'s shape",
DebugString(op_desc->MutableInputDesc(2)->GetShape().GetDims()),
" and input[indices]'s shape", DebugString(indices_shape.GetDims()));
@ -56,7 +56,7 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentSqrtNGrad, SparseSegmentSqrtNGradInfer) {
return GRAPH_FAILED;
}
auto unused_desc = op_desc->MutableInputDesc(3);
if (WithRank(unused_desc, 0, unused, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(unused_desc, 0, unused, op) != GRAPH_SUCCESS) {
err_msg = GetShapeErrMsg(3, DebugString(unused_desc->GetShape().GetDims()), "scalar");
err_msg = string("failed to call WithRank function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
@ -66,7 +66,7 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentSqrtNGrad, SparseSegmentSqrtNGradInfer) {
auto x_shape_dims = x_ge_shape.GetDims();
Shape x_shape(x_shape_dims);
Shape subshape;
if (SubShape(x_shape, 1, x_shape.GetDimNum(), 1, subshape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (SubShape(x_shape, 1, x_shape.GetDimNum(), 1, subshape, op) != GRAPH_SUCCESS) {
err_msg = ConcatString("failed to call SubShape function to get subshape from ", x_shape.GetDimNum(),
" to 1 in input[x] shape", DebugString(x_shape.GetDims()));
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);

View File

@ -24,32 +24,32 @@ graphStatus SparseSegmentReductionShapeFn(Operator &op) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
GeShape x_shapes;
auto x_desc = op_desc->MutableInputDesc(0);
if (WithRankAtLeast(x_desc, 1, x_shapes, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRankAtLeast(x_desc, 1, x_shapes, op) != GRAPH_SUCCESS) {
AICPU_OP_LOGE(TbeGetName(op).c_str(), "Input x should be at least 1-D.");
return GRAPH_FAILED;
}
GeShape indices_shapes;
auto indices_desc = op_desc->MutableInputDesc(1);
if (WithRank(indices_desc, 1, indices_shapes, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(indices_desc, 1, indices_shapes, op) != GRAPH_SUCCESS) {
AICPU_OP_LOGE(TbeGetName(op).c_str(), "Input indices must be 1-D.");
return GRAPH_FAILED;
}
GeShape segment_ids_shapes;
auto segment_ids_desc = op_desc->MutableInputDesc(2);
if (WithRank(segment_ids_desc, 1, segment_ids_shapes, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(segment_ids_desc, 1, segment_ids_shapes, op) != GRAPH_SUCCESS) {
AICPU_OP_LOGE(TbeGetName(op).c_str(), "Input segment_ids must be 1-D.");
return GRAPH_FAILED;
}
GeShape unuse;
if (Merge(indices_shapes, segment_ids_shapes, unuse, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (Merge(indices_shapes, segment_ids_shapes, unuse, op) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
GeShape subshapes;
if (SubShape(x_shapes, 1, x_shapes.GetDimNum(), 1, subshapes, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (SubShape(x_shapes, 1, x_shapes.GetDimNum(), 1, subshapes, op) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}

View File

@ -24,39 +24,39 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentSqrtNWithNumSegments, SparseSegmentSqrtNWith
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
GeShape x_shape;
auto x_desc = op_desc->MutableInputDesc(0);
if (WithRankAtLeast(x_desc, 1, x_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRankAtLeast(x_desc, 1, x_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input x should be at least 1-D.");
return GRAPH_FAILED;
}
GeShape indices_shape;
auto indices_desc = op_desc->MutableInputDesc(1);
if (WithRank(indices_desc, 1, indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(indices_desc, 1, indices_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input indices must be 1-D.");
return GRAPH_FAILED;
}
GeShape segment_ids_shape;
auto segment_ids_desc = op_desc->MutableInputDesc(2);
if (WithRank(segment_ids_desc, 1, segment_ids_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(segment_ids_desc, 1, segment_ids_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input segment_ids must be 1-D.");
return GRAPH_FAILED;
}
GeShape num_segments_shape;
auto num_segments_desc = op_desc->MutableInputDesc(3);
if (WithRankAtMost(num_segments_desc, 1, num_segments_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRankAtMost(num_segments_desc, 1, num_segments_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input nums_segments should be at most 1-D.");
return GRAPH_FAILED;
}
GeShape unused;
if (Merge(indices_shape, segment_ids_shape, unused, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (Merge(indices_shape, segment_ids_shape, unused, op) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
GeShape subshape;
if (SubShape(x_shape, 1, x_shape.GetDimNum(), 1, subshape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (SubShape(x_shape, 1, x_shape.GetDimNum(), 1, subshape, op) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
@ -67,7 +67,7 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentSqrtNWithNumSegments, SparseSegmentSqrtNWith
}
if (nums != UNKNOWN_DIM) {
if (MakeDimForScalarInput(tensor, nums, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (MakeDimForScalarInput(tensor, nums, op) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("fail to get dim from tensor of input[num_segments]."));
return GRAPH_FAILED;
}

View File

@ -25,21 +25,21 @@ CUST_IMPLEMT_INFERFUNC(SparseTensorToCSRSparseMatrix, SparseTensorToCSRSparseMat
GeShape x_indices_shape;
auto x_indices_desc = op_desc->MutableInputDesc(0);
if (WithRank(x_indices_desc, 2, x_indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(x_indices_desc, 2, x_indices_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input x_indices_desc must be 2-D.");
return GRAPH_FAILED;
}
GeShape x_values_shape;
auto x_values_desc = op_desc->MutableInputDesc(1);
if (WithRank(x_values_desc, 1, x_values_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(x_values_desc, 1, x_values_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input x_values must be 1-D.");
return GRAPH_FAILED;
}
GeShape x_dense_shape_shape;
auto x_dense_shape_desc = op_desc->MutableInputDesc(2);
if (WithRank(x_dense_shape_desc, 1, x_dense_shape_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(x_dense_shape_desc, 1, x_dense_shape_shape, op) != GRAPH_SUCCESS) {
OP_LOGE(TbeGetName(op).c_str(), "Input x_dense_shape must be 1-D.");
return GRAPH_FAILED;
}

View File

@ -30,32 +30,32 @@ CUST_IMPLEMT_INFERFUNC(SparseAddmm, SparseAddmmInfer) {
auto x1_shape_tensor = op.get_input_desc_x1_shape();
auto x2_tensor = op.get_input_desc_x2();
std::string err_msg;
if (WithRank(x1_indices_tensor, 2, unused_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(x1_indices_tensor, 2, unused_shape, op) != GRAPH_SUCCESS) {
err_msg = GetShapeErrMsg(0, DebugString(x1_indices_tensor.GetShape().GetDims()), "2D");
err_msg = string("failed to call WithRank function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (WithRank(x1_values_tensor, 1, unused_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(x1_values_tensor, 1, unused_shape, op) != GRAPH_SUCCESS) {
err_msg = GetShapeErrMsg(1, DebugString(x1_values_tensor.GetShape().GetDims()), "1D");
err_msg = string("failed to call WithRank function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (MakeShapeFromShapeTensor(op, "x1_shape", x1_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (MakeShapeFromShapeTensor(op, "x1_shape", x1_shape) != GRAPH_SUCCESS) {
err_msg = ConcatString(
"failed to call MakeShapeFromShapeTensor function to make shape from "
"input[x1_shape]");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (WithRankShape(x1_shape, 2, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRankShape(x1_shape, 2, op) != GRAPH_SUCCESS) {
err_msg = GetShapeErrMsg(2, DebugString(x1_shape.GetDims()), "2D");
err_msg = string("failed to call WithRank function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (WithRank(x2_tensor, 2, x2_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(x2_tensor, 2, x2_shape, op) != GRAPH_SUCCESS) {
err_msg = GetShapeErrMsg(3, DebugString(x2_tensor.GetShape().GetDims()), "2D");
err_msg = string("failed to call WithRank function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);

View File

@ -111,19 +111,19 @@ CUST_IMPLEMT_INFERFUNC(Sspaddmm, SspaddmmInfer) {
}
// check dimension
if (WithRank(mat1_values_tensor, 1, unused_shape1, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(mat1_values_tensor, 1, unused_shape1, op) != GRAPH_SUCCESS) {
err_msg = GetShapeErrMsg(1, DebugString(mat1_values_tensor.GetShape().GetDims()), "1D");
err_msg = string("MAT1 Values failed to call WithRank function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (WithRank(input_indices_tensor, 2, unused_shape2, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(input_indices_tensor, 2, unused_shape2, op) != GRAPH_SUCCESS) {
err_msg = GetShapeErrMsg(0, DebugString(input_indices_tensor.GetShape().GetDims()), "2D");
err_msg = string("input indices failed to call WithRank function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (WithRank(input_values_tensor, 1, unused_shape1, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(input_values_tensor, 1, unused_shape1, op) != GRAPH_SUCCESS) {
err_msg = GetShapeErrMsg(1, DebugString(input_values_tensor.GetShape().GetDims()), "1D");
err_msg = string("input values failed to call WithRank function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
@ -134,7 +134,7 @@ CUST_IMPLEMT_INFERFUNC(Sspaddmm, SspaddmmInfer) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (WithRank(mat2_tensor, 2, mat2_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRank(mat2_tensor, 2, mat2_shape, op) != GRAPH_SUCCESS) {
err_msg = GetShapeErrMsg(3, DebugString(mat2_tensor.GetShape().GetDims()), "2D");
err_msg = string("mat2 tensor failed to call WithRank function, ") + err_msg;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);

View File

@ -0,0 +1,188 @@
/*
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/ops/transformation_ops.h"
#include "register/op_impl_registry.h"
#include "utils/util.h"
#include "utils/common_shape_fns.h"
namespace ge {
// ------------------DepthToSpace------------------
static bool VerifyDepthToSpaceInputShape(const Operator &op, const int64_t &block_size,
const std::vector<int64_t> &input_dims, const std::string &data_format) {
bool check_format = (data_format == "NCHW" || data_format == "NHWC");
if (check_format && !IsUnknown(input_dims)) {
int64_t c_dim = 3;
c_dim = data_format == "NHWC" ? 3 : 1;
auto mod_res = input_dims[c_dim] % (block_size * block_size);
if (mod_res != 0) {
OP_LOGE(TbeGetName(op),
"Depth size must be divisible by block_size * block_size,"
"but got depth[%ld], block_size[%ld], data_format[%s]",
input_dims[c_dim], block_size, data_format.c_str());
return false;
}
}
return true;
}
IMPLEMT_VERIFIER(DepthToSpace, DepthToSpaceVerify) {
// verify input shape size
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
auto input_desc = op_info->MutableInputDesc("x");
auto input_dims = input_desc->MutableShape().GetDims();
if (!IsUnknownRankShape(input_dims) && (input_dims.size() < 4)) {
std::string err_msg = GetAttrValueErrMsg("input_dims", std::to_string(input_dims.size()), ConcatString(">=4"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
// verify block size
int64_t block_size;
if (op.GetAttr("block_size", block_size) != GRAPH_SUCCESS) {
std::string err_msg = GetInputInvalidErrMsg("block_size");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (block_size < 2) {
std::string err_msg = GetAttrValueErrMsg("block_size", std::to_string(block_size), ConcatString("=<2"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
// verify mode
std::string mode;
if (op.GetAttr("mode", mode) != GRAPH_SUCCESS) {
std::string err_msg = GetInputInvalidErrMsg("mode");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (mode != "DCR" && mode != "CRD") {
string expected_format_list = ConcatString("DCR, CRD");
std::string err_msg = GetAttrValueErrMsg("mode", mode, expected_format_list);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
// verify data_format
std::string data_format;
if (op.GetAttr("data_format", data_format) != GRAPH_SUCCESS) {
std::string err_msg = GetInputInvalidErrMsg("data_format");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
if (data_format != "NHWC" && data_format != "NCHW" && data_format != "NC1HWC0") {
string expected_format_list = ConcatString("NHWC, NCHW, NC1HWC0");
std::string err_msg = GetInputFormatNotSupportErrMsg("data_format", expected_format_list, data_format);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
// verify input shape
bool check_input_shape = true;
check_input_shape = VerifyDepthToSpaceInputShape(op, block_size, input_dims, data_format);
if (!check_input_shape) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(DepthToSpaceInfer) {
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
auto input_desc = op_info->MutableInputDesc("x");
auto input_dims = input_desc->MutableShape().GetDims();
auto input_dtype = input_desc->GetDataType();
auto input_format = static_cast<ge::Format>(ge::GetPrimaryFormat(input_desc->GetFormat()));
auto output_desc = op_info->MutableOutputDesc("y");
output_desc->SetDataType(input_dtype);
// get attr block_size
int64_t block_size;
if (GRAPH_SUCCESS != op.GetAttr("block_size", block_size)) {
std::string err_msg = GetInputInvalidErrMsg("block_size");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
return GRAPH_FAILED;
}
// not dynamic case, only set shape
if (!IsUnknown(input_dims)) {
std::vector<int64_t> output_dims;
output_dims.push_back(input_dims[0]);
if (input_format == FORMAT_NCHW) {
output_dims.push_back(input_dims[1] / block_size / block_size);
output_dims.push_back(input_dims[2] * block_size);
output_dims.push_back(input_dims[3] * block_size);
} else { // without NCHW all other format set as NHWC
output_dims.push_back(input_dims[1] * block_size);
output_dims.push_back(input_dims[2] * block_size);
output_dims.push_back(input_dims[3] / block_size / block_size);
}
output_desc->SetShape(GeShape(output_dims));
return GRAPH_SUCCESS;
}
// dynamic case, input shape is -2, output is -2
if (IsUnknownRankShape(input_dims)) {
output_desc->SetShape(GeShape(input_dims));
OP_LOGW(TbeGetName(op).c_str(), "input shape is UnknownRank, set output is UnknownRank.");
return GRAPH_SUCCESS;
}
// dynamic case, input shape is -1, output is -1
std::vector<std::pair<int64_t, int64_t>> input_range;
input_desc->GetShapeRange(input_range);
MakeUpShapeRange(input_dims, input_range);
// infer output shape and range
std::vector<int64_t> output_dims;
std::vector<std::pair<int64_t, int64_t>> output_range;
output_dims.push_back(input_dims[0]);
output_range.push_back(input_range[0]);
int64_t dim;
int64_t range_min;
int64_t range_max;
if (input_format == FORMAT_NCHW) {
dim = input_dims[1] == -1 ? -1 : input_dims[1] / block_size / block_size;
range_min = input_range[1].first / block_size / block_size;
range_min = std::max(int64_t(range_min), int64_t(1));
range_max = input_range[1].second == -1 ? -1 : input_range[1].second / block_size / block_size;
output_dims.push_back(dim);
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
dim = input_dims[2] == -1 ? -1 : input_dims[2] * block_size;
range_min = input_range[2].first * block_size;
range_max = input_range[2].second == -1 ? -1 : input_range[2].second * block_size;
output_dims.push_back(dim);
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
dim = input_dims[3] == -1 ? -1 : input_dims[3] * block_size;
range_min = input_range[3].first * block_size;
range_max = input_range[3].second == -1 ? -1 : input_range[3].second * block_size;
output_dims.push_back(dim);
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
} else {
dim = input_dims[1] == -1 ? -1 : input_dims[1] * block_size;
range_min = input_range[1].first * block_size;
range_max = input_range[1].second == -1 ? -1 : input_range[1].second * block_size;
output_dims.push_back(dim);
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
dim = input_dims[2] == -1 ? -1 : input_dims[2] * block_size;
range_min = input_range[2].first * block_size;
range_max = input_range[2].second == -1 ? -1 : input_range[2].second * block_size;
output_dims.push_back(dim);
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
dim = input_dims[3] == -1 ? -1 : input_dims[3] / block_size / block_size;
range_min = input_range[3].first / block_size / block_size;
range_min = std::max(int64_t(range_min), int64_t(1));
range_max = input_range[3].second == -1 ? -1 : input_range[3].second / block_size / block_size;
output_dims.push_back(dim);
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
}
output_desc->SetShape(GeShape(output_dims));
output_desc->SetShapeRange(output_range);
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(DepthToSpace, DepthToSpaceInfer);
VERIFY_FUNC_REG(DepthToSpace, DepthToSpaceVerify);
// -------------------DepthToSpace END-----------------
} // namespace ge

View File

@ -28,7 +28,7 @@ CUST_IMPLEMT_INFERFUNC(TridiagonalMatMul, TridiagonalMatMulInfer) {
Shape rhs;
TensorDesc rhs_desc = op.GetInputDesc(3);
auto rhs_shape = rhs_desc.GetShape().GetDims();
if (WithRankAtLeast(rhs_desc, 2, rhs, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
if (WithRankAtLeast(rhs_desc, 2, rhs, op) != GRAPH_SUCCESS) {
error_msg =
ConcatString("failed to call WithRankatleast function, ", "the rank of input[rhs] must at least be 2, but get ",
rhs_desc.GetShape().GetDimNum(), ".");

View File

@ -18,8 +18,8 @@
* \file axis_util.h
* \brief get the axis value
*/
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_
#define OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_
#ifndef CUSTOMIZE_OP_PROTO_UTIL_AXIS_UTIL_H_
#define CUSTOMIZE_OP_PROTO_UTIL_AXIS_UTIL_H_
#include <memory>
#include <functional>
@ -142,4 +142,4 @@ class AxisUtil {
};
} // namespace ge
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_
#endif // CUSTOMIZE_OP_PROTO_UTIL_AXIS_UTIL_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
/*
* Copyright (c) Huawei Technologies Co., Ltd 2019-2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,14 +21,10 @@
#include "common_shape_fns.h"
#include <vector>
#include <limits>
#include <algorithm>
#include <utility>
#include <map>
#include "op_log.h"
#include "error_util.h"
#include "graph/utils/op_desc_utils.h"
#include "common/util/error_manager/error_manager.h"
#include "util.h"
namespace ge {
const std::map<std::string, DataType> dtype_maps{{"DT_FLOAT", DT_FLOAT},
@ -58,6 +54,42 @@ const std::map<std::string, DataType> dtype_maps{{"DT_FLOAT", DT_FLOAT},
{"DT_DUAL", DT_DUAL},
{"DT_UNDEFINED", DT_UNDEFINED}};
graphStatus WithRankAtLeast(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op) {
if (rank > INT32_MAX) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
return GRAPH_FAILED;
}
Shape s = tensor.GetShape();
std::vector<int64_t> dims = s.GetDims();
// dim.size() convert to be type int64_t can't overflow
int64_t size = static_cast<int64_t>(dims.size());
if (!((size >= rank) || (dims == UNKNOWN_SHAPE))) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", size, "] must be at least [", rank, "]"));
return GRAPH_FAILED;
}
out = s;
return GRAPH_SUCCESS;
}
graphStatus WithRankAtLeast(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape,
const ge::Operator &op) {
if (rank > INT32_MAX) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
return GRAPH_FAILED;
}
GeShape s = tensorDesc->GetShape();
std::vector<int64_t> dims = s.GetDims();
// dim.size() convert to be type int64_t can't overflow
int64_t size = static_cast<int64_t>(dims.size());
if ((dims != UNKNOWN_RANK) && (size < rank)) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", size, "] must be at least [", rank, "]"));
return GRAPH_FAILED;
}
out_shape = s;
return GRAPH_SUCCESS;
}
graphStatus WithRankAtLeast(const TensorDesc &tensor, int64_t rank, Shape &out, const char *op_name) {
if (rank > INT32_MAX) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max"));
@ -95,9 +127,9 @@ graphStatus WithRankAtLeast(const GeTensorDescPtr &tensorDesc, int64_t rank, GeS
return GRAPH_SUCCESS;
}
graphStatus WithRankShape(GeShape &shape, int64_t rank, const char *op_name) {
graphStatus WithRankShape(GeShape &shape, int64_t rank, const ge::Operator &op) {
if (rank > INT32_MAX) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
return GRAPH_FAILED;
}
@ -109,7 +141,7 @@ graphStatus WithRankShape(GeShape &shape, int64_t rank, const char *op_name) {
return GRAPH_SUCCESS;
}
if (existing != rank) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", existing, "] must be [", rank, "]"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", existing, "] must be [", rank, "]"));
return GRAPH_FAILED;
}
@ -118,9 +150,9 @@ graphStatus WithRankShape(GeShape &shape, int64_t rank, const char *op_name) {
return GRAPH_SUCCESS;
}
graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const char *op_name) {
graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op) {
if (rank > INT32_MAX) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
return GRAPH_FAILED;
}
Shape s = tensor.GetShape();
@ -133,16 +165,16 @@ graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const c
}
if (existing != rank) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", existing, "] must be [", rank, "]"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", existing, "] must be [", rank, "]"));
return GRAPH_FAILED;
}
out = s;
return GRAPH_SUCCESS;
}
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape, const char *op_name) {
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape, const ge::Operator &op) {
if (rank > INT32_MAX) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
return GRAPH_FAILED;
}
@ -155,16 +187,16 @@ graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &o
}
if (existing != rank) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", existing, "] must be [", rank, "]"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", existing, "] must be [", rank, "]"));
return GRAPH_FAILED;
}
out_shape = s;
return GRAPH_SUCCESS;
}
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, Shape &out_shape, const char *op_name) {
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, Shape &out_shape, const ge::Operator &op) {
if (rank > INT32_MAX) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
return GRAPH_FAILED;
}
@ -177,27 +209,27 @@ graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, Shape &out
}
if (existing != rank) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", existing, "] must be [", rank, "]"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", existing, "] must be [", rank, "]"));
return GRAPH_FAILED;
}
out_shape = Shape(s.GetDims());
return GRAPH_SUCCESS;
}
graphStatus WithValue(int64_t dim, int64_t value, int64_t &out, const char *op_name) {
graphStatus WithValue(int64_t dim, int64_t value, int64_t &out, const ge::Operator &op) {
out = value;
if (dim == UNKNOWN_DIM) {
return GRAPH_SUCCESS;
}
if (dim != value) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("dim[", dim, "] should be ", value));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("dim[", dim, "] should be ", value));
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape &prefix_out, const char *op_name) {
graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape &prefix_out, const ge::Operator &op) {
// Same shape and unknown rank
if (!RankKnown(s) || !RankKnown(prefix)) {
s_out = s;
@ -208,7 +240,7 @@ graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape
std::vector<int64_t> dims1 = s.GetDims();
if ((dims1 != UNKNOWN_RANK) && (dims1.size() < rank)) {
std::string err_msg = ConcatString("first shape rank[", dims1.size(), "] must be at least rank[", rank, "]");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
@ -220,7 +252,7 @@ graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape
if (Merge(s.GetDim(i), prefix.GetDim(i), dims[i]) != GRAPH_SUCCESS) {
std::string err_msg = ConcatString(i, "th dim of first shape", DebugString(s.GetDims()),
" is not same as that of prefix shape", DebugString(prefix.GetDims()));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
}
@ -246,7 +278,7 @@ graphStatus Merge(int64_t dim1, int64_t dim2, int64_t &out) {
return GRAPH_FAILED;
}
graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const char *op_name) {
graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const ge::Operator &op) {
// Same shape and unknown rank
if (s0.GetDims() == s1.GetDims()) {
out = s0;
@ -263,7 +295,7 @@ graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const char *op_n
if (s1.GetDimNum() != rank) {
std::string err_msg = ConcatString("different rank of first shape", DebugString(s0.GetDims()), " and second shape",
DebugString(s1.GetDims()));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
@ -282,7 +314,7 @@ graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const char *op_n
} else if (d0 != d1) {
std::string err_msg = ConcatString("different ", i, "th dim of first shape", DebugString(s0.GetDims()),
" and second shape", DebugString(s1.GetDims()));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
}
@ -299,7 +331,7 @@ graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const char *op_n
if (Merge(s0.GetDim(i), s1.GetDim(i), dims[i]) == GRAPH_FAILED) {
std::string err_msg = ConcatString("merge ", i, "th dim failed, first shape", DebugString(s0.GetDims()),
" and second shape", DebugString(s1.GetDims()));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
}
@ -308,7 +340,7 @@ graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const char *op_n
return GRAPH_SUCCESS;
}
graphStatus Merge(const GeShape &s0, const GeShape &s1, GeShape &out, const char *op_name) {
graphStatus Merge(const GeShape &s0, const GeShape &s1, GeShape &out, const ge::Operator &op) {
// Same shape and unknown rank
if (s0.GetDims() == s1.GetDims()) {
out = s0;
@ -325,7 +357,7 @@ graphStatus Merge(const GeShape &s0, const GeShape &s1, GeShape &out, const char
if (s1.GetDimNum() != rank) {
std::string err_msg = ConcatString("different rank of first shape", DebugString(s0.GetDims()), " and second shape",
DebugString(s1.GetDims()));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
@ -344,7 +376,7 @@ graphStatus Merge(const GeShape &s0, const GeShape &s1, GeShape &out, const char
} else if (d0 != d1) {
std::string err_msg = ConcatString("different ", i, "th dim of first shape", DebugString(s0.GetDims()),
" and second shape", DebugString(s1.GetDims()));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
}
@ -361,7 +393,7 @@ graphStatus Merge(const GeShape &s0, const GeShape &s1, GeShape &out, const char
if (Merge(s0.GetDim(i), s1.GetDim(i), dims[i]) == GRAPH_FAILED) {
std::string err_msg = ConcatString("merge ", i, "th dim failed, first shape", DebugString(s0.GetDims()),
" and second shape", DebugString(s1.GetDims()));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
}
@ -403,7 +435,7 @@ void MergeRange(const std::vector<std::pair<int64_t, int64_t>> &shared_shape_ran
}
graphStatus MergeShapeAndRange(const ShapeAndRange &shared_shape_and_range, const ShapeAndRange &value_shape_and_range,
ShapeAndRange &out, bool &shape_changed, const char *op_name) {
ShapeAndRange &out, bool &shape_changed, const ge::Operator &op) {
if (!RankKnown(shared_shape_and_range.shape_)) {
out = {Shape(UNKNOWN_RANK), {}, value_shape_and_range.shape_type_};
return GRAPH_SUCCESS;
@ -445,7 +477,7 @@ graphStatus MergeShapeAndRange(const ShapeAndRange &shared_shape_and_range, cons
return GRAPH_SUCCESS;
}
graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Shape &out, const char *op_name) {
graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Shape &out, const ge::Operator &op) {
if (!RankKnown(s)) {
out = Shape(ge::UNKNOWN_SHAPE);
return GRAPH_SUCCESS;
@ -456,8 +488,8 @@ graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Sh
}
if (!FastBoundsCheck(dim_index, s.GetDimNum())) {
out = Shape();
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("out of range: replace dim[", dim_index_in,
"] for shape with rank[", s.GetDimNum(), "]"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
op, ConcatString("out of range: replace dim[", dim_index_in, "] for shape with rank[", s.GetDimNum(), "]"));
return GRAPH_FAILED;
}
std::vector<int64_t> dims = s.GetDims();
@ -466,7 +498,7 @@ graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Sh
return GRAPH_SUCCESS;
}
graphStatus ReplaceDim(const GeShape &s, int64_t dim_index_in, int64_t new_dim, GeShape &out, const char *op_name) {
graphStatus ReplaceDim(const GeShape &s, int64_t dim_index_in, int64_t new_dim, GeShape &out, const ge::Operator &op) {
if (!RankKnown(s)) {
out = GeShape(UNKNOWN_RANK);
return GRAPH_SUCCESS;
@ -477,8 +509,8 @@ graphStatus ReplaceDim(const GeShape &s, int64_t dim_index_in, int64_t new_dim,
}
if (!FastBoundsCheck(dim_index, s.GetDimNum())) {
out = GeShape();
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("out of range: replace dim[", dim_index_in,
"] for shape with rank[", s.GetDimNum(), "]"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
op, ConcatString("out of range: replace dim[", dim_index_in, "] for shape with rank[", s.GetDimNum(), "]"));
return GRAPH_FAILED;
}
std::vector<int64_t> dims = s.GetDims();
@ -491,7 +523,7 @@ template <typename Ta, typename Tb>
bool FastBoundsCheck(const Ta index, const Tb limit) {
static_assert(std::is_integral<Ta>::value && std::is_integral<Tb>::value,
"FastBoundsCheck can only be used on integer types.");
using UIndex = typename std::make_unsigned<decltype(index + limit)>::type;
typedef typename std::make_unsigned<decltype(index + limit)>::type UIndex;
return static_cast<UIndex>(index) < static_cast<UIndex>(limit);
}
@ -512,7 +544,7 @@ graphStatus Add(int64_t dim1, int64_t dim2, int64_t &out) {
return GRAPH_SUCCESS;
}
graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const char *op_name) {
graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const ge::Operator &op) {
if (dim2 == 0) {
out = dim1;
} else if ((dim1 == UNKNOWN_DIM) || (dim2 == UNKNOWN_DIM)) {
@ -520,7 +552,7 @@ graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const char *op_na
} else {
if (dim1 < dim2) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
string(op_name), ConcatString("negative dimension caused by subtracting. dim1[", dim1, "], dim2[", dim2, "]"));
op, ConcatString("negative dimension caused by subtracting. dim1[", dim1, "], dim2[", dim2, "]"));
return GRAPH_FAILED;
}
out = dim1 - dim2;
@ -528,10 +560,9 @@ graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const char *op_na
return GRAPH_SUCCESS;
}
graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride, Shape &out, const char *op_name) {
graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride, Shape &out, const ge::Operator &op) {
if (s.GetDimNum() > INT32_MAX) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name),
ConcatString("rank[", s.GetDimNum(), "] cannot exceed kint32max"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", s.GetDimNum(), "] cannot exceed kint32max"));
return GRAPH_FAILED;
}
const int64_t rank = static_cast<int64_t>(s.GetDimNum());
@ -557,7 +588,7 @@ graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride,
start = rank + start;
if (start < 0) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
string(op_name), ConcatString("invalid start[", start - rank, "] to get sub shape with rank[", rank, "]"));
op, ConcatString("invalid start[", start - rank, "] to get sub shape with rank[", rank, "]"));
return GRAPH_FAILED;
}
}
@ -566,21 +597,21 @@ graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride,
end = rank + end;
if (end < 0) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
string(op_name), ConcatString("invalid end[", end - rank, "] to get sub shape with rank[", rank, "]"));
op, ConcatString("invalid end[", end - rank, "] to get sub shape with rank[", rank, "]"));
return GRAPH_FAILED;
}
}
// stride > 0 and start > end
if (!((stride <= 0 || start <= end))) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("start[", start, "] should be less than end[",
end, "] at positive stride[", stride, "]"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
op, ConcatString("start[", start, "] should be less than end[", end, "] at positive stride[", stride, "]"));
return GRAPH_FAILED;
}
// stride < 0 and start < end
if (!(stride >= 0 || start >= end)) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("start[", start, "] should be greater than end[",
end, "] at negative stride[", stride, "]"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
op, ConcatString("start[", start, "] should be greater than end[", end, "] at negative stride[", stride, "]"));
return GRAPH_FAILED;
}
std::vector<int64_t> dims;
@ -593,10 +624,10 @@ graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride,
}
graphStatus SubShape(const GeShape &src_shape, int64_t start, int64_t end, int64_t stride, GeShape &out_shape,
const char *op_name) {
const ge::Operator &op) {
int64_t src_rank = src_shape.GetDimNum();
if (src_rank > static_cast<int64_t>(std::numeric_limits<int32_t>::max())) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", src_rank, "] cannot exceed kint32max"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", src_rank, "] cannot exceed kint32max"));
return GRAPH_FAILED;
}
@ -622,8 +653,7 @@ graphStatus SubShape(const GeShape &src_shape, int64_t start, int64_t end, int64
start += src_rank;
if (start < 0) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
string(op_name),
ConcatString("invalid start[", start - src_rank, "] to get sub shape with rank[", src_rank, "]"));
op, ConcatString("invalid start[", start - src_rank, "] to get sub shape with rank[", src_rank, "]"));
return GRAPH_FAILED;
}
}
@ -632,18 +662,18 @@ graphStatus SubShape(const GeShape &src_shape, int64_t start, int64_t end, int64
end += src_rank;
if (end < 0) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
string(op_name), ConcatString("invalid end[", end - src_rank, "] to get sub shape with rank[", src_rank, "]"));
op, ConcatString("invalid end[", end - src_rank, "] to get sub shape with rank[", src_rank, "]"));
return GRAPH_FAILED;
}
}
if (stride > 0 && start > end) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("start[", start, "] should be less than end[",
end, "] at positive stride[", stride, "]"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
op, ConcatString("start[", start, "] should be less than end[", end, "] at positive stride[", stride, "]"));
return GRAPH_FAILED;
} else if (stride < 0 && start < end) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("start[", start, "] should be greater than end[",
end, "] at negative stride[", stride, "]"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
op, ConcatString("start[", start, "] should be greater than end[", end, "] at negative stride[", stride, "]"));
return GRAPH_FAILED;
}
@ -698,7 +728,7 @@ graphStatus Concatenate(const GeShape &s1, const GeShape &s2, GeShape &out) {
graphStatus Matrix(int64_t dim1, int64_t dim2, Shape &out) {
std::vector<int64_t> dims;
dims.reserve(DIM_VALUE2); // The number of dims is 2.
dims.reserve(2); // The number of dims is 2.
dims.push_back(dim1);
dims.push_back(dim2);
Shape s(dims);
@ -716,7 +746,7 @@ graphStatus Vector(int64_t dim, Shape &out) {
}
static graphStatus GetShapeDataFromShapeTensor(Operator &op, const string &dst_name, int64_t rank,
std::vector<int64_t> &data, const char *op_name) {
std::vector<int64_t> &data) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto shape_data_desc = op_desc->MutableInputDesc(dst_name);
@ -728,13 +758,13 @@ static graphStatus GetShapeDataFromShapeTensor(Operator &op, const string &dst_n
DataType data_type = shape_data_desc->GetDataType();
if (dims.size() != static_cast<size_t>(rank)) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
string(op_name), ConcatString("invalid shape data rank[", dims.size(), "], should be [", rank, "]"));
op, ConcatString("invalid shape data rank[", dims.size(), "], should be [", rank, "]"));
return GRAPH_FAILED;
}
int64_t dim_value = ((rank > 0) && (dims[0] > 0)) ? dims[0] : 1;
data.clear();
if (dims[0] < 0) {
OP_LOGI(op_name, "Shape rank is %zu, dims[0] value is [%ld]", dims.size(), dims[0]);
OP_LOGI(op, "Shape rank is %zu, dims[0] value is [%ld]", dims.size(), dims[0]);
data.push_back(UNKNOWN_DIM_NUM);
return GRAPH_SUCCESS;
}
@ -747,7 +777,7 @@ static graphStatus GetShapeDataFromShapeTensor(Operator &op, const string &dst_n
data.push_back(static_cast<int64_t>(shape_data[i]));
}
} else {
OP_LOGI(TbeGetName(op).c_str(), "Input [%s] is not a const tensor.", dst_name.c_str());
OP_LOGI(op, "Input [%s] is not a const tensor.", dst_name.c_str());
for (int64_t i = 0; i < dim_value; i++) {
data.push_back(UNKNOWN_DIM);
}
@ -759,14 +789,14 @@ static graphStatus GetShapeDataFromShapeTensor(Operator &op, const string &dst_n
data.push_back(static_cast<int64_t>(shape_data[i]));
}
} else {
OP_LOGI(TbeGetName(op).c_str(), "Input [%s] is not a const tensor.", dst_name.c_str());
OP_LOGI(op, "Input [%s] is not a const tensor.", dst_name.c_str());
for (int64_t i = 0; i < dim_value; i++) {
data.push_back(UNKNOWN_DIM);
}
}
} else {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
string(op_name), ConcatString("invalid data type[", DTypeStr(data_type), "], should be DT_INT32 or DT_INT64"));
op, ConcatString("invalid data type[", DTypeStr(data_type), "], should be DT_INT32 or DT_INT64"));
return GRAPH_FAILED;
}
@ -774,7 +804,7 @@ static graphStatus GetShapeDataFromShapeTensor(Operator &op, const string &dst_n
}
static graphStatus GetShapeDataFromConstData(const Tensor &tensor, int64_t rank, std::vector<int64_t> &data,
const char *op_name) {
const ge::Operator &op) {
TensorDesc shape_data_desc = tensor.GetTensorDesc();
Shape shape_data_shape = shape_data_desc.GetShape();
std::vector<int64_t> dims = shape_data_shape.GetDims();
@ -782,84 +812,83 @@ static graphStatus GetShapeDataFromConstData(const Tensor &tensor, int64_t rank,
if (dims.size() != static_cast<size_t>(rank)) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
string(op_name), ConcatString("invalid shape data rank[", dims.size(), "], should be [", rank, "]"));
op, ConcatString("invalid shape data rank[", dims.size(), "], should be [", rank, "]"));
return GRAPH_FAILED;
}
int64_t dim_value = rank > 0 ? dims[0] : 1;
OP_LOGI(op_name, "data_type = %d, dim_value = %ld", data_type, dim_value);
OP_LOGI(op, "data_type = %d, dim_value = %ld", data_type, dim_value);
data.clear();
data.reserve(dim_value);
if (data_type == DT_INT32) {
const int32_t *shape_data = reinterpret_cast<const int32_t *>(tensor.GetData());
for (int64_t i = 0; i < dim_value; i++) {
OP_LOGI(op_name, "DT_INT32 i = %ld, shape_data[i] = %ld", i, static_cast<int64_t>(shape_data[i]));
OP_LOGI(op, "DT_INT32 i = %ld, shape_data[i] = %ld", i, static_cast<int64_t>(shape_data[i]));
data.push_back(static_cast<int64_t>(shape_data[i]));
}
} else if (data_type == DT_INT64) {
const int64_t *shape_data = reinterpret_cast<const int64_t *>(tensor.GetData());
for (int64_t i = 0; i < dim_value; i++) {
OP_LOGI(op_name, "DT_INT64 i = %ld, shape_data[i] = %ld", i, shape_data[i]);
OP_LOGI(op, "DT_INT64 i = %ld, shape_data[i] = %ld", i, shape_data[i]);
data.push_back(shape_data[i]);
}
} else {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
string(op_name), ConcatString("invalid data type[", DTypeStr(data_type), "], should be DT_INT32 or DT_INT64"));
op, ConcatString("invalid data type[", DTypeStr(data_type), "], should be DT_INT32 or DT_INT64"));
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
graphStatus MakeShapeFromShapeTensor(const Tensor &tensor, Shape &out, const char *op_name) {
graphStatus MakeShapeFromShapeTensor(const Tensor &tensor, Shape &out, const ge::Operator &op) {
std::vector<int64_t> shape_data;
GetShapeDataFromConstData(tensor, 1, shape_data, op_name);
GetShapeDataFromConstData(tensor, 1, shape_data, op);
out = Shape(shape_data);
return GRAPH_SUCCESS;
}
graphStatus MakeShapeFromShapeTensor(Operator &op, const string &dst_name, GeShape &out, const char *op_name) {
graphStatus MakeShapeFromShapeTensor(Operator &op, const string &dst_name, GeShape &out) {
std::vector<int64_t> shape_data;
if (GetShapeDataFromShapeTensor(op, dst_name, 1, shape_data, op_name) != GRAPH_SUCCESS) {
if (GetShapeDataFromShapeTensor(op, dst_name, 1, shape_data) != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
out = GeShape(shape_data);
return GRAPH_SUCCESS;
}
graphStatus MakeDimForScalarInput(const Tensor &tensor, int64_t &out, const char *op_name) {
graphStatus MakeDimForScalarInput(const Tensor &tensor, int64_t &out, const ge::Operator &op) {
std::vector<int64_t> shape_data;
GetShapeDataFromConstData(tensor, 0, shape_data, op_name);
GetShapeDataFromConstData(tensor, 0, shape_data, op);
out = shape_data[0];
return GRAPH_SUCCESS;
}
graphStatus WithRankAtMost(const TensorDesc &tensor, int64_t rank, Shape &out, const char *op_name) {
graphStatus WithRankAtMost(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op) {
if (rank > INT32_MAX) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
return GRAPH_FAILED;
}
Shape s = tensor.GetShape();
std::vector<int64_t> dims = s.GetDims();
if (!((dims.size() <= static_cast<size_t>(rank)) || (dims == ge::UNKNOWN_SHAPE))) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name),
ConcatString("invalid rank[", dims.size(), "], should be at most ", rank));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("invalid rank[", dims.size(), "], should be at most ", rank));
return GRAPH_FAILED;
}
out = s;
return GRAPH_SUCCESS;
}
graphStatus WithRankAtMost(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape, const char *op_name) {
graphStatus WithRankAtMost(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape,
const ge::Operator &op) {
if (rank > INT32_MAX) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
return GRAPH_FAILED;
}
GeShape s = tensorDesc->GetShape();
std::vector<int64_t> dims = s.GetDims();
if ((dims != ge::UNKNOWN_RANK) && (dims.size() > static_cast<size_t>(rank))) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name),
ConcatString("invalid rank[", dims.size(), "], should be at most ", rank));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("invalid rank[", dims.size(), "], should be at most ", rank));
return GRAPH_FAILED;
}
@ -881,20 +910,19 @@ graphStatus UnchangedShape(Operator &op, const string input_name, const string o
}
graphStatus Divide(const int64_t dividend, const int64_t divisor, const bool evenlyDivisible, int64_t &out,
const char *op_name) {
const ge::Operator &op) {
if (divisor == 1) {
out = dividend;
} else if ((dividend == ge::UNKNOWN_DIM) || (divisor == ge::UNKNOWN_DIM)) {
out = ge::UNKNOWN_DIM;
} else {
if (divisor <= 0) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name),
ConcatString("invalid divisor[", divisor, "], should be positive"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("invalid divisor[", divisor, "], should be positive"));
return GRAPH_FAILED;
}
if (!((!evenlyDivisible) || (dividend % divisor) == 0)) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
string(op_name), ConcatString("[", dividend, "] cannot be evenly divisible by [", divisor, "]"));
op, ConcatString("[", dividend, "] cannot be evenly divisible by [", divisor, "]"));
return GRAPH_FAILED;
}
out = dividend / divisor;
@ -970,25 +998,25 @@ bool ValueKnown(const Shape &shape, const size_t &dim_index) {
}
graphStatus ValidateSparseTensor(const TensorDesc &indices, const TensorDesc &values, const TensorDesc &shape,
const char *op_name) {
const ge::Operator &op) {
// Validate ranks
Shape unused_shape;
if (WithRank(indices, NUM_VALUE2, unused_shape, op_name) != GRAPH_SUCCESS) { // The rank is 2.
if (WithRank(indices, 2, unused_shape, op) != GRAPH_SUCCESS) { // The rank is 2.
std::string err_msg = ConcatString("failed to call WithRank function, indices has wrong shape",
DebugString(indices.GetShape().GetDims()), ", it should be 2D");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(string(op_name), err_msg);
AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
if (WithRank(values, 1, unused_shape, op_name) != GRAPH_SUCCESS) {
if (WithRank(values, 1, unused_shape, op) != GRAPH_SUCCESS) {
std::string err_msg = ConcatString("failed to call WithRank function, values has wrong shape",
DebugString(values.GetShape().GetDims()), ", it should be 1D");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(string(op_name), err_msg);
AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
if (WithRank(shape, 1, unused_shape, op_name) != GRAPH_SUCCESS) {
if (WithRank(shape, 1, unused_shape, op) != GRAPH_SUCCESS) {
std::string err_msg = ConcatString("failed to call WithRank function, shape has wrong shape",
DebugString(shape.GetShape().GetDims()), ", it should be 1D");
AICPU_INFER_SHAPE_CALL_ERR_REPORT(string(op_name), err_msg);
AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
@ -998,9 +1026,8 @@ graphStatus ValidateSparseTensor(const TensorDesc &indices, const TensorDesc &va
if (ValueKnown(indices_shape, 0)) {
if (ValueKnown(values_shape, 0)) {
if (indices_shape.GetDim(0) != values_shape.GetDim(0)) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
string(op_name), ConcatString("dim[0] of indices and dim[0] of value do not match, ", indices_shape.GetDim(0),
" and ", values_shape.GetDim(0)));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("dim[0] of indices and dim[0] of value do not match, ",
indices_shape.GetDim(0), " and ", values_shape.GetDim(0)));
return GRAPH_FAILED;
}
}
@ -1011,9 +1038,8 @@ graphStatus ValidateSparseTensor(const TensorDesc &indices, const TensorDesc &va
if (ValueKnown(indices_shape, 1)) {
if (ValueKnown(sparse_shape, 0)) {
if (indices_shape.GetDim(1) != sparse_shape.GetDim(0)) {
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name),
ConcatString("dim[1] of indices and dim[0] of sparse do not match, ",
indices_shape.GetDim(1), " and ", sparse_shape.GetDim(0)));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("dim[1] of indices and dim[0] of sparse do not match, ",
indices_shape.GetDim(1), " and ", sparse_shape.GetDim(0)));
return GRAPH_FAILED;
}
}
@ -1086,45 +1112,38 @@ std::string DTypeStr(DataType dtype) {
}
graphStatus SetShapeAndRange(Operator &op, const ShapeAndRange &feed_shape_and_range) {
AscendString op_name;
op.GetName(op_name);
auto context = op.GetInferenceContext();
std::vector<AscendString> marks;
context->GetMarks(marks);
if (!marks.empty()) {
OP_LOGI(TbeGetName(op).c_str(), "Set marks[0] = %s", marks[0].GetString());
OP_LOGI(op, "Set marks[0] = %s", marks[0].GetString());
bool shape_changed = false;
auto aicpu_resource_context = reinterpret_cast<AicpuResourceContext *>(context->GetResourceContext(marks[0]));
if (aicpu_resource_context == nullptr) {
aicpu_resource_context = new (std::nothrow) AicpuResourceContext();
if (aicpu_resource_context == nullptr) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(std::string(op_name.GetString()),
std::string("new AicpuResourceContext failed."));
AICPU_INFER_SHAPE_INNER_ERR_REPORT(op, std::string("new AicpuResourceContext failed."));
return GRAPH_FAILED;
}
aicpu_resource_context->shape_and_range_.push_back(feed_shape_and_range);
if (context->SetResourceContext(marks[0], aicpu_resource_context) != GRAPH_SUCCESS) {
delete aicpu_resource_context;
AICPU_INFER_SHAPE_CALL_ERR_REPORT(std::string(op_name.GetString()),
std::string("set resource context failed."));
AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, std::string("set resource context failed."));
return GRAPH_FAILED;
}
shape_changed = true;
} else {
auto &shape_and_range = aicpu_resource_context->shape_and_range_;
if (shape_and_range.empty()) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(std::string(op_name.GetString()),
std::string("get resource context shape and ranges failed."));
AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, std::string("get resource context shape and ranges failed."));
return GRAPH_FAILED;
}
MergeShapeAndRange(shape_and_range[0], feed_shape_and_range, shape_and_range[0], shape_changed,
op_name.GetString());
MergeShapeAndRange(shape_and_range[0], feed_shape_and_range, shape_and_range[0], shape_changed, op);
}
if (shape_changed) {
if (context->AddChangedResourceKey(marks[0]) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_CALL_ERR_REPORT(std::string(op_name.GetString()),
std::string("add change resource key failed."));
AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, std::string("add change resource key failed."));
return GRAPH_FAILED;
}
}
@ -1133,23 +1152,19 @@ graphStatus SetShapeAndRange(Operator &op, const ShapeAndRange &feed_shape_and_r
}
graphStatus GetShapeAndRange(Operator &op, ShapeAndRange &out, bool &geted, InferenceContextPtr infer_context) {
AscendString op_name;
op.GetName(op_name);
std::vector<AscendString> marks;
infer_context->GetMarks(marks);
if (!marks.empty()) {
OP_LOGI(TbeGetName(op).c_str(), "Get marks[0] = %s", marks[0].GetString());
OP_LOGI(op, "Get marks[0] = %s", marks[0].GetString());
if (infer_context->RegisterReliedOnResourceKey(marks[0]) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(std::string(op_name.GetString()),
std::string("register relied on resource key failed."));
AICPU_INFER_SHAPE_INNER_ERR_REPORT(op, std::string("register relied on resource key failed."));
return GRAPH_FAILED;
}
auto aicpu_resource_context = reinterpret_cast<AicpuResourceContext *>(infer_context->GetResourceContext(marks[0]));
if (aicpu_resource_context != nullptr) {
auto &shape_and_range = aicpu_resource_context->shape_and_range_;
if (shape_and_range.empty()) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(std::string(op_name.GetString()),
std::string("get resource context shape and ranges failed."));
AICPU_INFER_SHAPE_INNER_ERR_REPORT(op, std::string("get resource context shape and ranges failed."));
return GRAPH_FAILED;
}
out.shape_ = shape_and_range[0].shape_;

View File

@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,12 +18,11 @@
* \file common_shape_fns.h
* \brief
*/
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_
#define OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_
#ifndef CUSTOMIZE_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_
#define CUSTOMIZE_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_
#include <string>
#include <vector>
#include <utility>
#include "graph/tensor.h"
#include "graph/operator.h"
#include "graph/op_desc.h"
@ -32,6 +31,7 @@
#include "error_code.h"
namespace ge {
struct ShapeAndRange {
Shape shape_;
std::vector<std::pair<int64_t, int64_t>> shape_range_;
@ -42,6 +42,25 @@ struct AicpuResourceContext : public ResourceContext {
std::vector<ShapeAndRange> shape_and_range_;
};
/**
* Check whether Shape's rank is at least rank
* @param tensor Input tensor
* @param rank expect val of Shape
* @param out Output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRankAtLeast(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op);
/**
* Check whether Shape's rank is at least rank
* @param tensor Input tensor
* @param rank expect val of Shape
* @param out Output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRankAtLeast(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape,
const ge::Operator &op);
/**
* Check whether Shape's rank is at least rank
* @param tensor Input tensor
@ -67,7 +86,7 @@ graphStatus WithRankAtLeast(const GeTensorDescPtr &tensorDesc, int64_t rank, GeS
* @param out Output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRankShape(GeShape &shape, int64_t rank, const char *op_name);
graphStatus WithRankShape(GeShape &shape, int64_t rank, const ge::Operator &op);
/**
* Check whether Shape's rank is equal to rank
@ -76,7 +95,7 @@ graphStatus WithRankShape(GeShape &shape, int64_t rank, const char *op_name);
* @param out Output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const char *op_name);
graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op);
/**
* Check whether Shape's rank is equal to rank
@ -85,7 +104,7 @@ graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const c
* @param out Output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape, const char *op_name);
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape, const ge::Operator &op);
/**
* Check whether Shape's rank is equal to rank
@ -94,7 +113,7 @@ graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &o
* @param out Output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, Shape &out_shape, const char *op_name);
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, Shape &out_shape, const ge::Operator &op);
/**
* Check whether dim is equal to value
@ -103,7 +122,7 @@ graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, Shape &out
* @param out Output dim
* @return status whether Dim is equal to value
*/
graphStatus WithValue(int64_t dim, int64_t value, int64_t &out, const char *op_name);
graphStatus WithValue(int64_t dim, int64_t value, int64_t &out, const ge::Operator &op);
/**
* Merge two shapes
@ -113,7 +132,7 @@ graphStatus WithValue(int64_t dim, int64_t value, int64_t &out, const char *op_n
* @param prefix_out prefix out shape val
* @return status whether this operation success
*/
graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape &prefix_out, const char *op_name);
graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape &prefix_out, const ge::Operator &op);
/**
* Merge two dims of Shape
@ -131,7 +150,7 @@ graphStatus Merge(int64_t dim1, int64_t dim2, int64_t &out);
* @param out merged shape val
* @return status whether this operation success
*/
graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const char *op_name);
graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const ge::Operator &op);
/**
* Merge two shapes
@ -140,7 +159,7 @@ graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const char *op_n
* @param out merged Geshape val
* @return status whether this operation success
*/
graphStatus Merge(const GeShape &s0, const GeShape &s1, GeShape &out, const char *op_name);
graphStatus Merge(const GeShape &s0, const GeShape &s1, GeShape &out, const ge::Operator &op);
/**
* Merge two shapes
@ -171,7 +190,7 @@ void MergeRange(const std::vector<std::pair<int64_t, int64_t>> &shared_shape_ran
* @return status whether this operation success
*/
graphStatus MergeShapeAndRange(const ShapeAndRange &shared_shape_and_range, const ShapeAndRange &value_shape_and_range,
ShapeAndRange &out, bool &shape_changed, const char *op_name);
ShapeAndRange &out, bool &shape_changed, const ge::Operator &op);
/**
* Replace one dim in a given shape
@ -181,7 +200,7 @@ graphStatus MergeShapeAndRange(const ShapeAndRange &shared_shape_and_range, cons
* @param out new shape
* @return status whether this operation success
*/
graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Shape &out, const char *op_name);
graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Shape &out, const ge::Operator &op);
/**
* Replace one dim in a given shape
@ -191,7 +210,7 @@ graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Sh
* @param out new shape
* @return status whether this operation success
*/
graphStatus ReplaceDim(const GeShape &s, int64_t dim_index_in, int64_t new_dim, GeShape &out, const char *op_name);
graphStatus ReplaceDim(const GeShape &s, int64_t dim_index_in, int64_t new_dim, GeShape &out, const ge::Operator &op);
/**
* Check if it satisfies 0 <= index < limit
@ -218,7 +237,7 @@ graphStatus Add(int64_t dim1, int64_t dim2, int64_t &out);
* @param out Subtract dim val
* @return status whether this operation success
*/
graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const char *op_name);
graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const ge::Operator &op);
/**
* Get SubShape according to start end index and step size stride
@ -229,7 +248,7 @@ graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const char *op_na
* @param out sub shape output
* @return status whether this operation success
*/
graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride, Shape &out, const char *op_name);
graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride, Shape &out, const ge::Operator &op);
/**
* Get SubShape according to start end index and step size stride
@ -251,7 +270,8 @@ graphStatus SubShape(const GeShape &s, size_t start, size_t end, size_t stride,
* @param out sub shape output
* @return status whether this operation success
*/
graphStatus SubShape(const GeShape &s, int64_t start, int64_t end, int64_t stride, GeShape &out, const char *op_name);
graphStatus SubShape(const GeShape &s, int64_t start, int64_t end, int64_t stride, GeShape &out,
const ge::Operator &op);
/**
* Concatenate two shape
@ -294,17 +314,16 @@ graphStatus Vector(int64_t dim, Shape &out);
* @param out shape
* @return status whether this operation success
*/
graphStatus MakeShapeFromShapeTensor(const Tensor &tensor, Shape &out, const char *op_name);
graphStatus MakeShapeFromShapeTensor(const Tensor &tensor, Shape &out, const ge::Operator &op);
/**
* Make shape from shape tensor
* @param op Operator
* @param dst_name const string &
* @param out GeShape
* @param op_name const char *
* @return status whether this operation success
*/
graphStatus MakeShapeFromShapeTensor(Operator &op, const string &dst_name, GeShape &out, const char *op_name);
graphStatus MakeShapeFromShapeTensor(Operator &op, const string &dst_name, GeShape &out);
/**
* Make dim from scalar tensor
@ -312,7 +331,7 @@ graphStatus MakeShapeFromShapeTensor(Operator &op, const string &dst_name, GeSha
* @param out shape
* @return status whether this operation success
*/
graphStatus MakeDimForScalarInput(const Tensor &tensor, int64_t &out, const char *op_name);
graphStatus MakeDimForScalarInput(const Tensor &tensor, int64_t &out, const ge::Operator &op);
/**
* Check whether Shape's rank is at most rank
@ -321,7 +340,7 @@ graphStatus MakeDimForScalarInput(const Tensor &tensor, int64_t &out, const char
* @param out output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRankAtMost(const TensorDesc &tensor, int64_t rank, Shape &out, const char *op_name);
graphStatus WithRankAtMost(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op);
/**
* Check whether Shape's rank is at most rank
@ -330,7 +349,7 @@ graphStatus WithRankAtMost(const TensorDesc &tensor, int64_t rank, Shape &out, c
* @param out output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRankAtMost(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape, const char *op_name);
graphStatus WithRankAtMost(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape, const ge::Operator &op);
/**
* make a empty dim shape
@ -357,7 +376,7 @@ graphStatus UnchangedShape(Operator &op, const string input_name, const string o
* @return status whether this operation success
*/
graphStatus Divide(const int64_t dividend, const int64_t divisor, const bool evenlyDivisible, int64_t &out,
const char *op_name);
const ge::Operator &op);
/**
* check shape fully defined or not
@ -410,7 +429,7 @@ bool ValueKnown(const Shape &shape, const size_t &dim_index);
* @return status whether this operation success
*/
graphStatus ValidateSparseTensor(const TensorDesc &indices, const TensorDesc &values, const TensorDesc &shape,
const char *op_name);
const ge::Operator &op);
/**
* Fill op_desc with input shape
@ -448,6 +467,7 @@ std::string DTypeStr(DataType dtype);
graphStatus SetShapeAndRange(Operator &op, const ShapeAndRange &feed_shape_and_range);
graphStatus GetShapeAndRange(Operator &op, ShapeAndRange &out, bool &geted, InferenceContextPtr infer_context);
} // namespace ge
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_
#endif // CUSTOMIZE_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_

View File

@ -19,8 +19,8 @@
* \brief
*/
#ifndef CANN_OPS_BUILT_IN_CONTEXT_UTIL_H_
#define CANN_OPS_BUILT_IN_CONTEXT_UTIL_H_
#ifndef CANN_CUSTOMIZE_CONTEXT_UTIL_H_
#define CANN_CUSTOMIZE_CONTEXT_UTIL_H_
#include "runtime/infer_shape_context.h"
#include "runtime/tiling_context.h"
@ -43,4 +43,4 @@ namespace ops {
return ret; \
}
} // namespace ops
#endif // CANN_OPS_BUILT_IN_CONTEXT_UTIL_H_
#endif // CANN_CUSTOMIZE_CONTEXT_UTIL_H_

View File

@ -18,8 +18,8 @@
* \file error_code.h
* \brief
*/
#ifndef OPS_COMMON_INC_ERROR_CODE_H_
#define OPS_COMMON_INC_ERROR_CODE_H_
#ifndef CUSTOMIZE_OP_PROTO_UTILS_ERROR_CODE_H_
#define CUSTOMIZE_OP_PROTO_UTILS_ERROR_CODE_H_
namespace ge {
// error code for report purpose.
@ -59,4 +59,4 @@ enum ViewErrorCode {
};
} // namespace ge
#endif // OPS_COMMON_INC_ERROR_CODE_H_
#endif // CUSTOMIZE_OP_PROTO_UTILS_ERROR_CODE_H_

View File

@ -18,13 +18,12 @@
* \file error_util.h
* \brief
*/
#ifndef OPS_COMMON_INC_ERROR_UTIL_H_
#define OPS_COMMON_INC_ERROR_UTIL_H_
#ifndef CUSTOMIZE_OP_PROTO_UTILS_ERROR_UTIL_H_
#define CUSTOMIZE_OP_PROTO_UTILS_ERROR_UTIL_H_
#include <sstream>
#include <string>
#include <vector>
#include <utility>
#include "common/util/error_manager/error_manager.h"
#include "error_code.h"
#include "graph/operator.h"
@ -227,4 +226,4 @@ void GeInfershapeErrReport(const std::string &op_name, const std::string &op_typ
void CommonRuntimeErrLog(const std::string &opname, const std::string &description);
} // namespace ge
#endif // OPS_COMMON_INC_ERROR_UTIL_H_
#endif // CUSTOMIZE_OP_PROTO_UTILS_ERROR_UTIL_H_

View File

@ -0,0 +1,225 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file images_ops_shape_fns.cpp
* \brief
*/
#include "image_ops_shape_fns.h"
#include "op_log.h"
#include "error_util.h"
#include "graph/utils/op_desc_utils.h"
namespace ge {
graphStatus ColorspaceShapeFn(Operator &op, const std::string &output_name) {
Shape shape;
graphStatus status = WithRankAtLeast(op.GetInputDesc(0), 1, shape, op);
if (status != GRAPH_SUCCESS) {
AscendString op_name;
op.GetName(op_name);
OP_LOGE(op_name.GetString(), "input[images] must 1-D or higher rank.");
return GRAPH_PARAM_INVALID;
}
int64_t dim = op.GetInputDesc(0).GetShape().GetDims().back();
if (dim != 3) {
AscendString op_name;
op.GetName(op_name);
OP_LOGE(op_name.GetString(), "input[images] last dimension must be size 3.");
return GRAPH_PARAM_INVALID;
}
TensorDesc desc = op.GetOutputDescByName(output_name.c_str());
desc.SetShape(Shape(shape));
return op.UpdateOutputDesc(output_name.c_str(), desc);
}
graphStatus ResizeShapeFn(Operator &op, const std::string &input_name, const std::string &size_input_name,
const std::string &output_name) {
if (op.GetInputDesc(0).GetShape().GetDims() == UNKNOWN_RANK) {
std::vector<int64_t> output_shape(4, UNKNOWN_DIM);
TensorDesc td = op.GetOutputDescByName(output_name.c_str());
td.SetShape(Shape(output_shape));
op.UpdateOutputDesc(output_name.c_str(), td);
return GRAPH_SUCCESS;
}
Shape shape;
graphStatus status = WithRank(op.GetInputDesc(0), 4, shape, op);
if (status != GRAPH_SUCCESS) {
AscendString op_name;
op.GetName(op_name);
OP_LOGE(op_name.GetString(), "input[images] must 4-D.");
return GRAPH_PARAM_INVALID;
}
auto dims = op.GetInputDesc(0).GetShape().GetDims();
auto channel_dim = dims[3];
TensorDesc input_td = op.GetInputDesc(0);
if (static_cast<ge::Format>(ge::GetPrimaryFormat(input_td.GetFormat())) == FORMAT_NCHW) {
channel_dim = dims[1];
}
return SetOutputToSizedImage(op, dims[0], size_input_name, channel_dim, output_name);
}
graphStatus SetOutputToSizedImage(Operator &op, const int64_t batch_dim, const std::string &size_input_name,
const int64_t channel_dim, const std::string &output_name) {
Shape size_shape;
graphStatus status = WithRank(op.GetInputDescByName(size_input_name.c_str()), 1, size_shape, op);
if (status != GRAPH_SUCCESS) {
AscendString op_name;
op.GetName(op_name);
OP_LOGE(op_name.GetString(), "input size must be 1-D.");
return GRAPH_PARAM_INVALID;
}
auto size_dims = op.GetInputDescByName(size_input_name.c_str()).GetShape().GetDims();
if (size_dims[0] != 2) {
AscendString op_name;
op.GetName(op_name);
OP_LOGE(op_name.GetString(), "input size must be 1-D tensor of 2 elements.");
return GRAPH_PARAM_INVALID;
}
std::vector<std::string> input_infer_depends = {size_input_name};
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
op_desc->SetOpInferDepends(input_infer_depends);
DataType data_type = DT_FLOAT;
// Update DataType when Attr "dtype" is set, used for op ResizeBicubic
if (op.GetAttr("dtype", data_type) == GRAPH_SUCCESS) {
if ((data_type != DT_FLOAT) && (data_type != DT_UINT8)) {
OP_LOGW(op_desc->GetName().c_str(), "Attr dtype should only be DT_FLOAT or DT_UNIT8");
} else {
OP_LOGI(op_desc->GetName().c_str(), "Update DataType from attr, which is %d", data_type);
}
}
Tensor size_tensor;
TensorDesc td = op.GetOutputDescByName(output_name.c_str());
status = op.GetInputConstData(size_input_name.c_str(), size_tensor);
if (status != GRAPH_SUCCESS) {
td.SetDataType(data_type);
std::vector<int64_t> out_shape;
TensorDesc input_td = op.GetInputDesc(0);
if (static_cast<ge::Format>(ge::GetPrimaryFormat(input_td.GetFormat())) == FORMAT_NCHW) {
out_shape.push_back(batch_dim);
out_shape.push_back(channel_dim);
out_shape.push_back(-1);
out_shape.push_back(-1);
} else if (static_cast<ge::Format>(ge::GetPrimaryFormat(input_td.GetFormat())) == FORMAT_NHWC) {
out_shape.push_back(batch_dim);
out_shape.push_back(-1);
out_shape.push_back(-1);
out_shape.push_back(channel_dim);
} else {
std::string error_msg = "Not supported this format";
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), error_msg);
}
td.SetShape(Shape(out_shape));
op.UpdateOutputDesc(output_name.c_str(), td);
return GRAPH_SUCCESS;
}
const int32_t *size_data = reinterpret_cast<const int32_t *>(size_tensor.GetData());
int64_t size_width = static_cast<int64_t>(size_data[1]);
int64_t size_height = static_cast<int64_t>(size_data[0]);
std::vector<int64_t> output_shape;
TensorDesc input_td = op.GetInputDesc(0);
if (static_cast<ge::Format>(ge::GetPrimaryFormat(input_td.GetFormat())) == FORMAT_NCHW) {
output_shape.push_back(batch_dim);
output_shape.push_back(channel_dim);
output_shape.push_back(size_height);
output_shape.push_back(size_width);
} else if (static_cast<ge::Format>(ge::GetPrimaryFormat(input_td.GetFormat())) == FORMAT_NHWC) {
output_shape.push_back(batch_dim);
output_shape.push_back(size_height);
output_shape.push_back(size_width);
output_shape.push_back(channel_dim);
} else {
OP_LOGE(TbeGetName(op).c_str(), "Not supported this format");
}
td.SetShape(Shape(output_shape));
return op.UpdateOutputDesc(output_name.c_str(), td);
}
graphStatus EncodeImageShapeFn(Operator &op) {
Shape unused_shape;
if (WithRank(op.GetInputDesc(0), 3, unused_shape, op) != GRAPH_SUCCESS) {
AscendString op_name;
op.GetName(op_name);
OP_LOGE(op_name.GetString(), "input rank must be 3 .");
return GRAPH_FAILED;
}
Shape output_shape;
(void)Scalar(output_shape);
TensorDesc output_tensor = op.GetOutputDescByName("contents");
output_tensor.SetDataType(DT_STRING);
output_tensor.SetShape(output_shape);
return op.UpdateOutputDesc("contents", output_tensor);
}
graphStatus DecodeImageShapeFn(Operator &op) {
int channels;
if (op.GetAttr("channels", channels) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("Get attr[channels] failed"));
return GRAPH_FAILED;
}
if (channels != 0 && channels != 1 && channels != 3 && channels != 4) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("attr[Channels] must be 0,1,3,or 4"));
return GRAPH_FAILED;
}
DataType dtype;
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("Get attr[dtype] failed"));
return GRAPH_FAILED;
}
std::vector<int64_t> dims;
if (channels == 0) {
dims = {ge::UNKNOWN_DIM, ge::UNKNOWN_DIM, ge::UNKNOWN_DIM};
} else {
dims = {ge::UNKNOWN_DIM, ge::UNKNOWN_DIM, channels};
}
Shape output_shape(dims);
TensorDesc output_tensor = op.GetOutputDesc(0);
output_tensor.SetDataType(dtype);
output_tensor.SetShape(output_shape);
if (op.UpdateOutputDesc("image", output_tensor) != GRAPH_SUCCESS) {
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("Update OutputDesc[image] failed"));
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
bool DimsAllEqualOrUnknown(std::initializer_list<int64_t> &&inputs, int64_t unknown_dim_val) {
auto it = inputs.begin();
for (; it != inputs.end() && (*it == unknown_dim_val); ++it) {
}
if (it == inputs.end()) {
return true;
}
for (auto default_dim_val = *(it++); it != inputs.end(); ++it) {
if (*it != default_dim_val && *it != unknown_dim_val) {
return false;
}
}
return true;
}
} // namespace ge

View File

@ -0,0 +1,84 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file images_ops_shape_fns.h
* \brief
*/
#ifndef CUSTOMIZE_OP_PROTO_UTIL_IMAGES_OPS_SHAPE_FNS_H_
#define CUSTOMIZE_OP_PROTO_UTIL_IMAGES_OPS_SHAPE_FNS_H_
#include <string>
#include "common_shape_fns.h"
namespace ge {
/**
* ColorspaceShapeFn, infereshape function of colorspace op
* @param op, Operators that need to reason about shape
* @param output_name, the name of output
* @return status whether infer shape success
*/
graphStatus ColorspaceShapeFn(Operator &op, const std::string &output_name);
/**
* ResizeShapeFn, infereshape function of image resize op
* @param op, Operators that need to reason about shape
* @param input_name, the name of input
* @param size_input_name, the name of size input name
* @param output_name, the name of output
* @return status whether infer shape success
*/
graphStatus ResizeShapeFn(Operator &op, const std::string &input_name, const std::string &size_input_name,
const std::string &output_name);
/**
* SetOutputToSizedImage, set output shape of size image op
* @param op, Operators that need to set output shape
* @param batch_dim, the dim of batch
* @param size_input_name, the name of size input
* @param channel_dim, the dim of channel
* @param output_name, the name of output
* @return status whether set output shape success
*/
graphStatus SetOutputToSizedImage(Operator &op, const int64_t batch_dim, const std::string &size_input_name,
const int64_t channel_dim, const std::string &output_name);
/**
* EncodeImageShapeFn, infereshape function of EncodeImage op
* @param op, Operators that need to reason about shape
* @return status whether infer shape success
*/
graphStatus EncodeImageShapeFn(Operator &op);
/**
* DecodeImageShapeFn, infereshape function of DecodeImage op
* @param op, Operators that need to reason about shape
* @return status whether infer shape success
*/
graphStatus DecodeImageShapeFn(Operator &op);
/**
* EncodeImageShapeFn, infereshape function of EncodeImage op
* @param inputs, the list of impu dims
* @param unknown_dim_val, the definithion of UNKNOWN_DIM
* @return status whether infer shape success
*/
bool DimsAllEqualOrUnknown(std::initializer_list<int64_t> &&inputs, int64_t unknown_dim_val = UNKNOWN_DIM);
} // namespace ge
#endif // CUSTOMIZE_OP_PROTO_UTIL_IMAGES_OPS_SHAPE_FNS_H_

View File

@ -0,0 +1,162 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file linalg_ops_shape_fns.cpp
* \brief
*/
#include "linalg_ops_shape_fns.h"
#include "op_log.h"
#include "common_shape_fns.h"
namespace ge {
constexpr int64_t kRnak = 2;
constexpr int64_t kEnd = -2;
graphStatus MakeBatchSquareMatrix(const TensorDesc &tensor, Shape &out, const ge::Operator &op) {
Shape s;
if (WithRankAtLeast(tensor, kRnak, s, op) == GRAPH_FAILED) {
OP_LOGE(op, "input tensor's rank at least 2.");
return GRAPH_FAILED;
}
size_t existing = s.GetDimNum();
int64_t dim1 = s.GetDim(existing - 2);
int64_t dim2 = s.GetDim(existing - 1);
int64_t out_dim = 0;
if (Merge(dim1, dim2, out_dim) == GRAPH_FAILED) {
OP_LOGE(op, "Merge two dimension failed.");
return GRAPH_FAILED;
}
Shape batch_shape;
if (SubShape(s, 0, kEnd, 1, batch_shape, op) == GRAPH_FAILED) {
OP_LOGE(op, "Get SubShape batch_shape Failed.");
return GRAPH_FAILED;
}
if (Concatenate(batch_shape, Shape({out_dim, out_dim}), out) == GRAPH_FAILED) {
OP_LOGE(op, "Concatenate batch_shape and out_dim Failed.");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
graphStatus MakeBatchSquareMatrix(const GeTensorDescPtr &tensor_desc, GeShape &out, const ge::Operator &op) {
GeShape ge_shape;
if (WithRankAtLeast(tensor_desc, kRnak, ge_shape, op) == GRAPH_FAILED) {
OP_LOGE(op, "Input tensor's rank at least 2.");
return GRAPH_FAILED;
}
Shape s(ge_shape.GetDims());
size_t existing = s.GetDimNum();
int64_t dim1 = s.GetDim(existing - 2);
int64_t dim2 = s.GetDim(existing - 1);
int64_t out_dim = 0;
if (Merge(dim1, dim2, out_dim) == GRAPH_FAILED) {
OP_LOGE(op, "Merge two dimension failed.");
return GRAPH_FAILED;
}
if (RankKnown(ge_shape)) {
GeShape batch_shape;
if (SubShape(ge_shape, 0, kEnd, 1, batch_shape, op) == GRAPH_FAILED) {
OP_LOGE(op, "Get subShape batch_shape failed.");
return GRAPH_FAILED;
}
if (Concatenate(batch_shape, GeShape({out_dim, out_dim}), out) == GRAPH_FAILED) {
OP_LOGE(op, "Concatenate batch_shape and out_dim failed.");
return GRAPH_FAILED;
}
} else {
GeShape unknown_shape(ge::UNKNOWN_SHAPE);
out = unknown_shape;
}
return GRAPH_SUCCESS;
}
graphStatus MatrixSolve(const TensorDesc &tensor1, const TensorDesc &tensor2, bool square, Shape &out,
const ge::Operator &op) {
Shape lhs;
Shape rhs;
if (square) {
if (MakeBatchSquareMatrix(tensor1, lhs, op) == GRAPH_FAILED) {
OP_LOGE(op, "MatrixSolve first input tensor Make Batch Square Matrix failed.");
return GRAPH_FAILED;
}
} else {
if (WithRankAtLeast(tensor1, kRnak, lhs, op) == GRAPH_FAILED) {
OP_LOGE(op, "MatrixSolve func first input tensor must be at least 2.");
return GRAPH_FAILED;
}
}
if (WithRankAtLeast(tensor2, kRnak, rhs, op) == GRAPH_FAILED) {
OP_LOGE(op, "MatrixSolve func second input tensor must be at least 2.");
return GRAPH_FAILED;
}
Shape lhs_batch;
Shape rhs_batch;
// Make the common batch subshape.
if (SubShape(lhs, 0, kEnd, 1, lhs_batch, op) == GRAPH_FAILED) {
OP_LOGE(op, "SubShape lhs_batch in MatrixSolve func failed.");
return GRAPH_FAILED;
}
if (SubShape(rhs, 0, kEnd, 1, rhs_batch, op) == GRAPH_FAILED) {
OP_LOGE(op, "SubShape rhs_batch in MatrixSolve func failed.");
return GRAPH_FAILED;
}
int64_t lhs_batch_dim;
// Make sure the batch dimensions match between lhs and rhs.
if (Merge(lhs_batch.GetDimNum(), rhs_batch.GetDimNum(), lhs_batch_dim) == GRAPH_FAILED) {
OP_LOGE(op, "Merge dimension lhs_batch and rhs_batch failed.");
return GRAPH_FAILED;
}
int64_t dim_val = 0;
int64_t lhs_rank = lhs.GetDimNum();
int64_t rhs_rank = rhs.GetDimNum();
int64_t dim_lhs = lhs.GetDim(lhs_rank - 2);
int64_t dim_rhs = rhs.GetDim(rhs_rank - 2);
// lhs and rhs have the same value for m to be compatible.
if (Merge(dim_lhs, dim_rhs, dim_val) == GRAPH_FAILED) {
OP_LOGE(op, "Merge dimension dim_lhs and dim_rhs failed.");
return GRAPH_FAILED;
}
int64_t dim_ret = lhs.GetDim(lhs_rank - 1);
if (square) {
if (Merge(dim_val, dim_ret, dim_ret) == GRAPH_FAILED) {
OP_LOGE(op, "Merge dimension dim_val and dim_ret failed.");
return GRAPH_FAILED;
}
}
Shape s;
// Build final shape (batch_shape + n + k) in <out>.
if (Concatenate(lhs_batch, Shape({dim_ret}), s) == GRAPH_FAILED) {
OP_LOGE(op, "Concatenate Two Shape failed.");
return GRAPH_FAILED;
}
int64_t dims = rhs.GetDim(rhs_rank - 1);
if (Concatenate(s, Shape({dims}), s) == GRAPH_FAILED) {
OP_LOGE(op, "Concatenate Shape s and dims failed.");
return GRAPH_FAILED;
}
out = s;
return GRAPH_SUCCESS;
}
} // namespace ge

View File

@ -0,0 +1,57 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file linalg_ops_shape_fns.h
* \brief
*/
#ifndef CUSTOMIZE_OP_PROTO_UTIL_LINALG_OPS_SHAPE_FNS_H_
#define CUSTOMIZE_OP_PROTO_UTIL_LINALG_OPS_SHAPE_FNS_H_
#include "graph/tensor.h"
#include "graph/ge_tensor.h"
#include "graph/op_desc.h"
namespace ge {
/**
* Generate a square matrix's Shape
* @param tensor Input tensor
* @param out Output Shape
* @return status whether this operation success
*/
graphStatus MakeBatchSquareMatrix(const TensorDesc &tensor, Shape &out, const ge::Operator &op);
/**
* Generate a square matrix's Shape
* @param tensor Input ge tensor desc ptr
* @param out Output GeShape
* @return status whether this operation success
*/
graphStatus MakeBatchSquareMatrix(const GeTensorDescPtr &tensor_desc, GeShape &out, const ge::Operator &op);
/**
* Solving linear equations from matrices common shape func
* @param tensor1 first input tensor
* @param tensor2 second input tensor
* @param square whether matrix is square
* @param out Output Shape
* @return status whether this operation success
*/
graphStatus MatrixSolve(const TensorDesc &tensor1, const TensorDesc &tensor2, bool square, Shape &out,
const ge::Operator &op);
} // namespace ge
#endif // CUSTOMIZE_OP_PROTO_UTIL_LINALG_OPS_SHAPE_FNS_H_

View File

@ -0,0 +1,176 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file lookup_ops_shape_fns.cpp
* \brief
*/
#include "lookup_ops_shape_fns.h"
#include "common_shape_fns.h"
#include "graph/utils/op_desc_utils.h"
#include "error_util.h"
#include <vector>
#include <limits>
#include "op_log.h"
namespace ge {
graphStatus ValidateTableResourceHandle(Shape keys, std::vector<ShapeAndType> handleData,
ShapeAndType &output_shape_and_type, bool is_lookup, const ge::Operator &op) {
Shape unknown_shape(ge::UNKNOWN_SHAPE);
if (handleData.size() != 2) {
output_shape_and_type.SetShape(unknown_shape);
output_shape_and_type.SetType(DT_UNDEFINED);
} else {
const ShapeAndType &key_shape_and_type = handleData[0];
const ShapeAndType &value_shape_and_type = handleData[1];
// here need to check key_dtype and value_dtype
// but can not get the attr type for key and value
output_shape_and_type.SetType(value_shape_and_type.GetDataType());
if (is_lookup) {
if ((RankKnown(key_shape_and_type.GetShape()) == GRAPH_SUCCESS) && (RankKnown(keys) == GRAPH_SUCCESS)) {
int keys_rank = keys.GetDims().size();
int keys_suffix_rank = key_shape_and_type.GetShape().GetDims().size();
if (keys_rank < keys_suffix_rank) {
std::string err_msg = OtherErrMsg("Expected keys to have suffix");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
for (int d = 0; d < keys_suffix_rank; ++d) {
int new_dim = key_shape_and_type.GetShape().GetDim(d);
if (ReplaceDim(keys, keys_rank - keys_suffix_rank + d, new_dim, keys, op) == GRAPH_FAILED) {
return GRAPH_FAILED;
}
}
std::vector<int64_t> keys_prefix_vec;
keys_prefix_vec.reserve(keys_rank - keys_suffix_rank);
for (int d = 0; d < keys_rank - keys_suffix_rank; ++d) {
keys_prefix_vec.push_back(keys.GetDim(d));
}
Shape keys_prefix = Shape(keys_prefix_vec);
Shape temp_shape = output_shape_and_type.GetShape();
if (Concatenate(keys_prefix, value_shape_and_type.GetShape(), temp_shape) == GRAPH_FAILED) {
return GRAPH_FAILED;
}
output_shape_and_type.SetShape(temp_shape);
} else {
output_shape_and_type.SetShape(unknown_shape);
}
} else {
Shape temp_shape = output_shape_and_type.GetShape();
if (Concatenate(keys, value_shape_and_type.GetShape(), temp_shape) == GRAPH_FAILED) {
return GRAPH_FAILED;
}
output_shape_and_type.SetShape(temp_shape);
}
}
return GRAPH_SUCCESS;
}
graphStatus ValidateTableResourceHandle(const Operator &op, Shape &keys, const DataType &key_dtype,
const DataType &value_dtype, const bool &is_lookup,
ShapeAndType &output_shape_and_type) {
if (op.GetInferenceContext() == nullptr) {
OP_LOGI(op, "Op inference context is null, return unknown shape");
output_shape_and_type.SetShape(Shape(UNKNOWN_RANK));
output_shape_and_type.SetType(DT_UNDEFINED);
return GRAPH_SUCCESS;
}
const auto &shapes_and_types = op.GetInferenceContext()->GetInputHandleShapesAndTypes();
if (shapes_and_types.empty()) {
OP_LOGI(op, "Context GetInputHandleShapesAndTypes result is empty, return unknown shape");
output_shape_and_type.SetShape(Shape(UNKNOWN_RANK));
output_shape_and_type.SetType(DT_UNDEFINED);
return GRAPH_SUCCESS;
}
auto handle_data = shapes_and_types[0];
if (handle_data.size() != 2) {
OP_LOGI(op, "handle data(shapes_and_types[0]) size is not 2, return unknown shape");
output_shape_and_type.SetShape(Shape(UNKNOWN_RANK));
output_shape_and_type.SetType(DT_UNDEFINED);
return GRAPH_SUCCESS;
}
const ShapeAndType &key_shape_and_type = handle_data[0];
const ShapeAndType &value_shape_and_type = handle_data[1];
if (key_shape_and_type.GetDataType() != key_dtype) {
std::string err_msg =
GetInputDTypeErrMsg("key_dtype", ConcatString(key_shape_and_type.GetDataType()), ConcatString(key_dtype));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
if (value_shape_and_type.GetDataType() != value_dtype) {
OP_LOGW(op, "trying to read value with wrong dtype, expected %d, got %d", value_shape_and_type.GetDataType(),
value_dtype);
return GRAPH_FAILED;
}
output_shape_and_type.SetType(value_shape_and_type.GetDataType());
if (is_lookup) {
if (RankKnown(key_shape_and_type.GetShape()) && RankKnown(keys)) {
int64_t keys_rank = keys.GetDimNum();
int64_t key_suffix_rank = key_shape_and_type.GetShape().GetDimNum();
if (keys_rank < key_suffix_rank) {
std::string err_msg =
OtherErrMsg(ConcatString("Expected keys to have suffix ", key_suffix_rank, ", but saw shape ", keys_rank));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
for (int64_t d = 0; d < key_suffix_rank; ++d) {
// Ensure the suffix of keys match what's in the Table.
int64_t dim = key_shape_and_type.GetShape().GetDim(d);
if (ReplaceDim(keys, keys_rank - key_suffix_rank + d, dim, keys, op) == GRAPH_FAILED) {
std::string err_msg =
OtherErrMsg(ConcatString("replace dim ", keys_rank - key_suffix_rank + d, " in keys failed"));
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
}
std::vector<int64_t> keys_prefix_vec;
keys_prefix_vec.reserve(keys_rank - key_suffix_rank);
for (int d = 0; d < keys_rank - key_suffix_rank; ++d) {
keys_prefix_vec.push_back(keys.GetDim(d));
}
Shape keys_prefix(keys_prefix_vec);
auto temp_shape = output_shape_and_type.GetShape();
if (Concatenate(keys_prefix, value_shape_and_type.GetShape(), temp_shape) == GRAPH_FAILED) {
std::string err_msg = OtherErrMsg("concatenate keys_prefix and value shape failed");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
output_shape_and_type.SetShape(temp_shape);
} else {
output_shape_and_type.SetShape(Shape(UNKNOWN_RANK));
}
} else {
auto temp_shape = output_shape_and_type.GetShape();
if (Concatenate(keys, value_shape_and_type.GetShape(), temp_shape) == GRAPH_FAILED) {
std::string err_msg = OtherErrMsg("concatenate keys and value shape failed");
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
return GRAPH_FAILED;
}
output_shape_and_type.SetShape(temp_shape);
}
return GRAPH_SUCCESS;
}
} // namespace ge

View File

@ -0,0 +1,57 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file lookup_ops_shape_fns.h
* \brief
*/
#ifndef CUSTOMIZE_OP_PROTO_UTIL_LOOKUP_OPS_SHAPE_FNS_H_
#define CUSTOMIZE_OP_PROTO_UTIL_LOOKUP_OPS_SHAPE_FNS_H_
#include <vector>
#include "graph/tensor.h"
#include "graph/inference_context.h"
#include "graph/operator.h"
#include "graph/op_desc.h"
#include "graph/utils/op_desc_utils.h"
namespace ge {
/**
* Validate table resource handle
* @param keys keys of the shape
* @param handleData vector of handle data
* @param output_shape_and_type shape and type that created
* @param is_lookup if is lookup
* @return status whether this operation success
*/
graphStatus ValidateTableResourceHandle(Shape keys, std::vector<ShapeAndType> handleData,
ShapeAndType &output_shape_and_type, bool is_lookup, const ge::Operator &op);
/**
* Validate table resource handle
* @param op op context
* @param keys keys of the shape
* @param handleData vector of handle data
* @param output_shape_and_type shape and type that created
* @param is_lookup if is lookup
* @return status whether this operation success
*/
graphStatus ValidateTableResourceHandle(const Operator &op, Shape &keys, const DataType &key_dtype,
const DataType &value_dtype, const bool &is_lookup,
ShapeAndType &output_shape_and_type);
} // namespace ge
#endif // CUSTOMIZE_OP_PROTO_UTIL_LOOKUP_OPS_SHAPE_FNS_H_

View File

@ -0,0 +1,317 @@
/**
* Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file op_attr.h
* \brief
*/
#ifndef CUSTOMIZE_OP_PROTO_UTILS_OP_ATTR_H_
#define CUSTOMIZE_OP_PROTO_UTILS_OP_ATTR_H_
#include <vector>
#include "op_log.h"
#include "external/graph/operator.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/attr_utils.h"
namespace ops {
using namespace ge;
// attr base struct
struct AttrBase {
const int32_t attr_idx;
const std::string attr_name;
AttrBase(const int attr_idx, const std::string &attr_name) : attr_idx(attr_idx), attr_name(attr_name) {}
};
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const struct AttrBase &attr_info, int32_t &value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetInt(op_desc, attr_info.attr_name, value)) {
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.attr_name.c_str());
return false;
}
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.attr_name.c_str(), value);
return true;
}
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const struct AttrBase &attr_info, int64_t &value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetInt(op_desc, attr_info.attr_name, value)) {
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.attr_name.c_str());
return false;
}
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.attr_name.c_str(), value);
return true;
}
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @param [in] default_value: default_value
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const struct AttrBase &attr_info, int64_t &value, const int64_t default_value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetInt(op_desc, attr_info.attr_name, value)) {
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. set the default value", attr_info.attr_name.c_str());
value = default_value;
}
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.attr_name.c_str(), value);
return true;
}
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const struct AttrBase &attr_info, uint64_t &value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetInt(op_desc, attr_info.attr_name, value)) {
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.attr_name.c_str());
return false;
}
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.attr_name.c_str(), value);
return true;
}
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const std::pair<int64_t, std::string> &attr_info, int32_t &value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetInt(op_desc, attr_info.second, value)) {
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.second.c_str());
return false;
}
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.second.c_str(), value);
return true;
}
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @param [in] default_value: default_value
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const std::pair<int64_t, std::string> &attr_info, int32_t &value,
const int32_t default_value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetInt(op_desc, attr_info.second, value)) {
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. set the default value", attr_info.second.c_str());
value = default_value;
}
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.second.c_str(), value);
return true;
}
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const std::pair<int64_t, std::string> &attr_info, int64_t &value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetInt(op_desc, attr_info.second, value)) {
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.second.c_str());
return false;
}
OP_LOGD("GetAttrValue", "Get the attr of %s is %lld", attr_info.second.c_str(), value);
return true;
}
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @param [in] default_value: default_value
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const std::pair<int64_t, std::string> &attr_info, int64_t &value,
const int64_t default_value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetInt(op_desc, attr_info.second, value)) {
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. set the default value", attr_info.second.c_str());
value = default_value;
}
OP_LOGD("GetAttrValue", "Get the attr of %s is %lld", attr_info.second.c_str(), value);
return true;
}
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @param [in] default_value: default_value
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const std::pair<int64_t, std::string> &attr_info, uint32_t &value,
const uint32_t default_value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetInt(op_desc, attr_info.second, value)) {
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. set the default value", attr_info.second.c_str());
value = default_value;
}
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.second.c_str(), value);
return true;
}
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @param [in] default_value: default_value
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const std::pair<int64_t, std::string> &attr_info, bool &value,
const bool default_value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetBool(op_desc, attr_info.second, value)) {
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. set the default value", attr_info.second.c_str());
value = default_value;
}
return true;
}
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const std::pair<int64_t, std::string> &attr_info, bool &value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetBool(op_desc, attr_info.second, value)) {
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.second.c_str());
return false;
}
return true;
}
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const std::pair<int64_t, std::string> &attr_info, vector<int64_t> &value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetListInt(op_desc, attr_info.second, value)) {
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.second.c_str());
return false;
}
return true;
}
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const std::pair<int64_t, std::string> &attr_info, vector<int32_t> &value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetListInt(op_desc, attr_info.second, value)) {
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.second.c_str());
return false;
}
return true;
}
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const struct AttrBase &attr_info, int32_t &value, const int32_t default_value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetInt(op_desc, attr_info.attr_name, value)) {
OP_LOGW("GetAttrValue", "Fail to get attr %s automatically. use default value", attr_info.attr_name.c_str());
value = default_value;
}
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.attr_name.c_str(), value);
return true;
}
/*
* @brief: read constvalue from paras store into values
* @param [in] paras: ge::Operator
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
* @param [out] value: store value.
* @return bool: flag of success or not
*/
template <typename T>
bool GetAttrValue(const T &paras, const struct AttrBase &attr_info, float32_t &value) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
if (!AttrUtils::GetFloat(op_desc, attr_info.attr_name, value)) {
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.attr_name.c_str());
return false;
}
OP_LOGD("GetAttrValue", "Get the attr of %s is %f", attr_info.attr_name.c_str(), value);
return true;
}
} // namespace ops
#endif // CUSTOMIZE_OP_PROTO_UTILS_OP_ATTR_H_

View File

@ -19,8 +19,8 @@
* \brief common util for op, in this file only original type or class in C++ allowed
*/
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_
#define OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_
#ifndef CUSTOMIZE_OP_PROTO_UTIL_OP_COMMON_UTIL_H_
#define CUSTOMIZE_OP_PROTO_UTIL_OP_COMMON_UTIL_H_
#include <set>
#include <string>
@ -69,4 +69,4 @@ std::string to_string(const std::set<T> &items) {
}
} // namespace ops
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_
#endif // CUSTOMIZE_OP_PROTO_UTIL_OP_COMMON_UTIL_H_

View File

@ -19,8 +19,8 @@
* \brief
*/
#ifndef CANN_OPS_BUILT_IN_OPS_CONST_H_
#define CANN_OPS_BUILT_IN_OPS_CONST_H_
#ifndef CANN_CUSTOMIZE_OPS_CONST_H_
#define CANN_CUSTOMIZE_OPS_CONST_H_
#include <vector>
#include "external/graph/operator.h"
@ -283,4 +283,4 @@ bool GetConstIntToShape(T *context, const int64_t const_idx, gert::Shape &const_
return true;
}
} // namespace ops
#endif // CANN_OPS_BUILT_IN_OPS_CONST_H_
#endif // CANN_CUSTOMIZE_OPS_CONST_H_

View File

@ -18,8 +18,8 @@
* \file op_log.h
* \brief
*/
#ifndef GE_OP_LOG_H
#define GE_OP_LOG_H
#ifndef CUSTOMIZE_OP_PROTO_UTILS_OP_LOG_H
#define CUSTOMIZE_OP_PROTO_UTILS_OP_LOG_H
#include <string>
#include <type_traits>
@ -71,6 +71,12 @@ typename std::enable_if<std::is_same<ge::OpDescPtr, typename std::decay<T>::type
return op_desc != nullptr ? op_desc->GetName().c_str() : "nil";
}
template <class T>
typename std::enable_if<std::is_same<ge::Operator, typename std::decay<T>::type>::value, const char *>::type get_cstr(
const T &op) {
return op.GetName().c_str();
}
template <typename T>
std::string TbeGetName(const T &op) {
ge::AscendString op_ascend_name;
@ -222,30 +228,35 @@ std::string TbeGetOpType(const T &op) {
#define D_FUSION_PASS_LOGD(fmt, ...)
#endif
#define OP_LOGE_IF(condition, return_value, op_name, fmt, ...) \
do { \
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \
if (condition) { \
OP_LOGE(get_cstr(op_name), fmt, ##__VA_ARGS__); \
return return_value; \
} \
#define OP_LOGE_IF(condition, return_value, op_name, fmt, ...) \
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \
do { \
if (condition) { \
OP_LOGE(get_cstr(op_name), fmt, ##__VA_ARGS__); \
return return_value; \
} \
} while (0)
#define OP_LOGW_IF(condition, op_name, fmt, ...) \
do { \
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \
if (condition) { \
OP_LOGW(get_cstr(op_name), fmt, ##__VA_ARGS__); \
} \
#define OP_LOGW_IF(condition, op_name, fmt, ...) \
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \
do { \
if (condition) { \
OP_LOGW(get_cstr(op_name), fmt, ##__VA_ARGS__); \
} \
} while (0)
#define OP_LOGI_IF_RETURN(condition, return_value, op_name, fmt, ...) \
do { \
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \
if (condition) { \
OP_LOGI(get_cstr(op_name), fmt, ##__VA_ARGS__); \
return return_value; \
} \
#define OP_LOGI_IF_RETURN(condition, return_value, op_name, fmt, ...) \
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \
do { \
if (condition) { \
OP_LOGI(get_cstr(op_name), fmt, ##__VA_ARGS__); \
return return_value; \
} \
} while (0)
#endif // GE_OP_LOG_H
inline std::ostream &operator<<(std::ostream &os, const ge::Operator &op) {
os << op.GetName();
return os;
}
#endif // CUSTOMIZE_OP_PROTO_UTILS_OP_LOG_H

View File

@ -19,8 +19,8 @@
* \brief
*/
#ifndef CANN_OPS_BUILT_IN_OP_UTIL_H_
#define CANN_OPS_BUILT_IN_OP_UTIL_H_
#ifndef CANN_CUSTOMIZE_OP_UTIL_H_
#define CANN_CUSTOMIZE_OP_UTIL_H_
#include <memory>
#include <utility>
@ -218,4 +218,4 @@ inline bool IsConstTensor(const gert::Tensor *input_tensor) {
return (input_tensor != nullptr) && (input_tensor->GetAddr() != nullptr);
}
} // namespace ops
#endif // CANN_OPS_BUILT_IN_OP_UTIL_H_
#endif // CANN_CUSTOMIZE_OP_UTIL_H_

View File

@ -0,0 +1,370 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file reduce_infer_util.cc
* \brief
*/
#include <string>
#include <vector>
#include "util.h"
#include "reduce_infer_util.h"
#include "vector_proto_profiling.h"
namespace reduce_ops {
using namespace std;
using namespace ge;
constexpr int64_t UNKNOWN_DIM_VALUE = -2;
static bool ConvertAxis(std::vector<int64_t> &axis, const int64_t input_len) {
const int64_t input_length = input_len == 0 ? 1 : input_len;
// Convert reduce axis
for (size_t i = 0; i < axis.size(); ++i) {
if (axis[i] < -input_length || axis[i] > (input_length - 1)) {
OP_LOGE("ReduceOps", "reduce verify failed, axis: %ld, input_length: %ld", axis[i], input_length);
return false;
}
if (axis[i] < 0) {
axis[i] = input_length + axis[i];
}
}
return true;
}
bool DoReduceInfershapeWithAxesKeepdims(const GeShape &input_shape, std::vector<int64_t> &reduce_axes,
GeShape &output_shape) {
// case0: input is {-2} set the output {-2}
if (input_shape.IsUnknownDimNum()) {
OP_LOGD("ReduceOps", "do unknownrank infershape for Reduce, output is {-2}");
output_shape.SetIsUnknownDimNum();
return true;
}
auto input_shape_len = input_shape.GetDimNum();
OP_LOGD("ReduceOps", "input shape = %s, axes = %s", to_string(input_shape).c_str(), to_string(reduce_axes).c_str());
if (!ConvertAxis(reduce_axes, static_cast<int64_t>(input_shape_len))) {
OP_LOGE("ReduceOps", "do ConvertAxis failed, will return false");
return false;
}
// case1: will reduce all shape, if reduce_axes is empty
if (reduce_axes.empty()) {
// return the shape(all 1) when reduce_axes is empty and keep_dims = true
OP_LOGD("ReduceOps", "do all reduce infershape for Reduce, output is {}");
output_shape.SetDimNum(input_shape_len);
for (size_t i = 0; i < input_shape_len; ++i) {
output_shape.SetDim(i, 1);
}
return true;
}
// case2: shape is [x, y, z] axis is [0] --> [1, y, z] when keep_dims is true
output_shape.SetDimNum(input_shape_len);
OP_LOGD("ReduceOps", "do norm infershape for Reduce");
output_shape = input_shape;
for (size_t i = 0; i < reduce_axes.size(); ++i) {
int64_t axis = reduce_axes[i];
output_shape.SetDim(axis, 1);
}
OP_LOGD("ReduceOps", "after reduce output shape = %s", to_string(output_shape).c_str());
return true;
}
bool DoReduceInfershapeWithAxesNoKeepdims(const GeShape &input_shape, std::vector<int64_t> &reduce_axes,
GeShape &output_shape) {
// case0: will reduce all shape, if reduce_axes is empty
if (reduce_axes.empty()) {
// return a scalar shape when reduce_axes is empty and keep_dims = false
OP_LOGD("ReduceOps", "reduce_axes is empty, output a scalar");
output_shape.SetDimNum(0);
return true;
}
// case1: input is {-2} set the output {-2}
if (input_shape.IsUnknownDimNum()) {
OP_LOGD("ReduceOps", "input is {-2}, set the output is {-2}");
output_shape.SetIsUnknownDimNum();
return true;
}
auto input_shape_len = input_shape.GetDimNum();
if (!ConvertAxis(reduce_axes, static_cast<int64_t>(input_shape_len))) {
OP_LOGE("ReduceOps", "do ConvertAxis failed, will return false");
return false;
}
// case2: shape is [x, y, z] axis is [0] --> [y, z] when keep_dims is false
output_shape.SetDimNum(input_shape_len);
int64_t output_dim = 0;
for (int64_t i = 0; i < static_cast<int64_t>(input_shape_len); ++i) {
if (std::find(reduce_axes.begin(), reduce_axes.end(), i) == reduce_axes.end()) {
auto input_dim = input_shape.GetDim(i);
output_shape.SetDim(output_dim, input_dim);
output_dim++;
}
}
output_shape.SetDimNum(output_dim);
OP_LOGD("ReduceOps", "after reduce output shape = %s", to_string(output_shape).c_str());
return true;
}
bool DoReduceInfershapeWithAxes(const GeShape &input_shape, const bool keep_dims, std::vector<int64_t> &reduce_axes,
GeShape &output_shape) {
if (keep_dims) {
return DoReduceInfershapeWithAxesKeepdims(input_shape, reduce_axes, output_shape);
}
return DoReduceInfershapeWithAxesNoKeepdims(input_shape, reduce_axes, output_shape);
}
bool DoReduceInferRangeWithAxes(GeTensorDescPtr &tensordesc_input_x, GeTensorDescPtr &tensordesc_output,
std::vector<int64_t> &reduce_axes, bool keep_dims) {
std::vector<std::pair<int64_t, int64_t>> output_shape_range;
std::vector<std::pair<int64_t, int64_t>> input_shape_range;
tensordesc_input_x->GetShapeRange(input_shape_range);
std::vector<int64_t> input_shape_vec = tensordesc_input_x->MutableShape().GetDims();
// If InputShapeRange is None, MakeUpShapeRange will set range.
MakeUpShapeRange(input_shape_vec, input_shape_range);
if (keep_dims) {
output_shape_range = input_shape_range;
for (auto item : reduce_axes) {
output_shape_range[item] = std::make_pair<int64_t, int64_t>(1, 1);
}
} else {
for (int64_t i = 0; i < static_cast<int64_t>(input_shape_range.size()); ++i) {
if (std::find(reduce_axes.begin(), reduce_axes.end(), i) == reduce_axes.end()) {
output_shape_range.push_back(input_shape_range[i]);
}
}
}
tensordesc_output->SetShapeRange(output_shape_range);
return true;
}
bool GetConstData(const Operator &op, const int64_t const_input_idx, std::vector<int64_t> &const_values) {
ge::Tensor const_tensor;
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto input_name = op_desc->GetInputNameByIndex(const_input_idx);
if (op.GetInputConstData(input_name.c_str(), const_tensor) != ge::GRAPH_SUCCESS) {
OP_LOGW(TbeGetName(op).c_str(), "constvalue [%s] not exists.", input_name.c_str());
return false;
}
DataType const_dtype = op_desc->MutableInputDesc(const_input_idx)->GetDataType();
auto size = const_tensor.GetSize();
auto data = const_tensor.GetData();
const_values.reserve(size);
switch (const_dtype) {
case ge::DT_INT64: {
size_t count = size / sizeof(int64_t);
const int64_t *data_addr = (const int64_t *)(data);
for (size_t i = 0; i < count; i++) {
const_values.push_back(*(data_addr + i));
}
} break;
case ge::DT_INT32: {
size_t count = size / sizeof(int32_t);
const int32_t *data_addr = (const int32_t *)(data);
for (size_t i = 0; i < count; i++) {
const_values.push_back(*(data_addr + i));
}
} break;
default: {
OP_LOGW(TbeGetName(op).c_str(), "GetConstData of dtype[%s] has not implement.", to_string(const_dtype).c_str());
return false;
} break;
}
OP_LOGD(TbeGetName(op).c_str(), "get const value = %s", to_string(const_values).c_str());
return true;
}
bool DoReduceInferShapeWithoutAxes(const Operator &op, GeTensorDescPtr &tensordesc_input_x,
GeTensorDescPtr &tensordesc_output, const GeShape &axes_shape, bool keep_dims) {
OP_LOGD(TbeGetName(op).c_str(), "the axes is not const, the output will be dynamic shape");
const GeShape &input_shape = tensordesc_input_x->MutableShape();
// case0: input is {} set the output {}
if (input_shape.IsScalar()) {
OP_LOGD(TbeGetName(op).c_str(), "input is scalar, so output is scalar");
std::vector<int64_t> output_shape;
tensordesc_output->SetShape(GeShape(output_shape));
return true;
}
// case1: input is {-2} set the output {-2}
if (input_shape.IsUnknownDimNum()) {
OP_LOGD(TbeGetName(op).c_str(), "input is {-2}, so output {-2}");
std::vector<int64_t> output_shape(1, -2);
tensordesc_output->SetShape(GeShape(output_shape));
return true;
}
std::vector<int64_t> input_shape_vec = input_shape.GetDims();
std::vector<int64_t> output_shape;
std::vector<std::pair<int64_t, int64_t>> output_shape_range;
std::vector<std::pair<int64_t, int64_t>> input_shape_range;
tensordesc_input_x->GetShapeRange(input_shape_range);
// If InputShapeRange is None, MakeUpShapeRange will set range.
MakeUpShapeRange(input_shape_vec, input_shape_range);
size_t input_length = input_shape_vec.size();
if (keep_dims) {
// case2: all output shape dim is -1, range [1, xxx] when keep_dims is true
for (size_t item = 0; item < input_length; ++item) {
int64_t range_min_value = 1;
int64_t range_max_value = input_shape_range[item].second;
output_shape_range.push_back(std::make_pair(range_min_value, range_max_value));
if (range_max_value == 1) {
output_shape.push_back(1);
} else {
output_shape.push_back(-1);
}
}
} else {
// keep_dims is false
// case3: all output shape dim is -1, range is (min_range, max_range)
// output dim num = (input dim num - 1), when axes_shape = {} or {1}
// case4: all output shape dim is -2
int64_t output_dim_num = UNKNOWN_DIM_VALUE;
if (!axes_shape.IsUnknownDimNum() && axes_shape.GetDimNum() == 0) {
OP_LOGD(TbeGetName(op).c_str(), "the axes is scalar, will reduce one dim for input shape");
output_dim_num = input_length - 1;
}
if (axes_shape.GetDimNum() == 1 && axes_shape.GetDim(0) == 1) {
output_dim_num = input_length - 1;
OP_LOGD(TbeGetName(op).c_str(), "the shape of axes is [1], will reduce one dim for input shape");
}
int64_t range_min_value = input_shape_range[0].first;
int64_t range_max_value = input_shape_range[0].second;
for (size_t item = 0; item < input_length; ++item) {
if (input_shape_range[item].first < range_min_value) {
range_min_value = input_shape_range[item].first;
}
if (input_shape_range[item].second == -1) {
range_max_value = -1;
}
if (range_max_value != -1 && input_shape_range[item].second > range_max_value) {
range_max_value = input_shape_range[item].second;
}
}
if (output_dim_num == UNKNOWN_DIM_VALUE) {
output_shape.push_back(UNKNOWN_DIM_VALUE);
} else {
for (int64_t item = 0; item < output_dim_num; ++item) {
output_shape.push_back(-1);
output_shape_range.push_back(std::make_pair(range_min_value, range_max_value));
}
}
}
tensordesc_output->SetShape(GeShape(output_shape));
tensordesc_output->SetShapeRange(output_shape_range);
return true;
}
bool CommonReduceInferWithInputAxes(const Operator &op, const int64_t input_x_idx, const int64_t output_idx,
const int64_t input_axes_idx, bool keep_dims) {
PROFILING_PROTO_INIT(TbeGetName(op).c_str());
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto axes_name = op_desc->GetInputNameByIndex(input_axes_idx);
op_desc->SetOpInferDepends({axes_name});
CHECK(op_desc == nullptr, VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("invalid OpDesc.")),
return false);
auto tensordesc_input_x = op_desc->MutableInputDesc(input_x_idx);
auto tensordesc_input_axes = op_desc->MutableInputDesc(input_axes_idx);
auto tensordesc_output = op_desc->MutableOutputDesc(output_idx);
CHECK(tensordesc_input_x == nullptr || tensordesc_output == nullptr || tensordesc_input_axes == nullptr,
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("get mutable desc failed.")), return false);
auto input_type = tensordesc_input_x->GetDataType();
const GeShape &input_shape = tensordesc_input_x->MutableShape();
const GeShape &axes_shape = tensordesc_input_axes->MutableShape();
tensordesc_output->SetDataType(input_type);
if (axes_shape.GetDimNum() == 1 && axes_shape.GetDim(0) == 0) {
OP_LOGD(TbeGetName(op).c_str(), "axes_shape is [0], set output shape = input shape");
tensordesc_output->SetShape(input_shape);
std::vector<std::pair<int64_t, int64_t>> input_shape_range;
tensordesc_input_x->GetShapeRange(input_shape_range);
tensordesc_output->SetShapeRange(input_shape_range);
return true;
}
// get const value from input_axes_idx
std::vector<int64_t> reduce_axes;
if (GetConstData(op, input_axes_idx, reduce_axes)) {
PROFILING_PROTO_AFTER_GET_SHAPE_REG();
// do infershape with const axes for static op
GeShape &output_shape = tensordesc_output->MutableShape();
CHECK(!DoReduceInfershapeWithAxes(input_shape, keep_dims, reduce_axes, output_shape),
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("do reduce infershape failed.")),
return false);
// when output is dynamic shape, will infer range
if (output_shape.IsUnknownShape()) {
if (!output_shape.IsUnknownDimNum()) {
CHECK(!DoReduceInferRangeWithAxes(tensordesc_input_x, tensordesc_output, reduce_axes, keep_dims),
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("do reduce infer range failed.")),
return false);
}
OP_LOGD(TbeGetName(op).c_str(), "infer output range end for dynamic output");
return true;
}
OP_LOGD(TbeGetName(op).c_str(), "the output is not dynamic");
PROFILING_PROTO_AFTER_INFER_SHAPE_REG();
PROFILING_PROTO_END();
return true;
}
CHECK(!DoReduceInferShapeWithoutAxes(op, tensordesc_input_x, tensordesc_output, axes_shape, keep_dims),
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("infer reduce range failed.")), return false);
return true;
}
bool CommonReduceInferWithAttrAxes(const Operator &op, const int64_t input_x_idx, const int64_t output_idx,
vector<int64_t> attr_axes, bool keep_dims) {
PROFILING_PROTO_INIT(TbeGetName(op).c_str());
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
CHECK(op_desc == nullptr, VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("invalid OpDesc.")),
return false);
auto tensordesc_input_x = op_desc->MutableInputDesc(input_x_idx);
auto tensordesc_output = op_desc->MutableOutputDesc(output_idx);
CHECK(tensordesc_input_x == nullptr || tensordesc_output == nullptr,
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("get mutable desc failed.")), return false);
auto input_type = tensordesc_input_x->GetDataType();
const GeShape &input_shape = tensordesc_input_x->MutableShape();
tensordesc_output->SetDataType(input_type);
PROFILING_PROTO_AFTER_GET_SHAPE_REG();
// do infershape with const axes for static op
GeShape &output_shape = tensordesc_output->MutableShape();
CHECK(!DoReduceInfershapeWithAxes(input_shape, keep_dims, attr_axes, output_shape),
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("do reduce infershape failed.")), return false);
// when output is dynamic shape, will infer range
if (output_shape.IsUnknownShape()) {
if (!output_shape.IsUnknownDimNum()) {
CHECK(!DoReduceInferRangeWithAxes(tensordesc_input_x, tensordesc_output, attr_axes, keep_dims),
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("do reduce infer range failed.")),
return false);
}
OP_LOGD(TbeGetName(op), "infer output range end for dynamic output");
return true;
}
OP_LOGD(TbeGetName(op), "the output is not dynamic");
PROFILING_PROTO_AFTER_INFER_SHAPE_REG();
PROFILING_PROTO_END();
return true;
}
} // namespace reduce_ops

View File

@ -0,0 +1,119 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file reduce_infer_util.h
* \brief
*/
#ifndef CUSTOMIZE_OP_PROTO_UTIL_REDUCE_INFER_UTIL_H_
#define CUSTOMIZE_OP_PROTO_UTIL_REDUCE_INFER_UTIL_H_
#include <memory.h>
#include <string>
#include <vector>
#include <map>
#include <algorithm>
using namespace std;
using namespace ge;
namespace reduce_ops {
/*
* only do infer shape for reduce with input_shape/axes, when keepdims = true
* param[in] input_shape: GeShape input shape
* param[in] reduce_axes: reduce axes list
* param[in] output_shape: GeShape output shape
* return bool:
* true:infer success false:infer failed
*/
bool DoReduceInfershapeWithAxesKeepdims(const GeShape &input_shape, std::vector<int64_t> &reduce_axes,
GeShape &output_shape);
/*
* only do infer shape for reduce with input_shape/axes, when keepdims = false
* param[in] input_shape: GeShape input shape
* param[in] reduce_axes: reduce axes list
* param[in] output_shape: GeShape output shape
* return bool:
* true:infer success false:infer failed
*/
bool DoReduceInfershapeWithAxesNoKeepdims(const GeShape &input_shape, std::vector<int64_t> &reduce_axes,
GeShape &output_shape);
/*
* only do infer shape for reduce with input_shape, axes and keepdims
* param[in] input_shape: GeShape input shape
* param[in] keep_dims: bool
* param[in] reduce_axes: reduce axes list
* param[in] output_shape: GeShape output shape
* return bool:
* true:infer success false:infer failed
*/
bool DoReduceInfershapeWithAxes(const GeShape &input_shape, const bool keep_dims, std::vector<int64_t> &reduce_axes,
GeShape &output_shape);
/*
* only do infer range for reduce
* param[in] tensordesc_input_x: GeTensorDescPtr of input tensor
* param[in] tensordesc_output: GeTensorDescPtr of output tensor
* param[in] reduce_axes: reduce axes list
* param[in] keep_dims: bool
* return bool:
* true:infer success false:infer failed
*/
bool DoReduceInferRangeWithAxes(GeTensorDescPtr &tensordesc_input_x, GeTensorDescPtr &tensordesc_output,
std::vector<int64_t> &reduce_axes, bool keep_dims);
/*
* get const value from const node to vector const_values
* param[in] op: op desc get from by ge
* param[in] const_input_idx: the input idx for const node
* param[in] const_values: the const value
* return bool:
* true:infer success false:infer failed
*/
bool GetConstData(const Operator &op, const int64_t const_input_idx, std::vector<int64_t> &const_values);
/*
* infer shape and range for reduce, when the axes is not const
* param[in] op: op desc get from by ge
* param[in] tensordesc_input_x: GeTensorDescPtr of input tensor
* param[in] tensordesc_output: GeTensorDescPtr of output tensor
* param[in] axes_shape: the axes shape
* param[in] keep_dims: bool
* return bool:
* true:get value success false:no not get the const value
*/
bool DoReduceInferShapeWithoutAxes(const Operator &op, GeTensorDescPtr &tensordesc_input_x,
GeTensorDescPtr &tensordesc_output, const GeShape &axes_shape, bool keep_dims);
/*
* reduce infershape function, when axes is input
* param[in] op: op desc get from by ge
* param[in] input_x_idx: the input tensor idx int64
* param[in] output_idx: the output tensor idx int64
* param[in] input_axes_idx: the input const idx
* param[in] keep_dims: bool
* return bool:
* true:infer success false:infer failed
*/
bool CommonReduceInferWithInputAxes(const Operator &op, const int64_t input_x_idx, const int64_t output_idx,
const int64_t input_axes_idx, bool keep_dims);
bool CommonReduceInferWithAttrAxes(const Operator &op, const int64_t input_x_idx, const int64_t output_idx,
vector<int64_t> attr_axes, bool keep_dims);
} // namespace reduce_ops
#endif // CUSTOMIZE_OP_PROTO_UTIL_REDUCE_INFER_UTIL_H_

View File

@ -18,8 +18,8 @@
* \file transfer_shape_according_to_format.h
* \brief set shape according to original format and current format
*/
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
#define OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
#ifndef CUSTOMIZE_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
#define CUSTOMIZE_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
#include <memory>
#include <functional>
@ -121,4 +121,4 @@ class ShapeTransferAccordingToFormat {
};
} // namespace ge
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
#endif // CUSTOMIZE_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_

View File

@ -18,8 +18,8 @@
* \file util.h
* \brief
*/
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_
#define OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_
#ifndef CUSTOMIZE_OP_PROTO_UTIL_UTIL_H_
#define CUSTOMIZE_OP_PROTO_UTIL_UTIL_H_
#include <memory.h>
#include <string>
@ -641,4 +641,16 @@ const int32_t INDEX_VALUE5 = 5;
const int32_t INDEX_VALUE6 = 6;
const int32_t INDEX_VALUE7 = 7;
const int32_t INDEX_VALUE8 = 8;
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_
template <typename T>
inline std::string VectorToString(const std::vector<T> &values) {
std::stringstream ss;
for (auto iter = values.begin(); iter != values.end(); ++iter) {
ss << *iter;
if (iter != values.end() - 1) {
ss << ", ";
}
}
return ss.str();
}
#endif // CUSTOMIZE_OP_PROTO_UTIL_UTIL_H_

View File

@ -18,8 +18,8 @@
* \file vector_proto_profiling.h
* \brief
*/
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_VECTOR_PROTO_PROFILING_H__H_
#define OPS_BUILT_IN_OP_PROTO_UTIL_VECTOR_PROTO_PROFILING_H__H_
#ifndef CUSTOMIZE_OP_PROTO_UTIL_VECTOR_PROTO_PROFILING_H__H_
#define CUSTOMIZE_OP_PROTO_UTIL_VECTOR_PROTO_PROFILING_H__H_
#include <memory>
#include <string>
@ -66,4 +66,4 @@ const bool vector_prof_switch = std::getenv("VECTOR_PROF") != nullptr;
} \
} \
}
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_VECTOR_PROTO_PROFILING_H__H_
#endif // CUSTOMIZE_OP_PROTO_UTIL_VECTOR_PROTO_PROFILING_H__H_

View File

@ -20,12 +20,110 @@ import stat
import sys
cust_op_lists = [
"Cast",
"Conj",
"IsNan",
"SliceGrad",
"MaskedSelectGrad",
"GatherDGradV2"
"acosgrad",
"acoshgrad",
"adaptiveavgpool2d",
"adaptiveavgpool2dgrad",
"adaptiveavgpool3d",
"adaptiveavgpool3dgrad",
"adaptivemaxpool2dgrad",
"addn",
"adjusthue",
"adjustsaturation",
"affinegridgrad",
"argmax",
"argmaxwithvalue",
"argmin",
"argminwithvalue",
"asingrad",
"asinhgrad",
"bartlettwindow",
"betainc",
"biasadd",
"biasaddgrad",
"bincount",
"blackmanwindow",
"broadcastto",
"bucketize",
"cauchy",
"checknumerics",
"cholesky",
"choleskygrad",
"choleskyinverse",
"choleskysolve",
"combinednonmaxsuppression",
"complex",
"complexabs",
"concat",
"conj",
"cos",
"cumprod",
"cumulativelogsumexp",
"dataformatvecpermute",
"depthtospace",
"diag",
"diagpart",
"div",
"divnonan",
"eig",
"exp",
"expand",
"expm1",
"extractglimpse",
"eye",
"filldiagonal",
"floordiv",
"fractionalavgpool",
"fractionalavgpoolgrad",
"fractionalmaxpool",
"fractionalmaxpoolgrad",
"gathernd",
"gcd",
"geqrf",
"hammingwindow",
"heaviside",
"histogram",
"hypot",
"identityn",
"im2col",
"indexfill",
"isinf",
"isnan",
"kldivloss",
"kldivlossgrad",
"lcm",
"leftshift",
"lessequal",
"listdiff",
"log",
"log1p",
"lognormalreverse",
"logspace",
"lowerbound",
"lusolve",
"luunpackgrad",
"maskedselect",
"maskedselectgrad",
"matrixdeterminant",
"matrixexp",
"matrixlogarithm",
"matrixsolve",
"matrixtriangularsolve",
"maxpool3dgradwithargmax",
"maxpool3dwithargmax",
"mul",
"mulnonan",
"multimarginloss",
"multimarginlossgrad",
"multinomial",
"mvlgamma",
"mvlgammagrad",
"nextafter",
"nondeterministicints",
"gatherdgradv2",
"isnan",
"maskedselectgrad",
"slicegrad"
]
@ -56,7 +154,7 @@ def parse_ini_to_obj(ini_file, aicpu_ops_info):
info = {}
op_name = line[1:-1]
info = {}
if op_name not in cust_op_lists:
if op_name.lower() not in cust_op_lists:
op_name = None
continue
aicpu_ops_info[op_name] = info

View File

@ -57,5 +57,52 @@ REG_CUST_OP(GatherDGradV2)
.OUTPUT(output, TensorType({DT_BOOL, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64,
DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
.CUST_OP_END_FACTORY_REG(GatherDGradV2)
REG_CUST_OP(AffineGridGrad)
.INPUT(y_grad, TensorType({DT_FLOAT, DT_FLOAT16}))
.INPUT(x_size, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(x_grad, TensorType({DT_FLOAT, DT_FLOAT16}))
.REQUIRED_ATTR(align_corners, Bool)
.CUST_OP_END_FACTORY_REG(AffineGridGrad)
REG_CUST_OP(HammingWindow)
.INPUT(length, TensorType({DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.REQUIRED_ATTR(periodic, Bool)
.REQUIRED_ATTR(alpha, Float)
.REQUIRED_ATTR(beta, Float)
.REQUIRED_ATTR(dtype, Int)
.CUST_OP_END_FACTORY_REG(HammingWindow)
REG_CUST_OP(IndexFill)
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
DT_UINT32, DT_UINT64, DT_UINT8}))
.INPUT(dim, TensorType({DT_INT32}))
.INPUT(indices, TensorType({DT_INT32}))
.INPUT(value, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
DT_UINT32, DT_UINT64, DT_UINT8}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
DT_UINT32, DT_UINT64, DT_UINT8}))
.CUST_OP_END_FACTORY_REG(IndexFill)
REG_CUST_OP(Mvlgamma)
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT}))
.REQUIRED_ATTR(p, Int)
.CUST_OP_END_FACTORY_REG(Mvlgamma)
REG_CUST_OP(MvlgammaGrad)
.INPUT(y_grad, TensorType({DT_DOUBLE, DT_FLOAT}))
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT}))
.OUTPUT(x_grad, TensorType({DT_DOUBLE, DT_FLOAT}))
.REQUIRED_ATTR(p, Int)
.CUST_OP_END_FACTORY_REG(MvlgammaGrad)
REG_CUST_OP(LogSpace)
.INPUT(start, TensorType({DT_FLOAT, DT_FLOAT16}))
.INPUT(end, TensorType({DT_FLOAT, DT_FLOAT16}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16}))
.REQUIRED_ATTR(steps, Int)
.REQUIRED_ATTR(base, Int)
.CUST_OP_END_FACTORY_REG(LogSpace)
} // namespace ge
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_ARRAY_OPS_H_

View File

@ -0,0 +1,35 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_ELEWISE_CALCULATION_OPS_H_
#define MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_ELEWISE_CALCULATION_OPS_H_
#include "graph/operator_reg.h"
#include "graph/operator.h"
#include "transform/graph_ir/custom_op_proto/op_proto_macro.h"
/* clang-format off */
namespace ge {
REG_CUST_OP(ArgMax)
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
DT_UINT32, DT_UINT64, DT_UINT8}))
.INPUT(dimension, TensorType({DT_INT32, DT_INT64}))
.ATTR(dtype, Type, DT_INT64)
.OUTPUT(y, TensorType({DT_INT32, DT_INT64}))
.CUST_OP_END_FACTORY_REG(ArgMax)
} // namespace ge
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_ELEWISE_CALCULATION_OPS_H_

View File

@ -0,0 +1,59 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_LINALG_OPS_H_
#define MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_LINALG_OPS_H_
#include "graph/operator_reg.h"
#include "graph/operator.h"
#include "transform/graph_ir/custom_op_proto/op_proto_macro.h"
/* clang-format off */
namespace ge {
REG_CUST_OP(Geqrf)
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT}))
.OUTPUT(tau, TensorType({DT_DOUBLE, DT_FLOAT}))
.CUST_OP_END_FACTORY_REG(Geqrf)
REG_CUST_OP(LuUnpack)
.INPUT(LU_data, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
.INPUT(LU_pivots, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64}))
.OUTPUT(pivots, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
.OUTPUT(L, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
.OUTPUT(U, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
.ATTR(unpack_data, Bool, true)
.ATTR(unpack_pivots, Bool, true)
.CUST_OP_END_FACTORY_REG(LuUnpack)
REG_CUST_OP(LuUnpackGrad)
.INPUT(L_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
.INPUT(U_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
.INPUT(LU_data, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
.OUTPUT(L_data_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
.OUTPUT(U_data_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
.REQUIRED_ATTR(L_grad_flag, Bool)
.CUST_OP_END_FACTORY_REG(LuUnpackGrad)
REG_CUST_OP(LuSolve)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16}))
.INPUT(lu_data, TensorType({DT_FLOAT, DT_FLOAT16}))
.INPUT(lu_pivots, TensorType({DT_INT32}))
.OUTPUT(output, TensorType({DT_FLOAT, DT_FLOAT16}))
.CUST_OP_END_FACTORY_REG(LuSolve)
} // namespace ge
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_LINALG_OPS_H_

View File

@ -0,0 +1,92 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_MATH_OPS_H_
#define MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_MATH_OPS_H_
#include "graph/operator_reg.h"
#include "graph/operator.h"
#include "transform/graph_ir/custom_op_proto/op_proto_macro.h"
/* clang-format off */
namespace ge {
REG_CUST_OP(CholeskySolve)
.INPUT(x1, TensorType({DT_DOUBLE, DT_FLOAT}))
.INPUT(x2, TensorType({DT_DOUBLE, DT_FLOAT}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT}))
.REQUIRED_ATTR(upper, Bool)
.CUST_OP_END_FACTORY_REG(CholeskySolve)
REG_CUST_OP(Cauchy)
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16}))
.REQUIRED_ATTR(size, ListInt)
.REQUIRED_ATTR(sigma, Float)
.REQUIRED_ATTR(median, Float)
.CUST_OP_END_FACTORY_REG(Cauchy)
REG_CUST_OP(CholeskyInverse)
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT}))
.REQUIRED_ATTR(upper, Bool)
.CUST_OP_END_FACTORY_REG(CholeskyInverse)
REG_CUST_OP(Eig)
.INPUT(x, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT}))
.OUTPUT(eigen_values, TensorType({DT_COMPLEX128, DT_COMPLEX64}))
.OUTPUT(eigen_vectors, TensorType({DT_COMPLEX128, DT_COMPLEX64}))
.REQUIRED_ATTR(compute_v, Bool)
.CUST_OP_END_FACTORY_REG(Eig)
REG_CUST_OP(Hypot)
.INPUT(x1, TensorType({DT_DOUBLE, DT_FLOAT}))
.INPUT(x2, TensorType({DT_DOUBLE, DT_FLOAT}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT}))
.CUST_OP_END_FACTORY_REG(Hypot)
REG_CUST_OP(MatrixLogarithm)
.INPUT(x, TensorType({DT_COMPLEX128, DT_COMPLEX64}))
.OUTPUT(y, TensorType({DT_COMPLEX128, DT_COMPLEX64}))
.CUST_OP_END_FACTORY_REG(MatrixLogarithm)
REG_CUST_OP(Lcm)
.INPUT(x1, TensorType({DT_INT32, DT_INT64}))
.INPUT(x2, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(y, TensorType({DT_INT32, DT_INT64}))
.CUST_OP_END_FACTORY_REG(Lcm)
REG_CUST_OP(MatrixExp)
.INPUT(x, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.OUTPUT(y, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.CUST_OP_END_FACTORY_REG(MatrixExp)
REG_CUST_OP(Heaviside)
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
DT_UINT32, DT_UINT64, DT_UINT8}))
.INPUT(values, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
DT_UINT32, DT_UINT64, DT_UINT8}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
DT_UINT32, DT_UINT64, DT_UINT8}))
.CUST_OP_END_FACTORY_REG(Heaviside)
REG_CUST_OP(Gcd)
.INPUT(x1, TensorType({DT_INT32, DT_INT64}))
.INPUT(x2, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(y, TensorType({DT_INT32, DT_INT64}))
.CUST_OP_END_FACTORY_REG(Gcd)
} // namespace ge
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_MATH_OPS_H_

View File

@ -0,0 +1,97 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_NN_OPS_H_
#define MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_NN_OPS_H_
#include "graph/operator_reg.h"
#include "graph/operator.h"
#include "transform/graph_ir/custom_op_proto/op_proto_macro.h"
/* clang-format off */
namespace ge {
REG_CUST_OP(AdaptiveAvgPool3dGrad)
.INPUT(input_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
.INPUT(orig_input_shape, TensorType({DT_INT32}))
.OUTPUT(output_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
.CUST_OP_END_FACTORY_REG(AdaptiveAvgPool3dGrad)
REG_CUST_OP(AdaptiveMaxPool2dGrad)
.INPUT(y_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.INPUT(argmax, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(x_grad, TensorType({DT_FLOAT, DT_FLOAT16}))
.CUST_OP_END_FACTORY_REG(AdaptiveMaxPool2dGrad)
REG_CUST_OP(AdaptiveAvgPool2D)
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.REQUIRED_ATTR(output_size, ListInt)
.CUST_OP_END_FACTORY_REG(AdaptiveAvgPool2D)
REG_CUST_OP(AdaptiveAvgPool3d)
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
.REQUIRED_ATTR(output_size, ListInt)
.CUST_OP_END_FACTORY_REG(AdaptiveAvgPool3d)
REG_CUST_OP(AdaptiveAvgPool2DGrad)
.INPUT(input_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.INPUT(orig_input_shape, TensorType({DT_INT64}))
.OUTPUT(output_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.CUST_OP_END_FACTORY_REG(AdaptiveAvgPool2DGrad)
REG_CUST_OP(MultiMarginLossGrad)
.INPUT(y_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.INPUT(target, TensorType({DT_INT64}))
.INPUT(weight, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.OUTPUT(x_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.REQUIRED_ATTR(p, Int)
.REQUIRED_ATTR(margin, Float)
.REQUIRED_ATTR(reduction, String)
.CUST_OP_END_FACTORY_REG(MultiMarginLossGrad)
REG_CUST_OP(MaxPool3DGradWithArgmax)
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
DT_UINT32, DT_UINT64, DT_UINT8}))
.INPUT(grads, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
DT_UINT32, DT_UINT64, DT_UINT8}))
.INPUT(argmax, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
DT_UINT32, DT_UINT64, DT_UINT8}))
.REQUIRED_ATTR(ksize, ListInt)
.REQUIRED_ATTR(strides, ListInt)
.REQUIRED_ATTR(pads, ListInt)
.REQUIRED_ATTR(dilation, ListInt)
.REQUIRED_ATTR(ceil_mode, Bool)
.REQUIRED_ATTR(data_format, String)
.REQUIRED_ATTR(argmax_type, String)
.CUST_OP_END_FACTORY_REG(MaxPool3DGradWithArgmax)
REG_CUST_OP(MultiMarginLoss)
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.INPUT(target, TensorType({DT_INT64}))
.INPUT(weight, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.REQUIRED_ATTR(p, Int)
.REQUIRED_ATTR(margin, Float)
.REQUIRED_ATTR(reduction, String)
.CUST_OP_END_FACTORY_REG(MultiMarginLoss)
} // namespace ge
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_NN_OPS_H_

View File

@ -0,0 +1,43 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_RANDOM_OPS_H_
#define MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_RANDOM_OPS_H_
#include "graph/operator.h"
#include "graph/operator_reg.h"
#include "transform/graph_ir/custom_op_proto/op_proto_macro.h"
/* clang-format off */
namespace ge {
REG_CUST_OP(LogNormalReverse)
.INPUT(input, TensorType({DT_FLOAT, DT_FLOAT16}))
.OUTPUT(output, TensorType({DT_FLOAT, DT_FLOAT16}))
.REQUIRED_ATTR(mean, Float)
.REQUIRED_ATTR(std, Float)
.CUST_OP_END_FACTORY_REG(LogNormalReverse)
REG_CUST_OP(Dropout2D)
.INPUT(x, TensorType({DT_BOOL, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8,
DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}))
.OUTPUT(y, TensorType({DT_BOOL, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8,
DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}))
.OUTPUT(mask, TensorType({DT_BOOL}))
.REQUIRED_ATTR(keep_prob, Float)
.CUST_OP_END_FACTORY_REG(Dropout2D)
} // namespace ge
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_RANDOM_OPS_H_

View File

@ -0,0 +1,42 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_SPECTRAL_OPS_H_
#define MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_SPECTRAL_OPS_H_
#include "graph/operator_reg.h"
#include "graph/operator.h"
#include "transform/graph_ir/custom_op_proto/op_proto_macro.h"
/* clang-format off */
namespace ge {
REG_CUST_OP(BlackmanWindow)
.INPUT(window_length, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.REQUIRED_ATTR(periodic, Bool)
.REQUIRED_ATTR(dtype, Type)
.CUST_OP_END_FACTORY_REG(BlackmanWindow)
REG_CUST_OP(BartlettWindow)
.INPUT(window_length, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
.REQUIRED_ATTR(periodic, Bool)
.REQUIRED_ATTR(dtype, Type)
.CUST_OP_END_FACTORY_REG(BartlettWindow)
} // namespace ge
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_SPECTRAL_OPS_H_

View File

@ -17,8 +17,9 @@
#ifndef MINDSPORE_CCSRC_INCLUDE_TRANSFORM_GRAPH_IR_OP_ADAPTER_MAP_H_
#define MINDSPORE_CCSRC_INCLUDE_TRANSFORM_GRAPH_IR_OP_ADAPTER_MAP_H_
#include <string>
#include <memory>
#include <string>
#include "utils/hash_map.h"
namespace mindspore {
@ -28,12 +29,15 @@ constexpr const char kNameConst[] = "Const";
constexpr const char kNameParam[] = "parameter";
constexpr const char kNameRandomUniform[] = "RandomUniform";
constexpr const char kNameUniformReal[] = "UniformReal";
constexpr const char kNameLogNormalReverse[] = "LogNormalReverse";
constexpr const char kNameSimpleMean[] = "SimpleMean";
constexpr const char kNameSimpleMeanGrad[] = "SimpleMeanGrad";
constexpr const char kNameAllReduce[] = "AllReduce";
constexpr const char kNameBroadcast[] = "Broadcast";
constexpr const char kNameBroadcastTo[] = "BroadcastTo";
constexpr const char kNameBroadcastToD[] = "BroadcastToD";
constexpr const char kNameBlackmanWindow[] = "BlackmanWindow";
constexpr const char kNameBartlettWindow[] = "BartlettWindow";
constexpr const char kNameAllgather[] = "AllGather";
constexpr const char kNameAllToAllv[] = "AllToAllv";
constexpr const char kNameReduceScatter[] = "ReduceScatter";
@ -47,6 +51,7 @@ constexpr const char kNameSquaredDifference[] = "SquaredDifference";
constexpr const char kNamePow[] = "Pow";
constexpr const char kNameBatchMatMul[] = "BatchMatMul";
constexpr const char kNameBatchMatMulV2[] = "BatchMatMulV2";
constexpr const char kNameBincount[] = "Bincount";
constexpr const char kNameStridedSlice[] = "StridedSlice";
constexpr const char kNameStridedSliceGrad[] = "StridedSliceGrad";
constexpr const char kNameExpandDims[] = "ExpandDims";
@ -54,6 +59,7 @@ constexpr const char kNameLog[] = "Log";
constexpr const char kNameLogicalAnd[] = "LogicalAnd";
constexpr const char kNameLogicalNot[] = "LogicalNot";
constexpr const char kNameLogicalOr[] = "LogicalOr";
constexpr const char kNameListDiff[] = "ListDiff";
constexpr const char kNameExp[] = "Exp";
constexpr const char kNameLessEqual[] = "LessEqual";
constexpr const char kNameGreaterEqual[] = "GreaterEqual";
@ -61,6 +67,8 @@ constexpr const char kNameApproximateEqual[] = "ApproximateEqual";
constexpr const char kNameEqual[] = "Equal";
constexpr const char kNameNotEqual[] = "NotEqual";
constexpr const char kNameFlattenGrad[] = "FlattenGrad";
constexpr const char kNameFillDiagonal[] = "FillDiagonal";
constexpr const char kNameEye[] = "Eye";
constexpr const char kNameConvolution[] = "Convolution";
constexpr const char kNameMaxPool3D[] = "MaxPool3D";
constexpr const char kNameMaxPool3DGrad[] = "MaxPool3DGrad";
@ -80,6 +88,7 @@ constexpr const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax";
constexpr const char kNameMaxPoolGradWithArgmaxV2[] = "MaxPoolGradWithArgmaxV2";
constexpr const char kNameApplyMomentum[] = "ApplyMomentum";
constexpr const char kNameDropoutDoMask[] = "DropoutDoMask";
constexpr const char kNameDropout2D[] = "Dropout2D";
constexpr const char kNameDropOutDoMaskV3[] = "DropOutDoMaskV3";
constexpr const char kNameDropOutDoMaskV3D[] = "DropOutDoMaskV3D";
constexpr const char kNameDropOutGenMaskV4[] = "DropOutGenMaskV4";
@ -441,6 +450,7 @@ constexpr const char kNameViewCopy[] = "ViewCopy";
constexpr const char kNameSend[] = "Send";
constexpr const char kNameReceive[] = "Receive";
constexpr const char kNameIndexAdd[] = "IndexAdd";
constexpr const char kNameIndexFill[] = "IndexFill";
constexpr const char kNameUnique[] = "Unique";
constexpr const char kNameDynamicBroadcastGradientArgs[] = "DynamicBroadcastGradientArgs";
constexpr const char kNameDynamicStitch[] = "DynamicStitch";

View File

@ -253,4 +253,57 @@ ATTR_INPUT_MAP(ViewCopy) = {
ATTR_MAP(ViewCopy) = EMPTY_ATTR_MAP;
OUTPUT_MAP(ViewCopy) = {{0, OUTPUT_DESC(dst)}};
REG_ADPT_DESC(ViewCopy, kNameViewCopy, ADPT_DESC(ViewCopy))
// CheckNumerics
INPUT_MAP(CheckNumerics) = {{1, INPUT_DESC(x)}};
ATTR_MAP(CheckNumerics) = {{"message", ATTR_DESC(message, AnyTraits<std::string>())}};
OUTPUT_MAP(CheckNumerics) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(CheckNumerics, prim::kPrimCheckNumerics->name(), ADPT_DESC(CheckNumerics));
// HammingWindow
CUST_INPUT_MAP(HammingWindow) = {{1, INPUT_DESC(length)}};
CUST_ATTR_MAP(HammingWindow) = {{"periodic", ATTR_DESC(periodic, AnyTraits<bool>())},
{"alpha", ATTR_DESC(alpha, AnyTraits<float>())},
{"beta", ATTR_DESC(beta, AnyTraits<float>())},
{"dtype", ATTR_DESC(dtype, AnyTraits<int64_t>())}};
CUST_OUTPUT_MAP(HammingWindow) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(HammingWindow, prim::kPrimHammingWindow->name(), CUST_ADPT_DESC(HammingWindow));
// LowerBound
INPUT_MAP(LowerBound) = {{1, INPUT_DESC(sorted_x)}, {2, INPUT_DESC(values)}};
ATTR_MAP(LowerBound) = EMPTY_ATTR_MAP;
OUTPUT_MAP(LowerBound) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(LowerBound, prim::kPrimLowerBound->name(), ADPT_DESC(LowerBound));
// ListDiff
INPUT_MAP(ListDiff) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}};
ATTR_MAP(ListDiff) = {{"out_idx", ATTR_DESC(out_idx, AnyTraits<GEType>())}};
OUTPUT_MAP(ListDiff) = {{0, OUTPUT_DESC(out)}, {1, OUTPUT_DESC(idx)}};
REG_ADPT_DESC(ListDiff, kNameListDiff, ADPT_DESC(ListDiff));
// IndexFill
CUST_INPUT_MAP(IndexFill) = {
{1, INPUT_DESC(x)}, {2, INPUT_DESC(dim)}, {3, INPUT_DESC(indices)}, {4, INPUT_DESC(value)}};
CUST_ATTR_MAP(IndexFill) = EMPTY_ATTR_MAP;
CUST_OUTPUT_MAP(IndexFill) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(IndexFill, kNameIndexFill, CUST_ADPT_DESC(IndexFill));
// Mvlgamma
CUST_INPUT_MAP(Mvlgamma) = {{1, INPUT_DESC(x)}};
CUST_ATTR_MAP(Mvlgamma) = {{"p", ATTR_DESC(p, AnyTraits<int64_t>())}};
CUST_OUTPUT_MAP(Mvlgamma) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Mvlgamma, prim::kPrimMvlgamma->name(), CUST_ADPT_DESC(Mvlgamma));
// MvlgammaGrad
CUST_INPUT_MAP(MvlgammaGrad) = {{1, INPUT_DESC(y_grad)}, {2, INPUT_DESC(x)}};
CUST_ATTR_MAP(MvlgammaGrad) = {{"p", ATTR_DESC(p, AnyTraits<int64_t>())}};
CUST_OUTPUT_MAP(MvlgammaGrad) = {{0, OUTPUT_DESC(x_grad)}};
REG_ADPT_DESC(MvlgammaGrad, prim::kPrimMvlgammaGrad->name(), CUST_ADPT_DESC(MvlgammaGrad));
// LogSpace
CUST_INPUT_MAP(LogSpace) = {{1, INPUT_DESC(start)}, {2, INPUT_DESC(end)}};
CUST_ATTR_MAP(LogSpace) = {{"steps", ATTR_DESC(steps, AnyTraits<int64_t>())},
{"base", ATTR_DESC(base, AnyTraits<int64_t>())}};
CUST_OUTPUT_MAP(LogSpace) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(LogSpace, prim::kPrimLogSpace->name(), CUST_ADPT_DESC(LogSpace));
} // namespace mindspore::transform

View File

@ -125,4 +125,28 @@ DECLARE_OP_USE_OUTPUT(AsStrided)
DECLARE_OP_ADAPTER(ViewCopy)
DECLARE_OP_USE_OUTPUT(ViewCopy)
DECLARE_OP_ADAPTER(CheckNumerics)
DECLARE_OP_USE_OUTPUT(CheckNumerics)
DECLARE_CUST_OP_ADAPTER(HammingWindow)
DECLARE_CUST_OP_USE_OUTPUT(HammingWindow)
DECLARE_OP_ADAPTER(LowerBound)
DECLARE_OP_USE_OUTPUT(LowerBound)
DECLARE_OP_ADAPTER(ListDiff)
DECLARE_OP_USE_OUTPUT(ListDiff)
DECLARE_CUST_OP_ADAPTER(IndexFill)
DECLARE_CUST_OP_USE_OUTPUT(IndexFill)
DECLARE_CUST_OP_ADAPTER(Mvlgamma)
DECLARE_CUST_OP_USE_OUTPUT(Mvlgamma)
DECLARE_CUST_OP_ADAPTER(MvlgammaGrad)
DECLARE_CUST_OP_USE_OUTPUT(MvlgammaGrad)
DECLARE_CUST_OP_ADAPTER(LogSpace)
DECLARE_CUST_OP_USE_OUTPUT(LogSpace)
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_

View File

@ -18,6 +18,7 @@
#include <memory>
#include <string>
#include <vector>
#include "transform/graph_ir/custom_op_proto/cust_elewise_calculation_ops.h"
#include "ops/ascend_op_name.h"
#include "ops/array_ops.h"
#include "ops/framework_ops.h"
@ -279,6 +280,13 @@ ATTR_MAP(OnesLike) = EMPTY_ATTR_MAP;
OUTPUT_MAP(OnesLike) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(OnesLike, kNameOnesLike, ADPT_DESC(OnesLike))
// ArgMax
CUST_INPUT_MAP(ArgMax) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(dimension)}};
CUST_ATTR_INPUT_MAP(ArgMax) = {{"axis", "dimension"}};
CUST_ATTR_MAP(ArgMax) = {{"output_type", ATTR_DESC(dtype, AnyTraits<GEType>())}};
CUST_OUTPUT_MAP(ArgMax) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ArgMax, kNameArgmax, CUST_ADPT_DESC(ArgMax));
// ArgMaxD
INPUT_MAP(ArgMaxD) = {{1, INPUT_DESC(x)}};
ATTR_MAP(ArgMaxD) = {{"axis", ATTR_DESC(dimension, AnyTraits<int64_t>())},
@ -291,7 +299,6 @@ INPUT_MAP(ArgMaxV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(dimension)}};
ATTR_INPUT_MAP(ArgMaxV2) = {{"axis", "dimension"}};
ATTR_MAP(ArgMaxV2) = {{"output_type", ATTR_DESC(dtype, AnyTraits<GEType>())}};
OUTPUT_MAP(ArgMaxV2) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ArgMax, kNameArgmax, ADPT_DESC(ArgMaxV2))
REG_ADPT_DESC(ArgMaxV2, kNameArgMaxV2, ADPT_DESC(ArgMaxV2))
// ArgMaxWithValue

View File

@ -19,6 +19,7 @@
#include "inc/ops/bitwise_ops.h"
#include "inc/ops/elewise_calculation_ops.h"
#include "transform/graph_ir/custom_op_proto/cust_elewise_calculation_ops.h"
#include "transform/graph_ir/op_declare/op_declare_macro.h"
#include "utils/hash_map.h"
@ -47,6 +48,9 @@ DECLARE_OP_USE_OUTPUT(ZerosLike)
DECLARE_OP_ADAPTER(OnesLike)
DECLARE_OP_USE_OUTPUT(OnesLike)
DECLARE_CUST_OP_ADAPTER(ArgMax)
DECLARE_CUST_OP_USE_OUTPUT(ArgMax)
DECLARE_OP_ADAPTER(ArgMaxD)
DECLARE_OP_USE_OUTPUT(ArgMaxD)

Some files were not shown because too many files have changed in this diff Show More