Add regop and adapters for custom aicpu
This commit is contained in:
parent
0faf4bd9da
commit
a2adaa3917
|
@ -183,6 +183,8 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/" "readability/casting"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/" "readability/namespace"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/" "readability/braces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/" "readability/braces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/" ""
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/inc/" "whitespace/ending_newline"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/inc/" "build/include_subdir"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/inc/" "runtime/references"
|
||||
|
|
|
@ -415,3 +415,12 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/
|
|||
mindspore/mindspore/lite/src/litert/kernel/opencl/kernel/conv2d.cc:mindspore::kernel::UseWinograd4x4To6x6
|
||||
mindspore/mindspore/lite/src/litert/kernel/opencl/kernel/fullconnection.cc:mindspore::kernel::FullConnectionOpenCLKernel::CheckSpecs
|
||||
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/concat_proto.cc:ge::ConcatInferShapeCommon
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/elewise_calculation_ops_proto.cc:ge::IMPLEMT_COMMON_INFERFUNC
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/nn_pooling_ops_proto.cc:ge::CUST_IMPLEMT_VERIFIER
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/combined_non_max_suppression_proto.cc:ge::IMPLEMT_INFERFUNC
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/math_ops_proto.cc:ge::IMPLEMT_INFERFUNC
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/concat_proto.cc:ge::ConcatInferShapeCommon
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/combined_non_max_suppression_proto.cc:ge::IMPLEMT_INFERFUNC
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/im2col_proto.cc:ge::IMPLEMT_COMMON_INFERFUNC
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/math_ops_proto.cc:ge::IMPLEMT_INFERFUNC
|
||||
|
|
|
@ -36,7 +36,7 @@ input1.name=dy
|
|||
input1.type=DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_COMPLEX64,DT_COMPLEX128
|
||||
output0.name=z
|
||||
output0.type=DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_COMPLEX64,DT_COMPLEX128
|
||||
[AdaptiveAvgPool2d]
|
||||
[AdaptiveAvgPool2D]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
|
@ -807,25 +807,6 @@ output1.type=DT_FLOAT,DT_DOUBLE,DT_COMPLEX64,DT_COMPLEX128
|
|||
output1.name=values
|
||||
output2.type=DT_INT32,DT_INT64
|
||||
output2.name=dense_shape
|
||||
[Cumprod]
|
||||
opInfo.engine=DNN_VM_AICPU
|
||||
opInfo.flagPartial=False
|
||||
opInfo.computeCost=100
|
||||
opInfo.flagAsync=False
|
||||
opInfo.opKernelLib=CUSTAICPUKernel
|
||||
opInfo.kernelSo=libcust_cpu_kernels.so
|
||||
opInfo.functionName=RunCpuKernel
|
||||
opInfo.workspaceSize=1024
|
||||
opInfo.opsFlag=OPS_FLAG_CLOSE
|
||||
opInfo.userDefined=True
|
||||
opInfo.subTypeOfInferShape=1
|
||||
opInfo.formatAgnostic=False
|
||||
input0.type=DT_INT8,DT_INT16,DT_INT32,DT_INT64,DT_UINT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_COMPLEX64,DT_COMPLEX128
|
||||
input0.name=x
|
||||
input1.type=DT_INT32,DT_INT64
|
||||
input1.name=axis
|
||||
output0.type=DT_INT8,DT_INT16,DT_INT32,DT_INT64,DT_UINT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_COMPLEX64,DT_COMPLEX128
|
||||
output0.name=y
|
||||
[CumulativeLogsumexp]
|
||||
opInfo.engine=DNN_VM_AICPU
|
||||
opInfo.flagPartial=False
|
||||
|
@ -4076,3 +4057,709 @@ opInfo.opsFlag=OPS_FLAG_CLOSE
|
|||
opInfo.userDefined=True
|
||||
opInfo.subTypeOfInferShape=1
|
||||
opInfo.formatAgnostic=False
|
||||
[ArgMin]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
input1.name=dimension
|
||||
input1.type=DT_INT32,DT_INT64
|
||||
output0.name=y
|
||||
output0.type=DT_INT32,DT_INT64
|
||||
[Diag]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT32,DT_INT64
|
||||
output0.name=y
|
||||
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT32,DT_INT64
|
||||
[Betainc]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=a
|
||||
input0.type=DT_DOUBLE,DT_FLOAT
|
||||
input1.name=b
|
||||
input1.type=DT_DOUBLE,DT_FLOAT
|
||||
input2.name=x
|
||||
input2.type=DT_DOUBLE,DT_FLOAT
|
||||
output0.name=z
|
||||
output0.type=DT_DOUBLE,DT_FLOAT
|
||||
[DivNoNan]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x1
|
||||
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
input1.name=x2
|
||||
input1.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
output0.name=y
|
||||
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
[Expm1]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
output0.name=y
|
||||
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
[ArgMaxWithValue]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
output0.name=indice
|
||||
output0.type=DT_INT32
|
||||
output1.name=values
|
||||
output1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
[Bucketize]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=input
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_INT32,DT_INT64
|
||||
output0.name=output
|
||||
output0.type=DT_INT32
|
||||
[DiagPart]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT32,DT_INT64
|
||||
output0.name=y
|
||||
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT32,DT_INT64
|
||||
[CheckNumerics]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
output0.name=y
|
||||
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
[CumProd]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
input1.name=axis
|
||||
input1.type=DT_INT32,DT_INT64
|
||||
output0.name=y
|
||||
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
[Cos]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
output0.name=y
|
||||
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
[AdaptiveAvgPool2D]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
output0.name=y
|
||||
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
[AdjustSaturation]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=image
|
||||
input0.type=DT_FLOAT,DT_FLOAT16
|
||||
input1.name=scale
|
||||
input1.type=DT_FLOAT
|
||||
output0.name=y
|
||||
output0.type=DT_FLOAT,DT_FLOAT16
|
||||
[ACosGrad]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=y
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
input1.name=dy
|
||||
input1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
output0.name=z
|
||||
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
[AffineGridGrad]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=y_grad
|
||||
input0.type=DT_FLOAT,DT_FLOAT16
|
||||
input1.name=x_size
|
||||
input1.type=DT_INT32,DT_INT64
|
||||
output0.name=x_grad
|
||||
output0.type=DT_FLOAT,DT_FLOAT16
|
||||
[AddN]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
output0.name=y
|
||||
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
[Div]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT8
|
||||
input1.name=y
|
||||
input1.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT8
|
||||
output0.name=output
|
||||
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT8
|
||||
[BiasAddGrad]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
output0.name=y
|
||||
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
[ArgMinWithValue]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
output0.name=index
|
||||
output0.type=DT_INT32
|
||||
output1.name=values
|
||||
output1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
[AdaptiveAvgPool3D]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
|
||||
output0.name=y
|
||||
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
|
||||
[AdaptiveAvgPool2DGrad]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=input_grad
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
input1.name=orig_input_shape
|
||||
input1.type=DT_INT64
|
||||
output0.name=output_grad
|
||||
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
[Hypot]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x1
|
||||
input0.type=DT_DOUBLE,DT_FLOAT
|
||||
input1.name=x2
|
||||
input1.type=DT_DOUBLE,DT_FLOAT
|
||||
output0.name=y
|
||||
output0.type=DT_DOUBLE,DT_FLOAT
|
||||
[Lcm]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x1
|
||||
input0.type=DT_INT32,DT_INT64
|
||||
input1.name=x2
|
||||
input1.type=DT_INT32,DT_INT64
|
||||
output0.name=y
|
||||
output0.type=DT_INT32,DT_INT64
|
||||
[MatrixExp]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
output0.name=y
|
||||
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
[Log1p]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
output0.name=y
|
||||
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
[LessEqual]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x1
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
input1.name=x2
|
||||
input1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
output0.name=y
|
||||
output0.type=DT_BOOL
|
||||
[Heaviside]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
input1.name=values
|
||||
input1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
output0.name=y
|
||||
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
[MaskedSelect]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_FLOAT,DT_INT32
|
||||
input1.name=mask
|
||||
input1.type=DT_BOOL
|
||||
output0.name=y
|
||||
output0.type=DT_FLOAT,DT_INT32
|
||||
[Multinomial]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=input
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
input1.name=num_sample
|
||||
input1.type=DT_INT32
|
||||
input2.name=count
|
||||
input2.type=DT_UINT64
|
||||
input3.name=state
|
||||
input3.type=DT_UINT64
|
||||
output0.name=output
|
||||
output0.type=DT_INT32,DT_INT64
|
||||
[MatrixDeterminant]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_DOUBLE,DT_FLOAT
|
||||
output0.name=y
|
||||
output0.type=DT_DOUBLE,DT_FLOAT
|
||||
[FillDiagonal]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=input_x
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
|
||||
output0.name=y
|
||||
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
|
||||
[Gcd]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x1
|
||||
input0.type=DT_INT32,DT_INT64
|
||||
input1.name=x2
|
||||
input1.type=DT_INT32,DT_INT64
|
||||
output0.name=y
|
||||
output0.type=DT_INT32,DT_INT64
|
||||
[MatrixTriangularSolve]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=matrix
|
||||
input0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT
|
||||
input1.name=rhs
|
||||
input1.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT
|
||||
output0.name=y
|
||||
output0.type=DT_COMPLEX128,DT_COMPLEX64,DT_DOUBLE,DT_FLOAT
|
||||
[FloorDiv]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT8
|
||||
input1.name=y
|
||||
input1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT8
|
||||
output0.name=output
|
||||
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT8
|
||||
[LuUnpack]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=LU_data
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
|
||||
input1.name=LU_pivots
|
||||
input1.type=DT_INT8, DT_INT16,DT_INT32,DT_INT64,DT_UINT8
|
||||
output0.name=pivots
|
||||
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
|
||||
output1.name=L
|
||||
output1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
|
||||
output1.name=U
|
||||
output1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
|
||||
[LuUnpackGrad]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=L_grad
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
|
||||
input1.name=U_grad
|
||||
input1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
|
||||
input2.name=LU_data
|
||||
input2.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
|
||||
output0.name=L_data_grad
|
||||
output0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
|
||||
output1.name=U_data_grad
|
||||
output1.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT8
|
||||
[IsInf]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_DOUBLE,DT_FLOAT,DT_FLOAT16
|
||||
output0.name=y
|
||||
output0.type=DT_BOOL
|
||||
[Dropout2D]
|
||||
opInfo.subTypeOfInferShape = 1
|
||||
opInfo.opsFlag = OPS_FLAG_CLOSE
|
||||
opInfo.engine = DNN_VM_AICPU
|
||||
opInfo.flagPartial = False
|
||||
opInfo.computeCost = 100
|
||||
opInfo.flagAsync = False
|
||||
opInfo.opKernelLib = CUSTAICPUKernel
|
||||
opInfo.formatAgnostic = False
|
||||
opInfo.userDefined = True
|
||||
opInfo.workspaceSize = 1024
|
||||
opInfo.kernelSo = libcust_cpu_kernels.so
|
||||
opInfo.functionName = RunCpuKernel
|
||||
input0.name=x
|
||||
input0.type=DT_BOOL,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
output0.name=y
|
||||
output0.type=DT_BOOL,DT_DOUBLE,DT_FLOAT,DT_FLOAT16,DT_INT16,DT_INT32,DT_INT64,DT_INT8,DT_UINT16,DT_UINT32,DT_UINT64,DT_UINT8
|
||||
output1.name=mask
|
||||
output1.type=DT_BOOL
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright (c) 2023 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "custom_op_proto/cust_nn_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
CUST_IMPLEMT_INFERFUNC(AdaptiveAvgPool2DGrad, AdaptiveAvgPool2dGradInferShape) {
|
||||
std::vector<std::string> input_infer_depends = {"orig_input_shape"};
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
op_desc->SetOpInferDepends(input_infer_depends);
|
||||
DataType input_dtype = op.GetInputDescByName("input_grad").GetDataType();
|
||||
Shape output_shape;
|
||||
Tensor orig_input_shape_tensor;
|
||||
if (op.GetInputConstData("orig_input_shape", orig_input_shape_tensor) != GRAPH_SUCCESS) {
|
||||
auto output_desc = op.GetOutputDescByName("output_grad");
|
||||
output_desc.SetDataType(input_dtype);
|
||||
output_desc.SetShape(Shape(ge::UNKNOWN_RANK));
|
||||
return op.UpdateOutputDesc("output_grad", output_desc);
|
||||
}
|
||||
MakeShapeFromShapeTensor(orig_input_shape_tensor, output_shape, op);
|
||||
TensorDesc output_grad = op.GetOutputDescByName("output_grad");
|
||||
output_grad.SetShape(output_shape);
|
||||
output_grad.SetDataType(input_dtype);
|
||||
return op.UpdateOutputDesc("output_grad", output_grad);
|
||||
}
|
||||
|
||||
CUST_INFER_FUNC_REG(AdaptiveAvgPool2DGrad, AdaptiveAvgPool2dGradInferShape);
|
||||
} // namespace ge
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright (c) 2023 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "custom_op_proto/cust_nn_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
// ---------------AdaptiveAvgPool2D-------------------
|
||||
CUST_IMPLEMT_INFERFUNC(AdaptiveAvgPool2D, AdaptiveAvgPool2dInferShape) {
|
||||
OP_LOGI(TbeGetName(op).c_str(), " AdaptiveAvgPool2d inferShape begin!");
|
||||
const size_t DIM_SIZE2 = 2;
|
||||
auto input_tensor_desc = op.GetInputDescByName("x");
|
||||
auto shape = input_tensor_desc.GetShape();
|
||||
// get output_size
|
||||
std::vector<int64_t> ouput_size_list;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("output_size", ouput_size_list)) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "GetOpAttr ouput_size_list failed!");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
// check output size
|
||||
if (ouput_size_list.size() != DIM_SIZE2) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "length of output_size must be 2");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::vector<int64_t> dims_input = shape.GetDims();
|
||||
// set output shape
|
||||
std::vector<int64_t> dim_vector;
|
||||
for (size_t i = 0; i < dims_input.size(); i++) {
|
||||
int64_t dims = dims_input[i];
|
||||
dim_vector.push_back(dims);
|
||||
}
|
||||
size_t index0 = dims_input.size() - 2;
|
||||
size_t index1 = dims_input.size() - 1;
|
||||
dim_vector[index0] = ouput_size_list[0];
|
||||
dim_vector[index1] = ouput_size_list[1];
|
||||
TensorDesc td = op.GetOutputDescByName("y");
|
||||
DataType input_dtype = input_tensor_desc.GetDataType();
|
||||
Shape output_shape(dim_vector);
|
||||
td.SetShape(output_shape);
|
||||
td.SetDataType(input_dtype);
|
||||
(void)op.UpdateOutputDesc("y", td);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(AdaptiveAvgPool2D, AdaptiveAvgPool2dVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
CUST_INFER_FUNC_REG(AdaptiveAvgPool2D, AdaptiveAvgPool2dInferShape);
|
||||
CUST_VERIFY_FUNC_REG(AdaptiveAvgPool2D, AdaptiveAvgPool2dVerify);
|
||||
// ---------------AdaptiveAvgPool2D End---------------
|
||||
} // namespace ge
|
|
@ -26,15 +26,13 @@ IMPLEMT_COMMON_INFERFUNC(AdaptiveAvgPool3dInferShape) {
|
|||
{ge::FORMAT_DHWCN, "DHWCN"}, {ge::FORMAT_NDHWC, "NDHWC"}, {ge::FORMAT_NCDHW, "NCDHW"}};
|
||||
|
||||
// verify the dim of output_size
|
||||
auto output_size_desc = op.GetInputDescByName("output_size");
|
||||
auto output_size_dim = output_size_desc.GetShape().GetDimNum();
|
||||
ge::AscendString op_name;
|
||||
(void)op.GetName(op_name);
|
||||
if (output_size_dim != 1) {
|
||||
OP_LOGE("AdaptiveAvgPool3d", "Num Dim of output_szie is invalid");
|
||||
std::vector<int64_t> output_size;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("output_size", output_size)) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "GetOpAttr output_size failed!");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
ge::AscendString op_name;
|
||||
(void)op.GetName(op_name);
|
||||
auto input_desc = op.GetInputDescByName("x");
|
||||
TensorDesc out_desc = op.GetOutputDescByName("y");
|
||||
|
||||
|
@ -56,38 +54,20 @@ IMPLEMT_COMMON_INFERFUNC(AdaptiveAvgPool3dInferShape) {
|
|||
|
||||
std::vector<int64_t> input_size_shape = input_desc.GetShape().GetDims();
|
||||
auto input_size_dim_num = input_size_shape.size();
|
||||
std::vector<int64_t> output_shape(input_size_dim_num);
|
||||
for (uint64_t i = 0; i < input_size_dim_num - 3; ++i) {
|
||||
output_shape[i] = input_size_shape[i];
|
||||
}
|
||||
|
||||
Tensor output_size_tensor;
|
||||
if (op.GetInputConstData("output_size", output_size_tensor) != GRAPH_SUCCESS) {
|
||||
OP_LOGE("AdaptiveAvgPool3d", "failed to get tensor from output_size");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int32_t *output_size_data = reinterpret_cast<int32_t *>(output_size_tensor.GetData());
|
||||
if (output_size_data == nullptr) {
|
||||
OP_LOGE("AdaptiveAvgPool3d", "output_size data is invalid");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto output_size_num = output_size_desc.GetShape().GetShapeSize();
|
||||
std::vector<int64_t> output_shape(input_size_shape.begin(), input_size_shape.end());
|
||||
auto output_size_num = output_size.size();
|
||||
if (output_size_num == 1) {
|
||||
for (uint64_t i = input_size_dim_num - 3; i < input_size_dim_num; ++i) {
|
||||
if (output_size_data[0] < 0) {
|
||||
OP_LOGE("AdaptiveAvgPool3d", "Value of output_size can\'t be negative");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
if (output_size[0] < 0) {
|
||||
continue;
|
||||
}
|
||||
output_shape[i] = output_size_data[0];
|
||||
output_shape[i] = output_size[0];
|
||||
}
|
||||
} else if (output_size_num == 3) {
|
||||
for (uint64_t i = input_size_dim_num - 3; i < input_size_dim_num; ++i) {
|
||||
auto data = output_size_data[i - input_size_dim_num + 3];
|
||||
auto data = output_size[i - input_size_dim_num + 3];
|
||||
if (data < 0) {
|
||||
OP_LOGE("AdaptiveAvgPool3d", "Value of output_size can\'t be negative");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
continue;
|
||||
}
|
||||
output_shape[i] = data;
|
||||
}
|
||||
|
|
|
@ -26,14 +26,14 @@ CUST_IMPLEMT_INFERFUNC(AdjustContrast, AdjustContrastInfer) {
|
|||
GeShape shape;
|
||||
std::string err_msg;
|
||||
auto contrast_factor_desc = op_desc->MutableInputDesc(1);
|
||||
if (WithRank(contrast_factor_desc, 0, shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(contrast_factor_desc, 0, shape, op) != GRAPH_SUCCESS) {
|
||||
err_msg = GetShapeErrMsg(1, DebugString(contrast_factor_desc->GetShape().GetDims()), "scalar");
|
||||
err_msg = string("failed to call WithRank function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
auto images_desc = op_desc->MutableInputDesc(0);
|
||||
if (WithRankAtLeast(images_desc, 3, shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRankAtLeast(images_desc, 3, shape, op) != GRAPH_SUCCESS) {
|
||||
err_msg = GetShapeErrMsg(0, DebugString(images_desc->GetShape().GetDims()), "at least 3D");
|
||||
err_msg = string("failed to call WithRankAtLeast function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/arg_max_op.h"
|
||||
#include "custom_op_proto/cust_elewise_calculation_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/op_const.h"
|
||||
|
@ -98,5 +98,5 @@ IMPLEMT_COMMON_INFERFUNC(ArgMaxInferShape) {
|
|||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(ArgMaxV2, ArgMaxInferShape);
|
||||
CUST_COMMON_INFER_FUNC_REG(ArgMax, ArgMaxInferShape);
|
||||
} // namespace ge
|
|
@ -0,0 +1,108 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/elewise_calculation_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/op_const.h"
|
||||
|
||||
namespace ge {
|
||||
IMPLEMT_COMMON_INFERFUNC(ArgMinInferShape) {
|
||||
// get all input desc
|
||||
const vector<string> depend_names = {"dimension"};
|
||||
PREPARE_DYNAMIC_SHAPE(depend_names);
|
||||
auto op_info_arg = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto input_desc = op_info_arg->MutableInputDesc("x");
|
||||
auto const_desc = op_info_arg->MutableInputDesc("dimension");
|
||||
auto y_desc = op_info_arg->MutableOutputDesc("y");
|
||||
|
||||
// get and set output dtype
|
||||
ge::DataType dtype;
|
||||
if (op.GetAttr("dtype", dtype) == GRAPH_SUCCESS) {
|
||||
y_desc->SetDataType(dtype);
|
||||
} else {
|
||||
OP_LOGW(TbeGetName(op).c_str(), "get attr dtype failed.");
|
||||
y_desc->SetDataType(DT_INT32);
|
||||
}
|
||||
|
||||
// get x shape
|
||||
auto x_shape = input_desc->MutableShape().GetDims();
|
||||
// if x_shape == -2, set output -2
|
||||
if (IsUnknownRankShape(x_shape)) {
|
||||
y_desc->SetShape(GeShape(x_shape));
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// if x_shape.size() < 2, set output scalar
|
||||
if (x_shape.size() < 2) {
|
||||
vector<int64_t> output_shape;
|
||||
y_desc->SetShape(GeShape(output_shape));
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// read dimension const value
|
||||
vector<int64_t> dimension_value;
|
||||
auto dimension_idx = static_cast<uint32_t>(op_info_arg->GetInputIndexByName("dimension"));
|
||||
const GeTensor *dimension_tensor = OpDescUtils::GetInputConstData(op, dimension_idx);
|
||||
if (dimension_tensor != nullptr) {
|
||||
auto const_dtype = const_desc->GetDataType();
|
||||
GetConstValue(op, dimension_tensor, const_dtype, dimension_value);
|
||||
// verify dimension_value
|
||||
if (dimension_value.size() != 1) {
|
||||
string error_msg = ConcatString("the element size of input[dimension] should be equal to 1, but get ",
|
||||
dimension_value.size(), ".");
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), error_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
int64_t dimension = dimension_value[0] < 0 ? dimension_value[0] + x_shape.size() : dimension_value[0];
|
||||
if (dimension >= static_cast<int64_t>(x_shape.size())) {
|
||||
string error_msg = ConcatString("the value of input[dimension] must be range at input shape size,",
|
||||
" but get input[dimension] value ", dimension_value[0], ", input[x] shape size ",
|
||||
x_shape.size(), ".");
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), error_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
vector<int64_t> output_shape(x_shape);
|
||||
output_shape.erase(output_shape.begin() + dimension);
|
||||
y_desc->SetShape(GeShape(output_shape));
|
||||
|
||||
// when output is dynamic will update range
|
||||
if (IsUnknown(output_shape)) {
|
||||
std::vector<std::pair<int64_t, int64_t>> input_range;
|
||||
input_desc->GetShapeRange(input_range);
|
||||
MakeUpShapeRange(x_shape, input_range);
|
||||
input_range.erase(input_range.begin() + dimension);
|
||||
y_desc->SetShapeRange(input_range);
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// dimension is not const, set all output is -1, range is [1, -1]
|
||||
std::vector<std::pair<int64_t, int64_t>> output_range;
|
||||
vector<int64_t> output_shape;
|
||||
for (size_t item = 0; item < (x_shape.size() - 1); ++item) {
|
||||
output_shape.push_back(-1);
|
||||
}
|
||||
MakeUpShapeRange(output_shape, output_range);
|
||||
y_desc->SetShape(GeShape(output_shape));
|
||||
y_desc->SetShapeRange(output_range);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(ArgMin, ArgMinInferShape);
|
||||
} // namespace ge
|
|
@ -21,6 +21,161 @@
|
|||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------Expand Begin-------------------
|
||||
template <typename T>
|
||||
static bool ExpandCalDim(const Tensor &data, std::vector<int64_t> &vec_dim, std::vector<int64_t> &x_dims,
|
||||
std::vector<std::pair<int64_t, int64_t>> &range_vector) {
|
||||
int64_t len_x = x_dims.size();
|
||||
int64_t len_shape = data.GetSize() / sizeof(T);
|
||||
int64_t diff = abs(len_x - len_shape);
|
||||
const char *op_name = "Expand";
|
||||
|
||||
std::string xShape = to_string(x_dims);
|
||||
OP_LOGD(op_name, "Get shape of [expand's x] %s", xShape.c_str());
|
||||
|
||||
const T *pdata = reinterpret_cast<const T *>(data.GetData());
|
||||
std::vector<int64_t> shape_dims;
|
||||
for (int64_t i = 0; i < len_shape; i++) {
|
||||
T dim = pdata[i];
|
||||
shape_dims.push_back(dim);
|
||||
}
|
||||
std::string shapeVal = to_string(shape_dims);
|
||||
OP_LOGD(op_name, "Get constValue val of [expand's shape] %s", shapeVal.c_str());
|
||||
|
||||
const bool is_shape_less = (len_shape < len_x);
|
||||
|
||||
for (int64_t i = 0; i < diff; i++) {
|
||||
T dim = 0;
|
||||
if (is_shape_less) {
|
||||
dim = x_dims[i];
|
||||
} else {
|
||||
dim = pdata[i];
|
||||
}
|
||||
if (dim == -1) {
|
||||
range_vector.push_back(std::make_pair(1, -1));
|
||||
} else {
|
||||
range_vector.push_back(std::make_pair(dim, dim));
|
||||
}
|
||||
vec_dim.push_back(dim);
|
||||
}
|
||||
|
||||
int64_t upb = len_shape;
|
||||
if (is_shape_less) {
|
||||
upb = len_x;
|
||||
}
|
||||
for (int64_t i = diff; i < upb; i++) {
|
||||
int64_t idx = i - diff;
|
||||
T dim = 0;
|
||||
if (is_shape_less) {
|
||||
idx = i;
|
||||
dim = pdata[i - diff];
|
||||
} else {
|
||||
dim = pdata[i];
|
||||
}
|
||||
if (dim == -1 || x_dims[idx] == -1) {
|
||||
vec_dim.push_back(-1);
|
||||
range_vector.push_back(std::make_pair(1, -1));
|
||||
continue;
|
||||
}
|
||||
if (dim == 0) {
|
||||
vec_dim.push_back(0);
|
||||
range_vector.push_back(std::make_pair(0, 0));
|
||||
continue;
|
||||
}
|
||||
if ((x_dims[idx] != dim) && (x_dims[idx] != 1) && (dim != 1)) {
|
||||
return false;
|
||||
}
|
||||
if (x_dims[idx] > dim) {
|
||||
vec_dim.push_back(x_dims[idx]);
|
||||
range_vector.push_back(std::make_pair(x_dims[idx], x_dims[idx]));
|
||||
} else {
|
||||
vec_dim.push_back(dim);
|
||||
range_vector.push_back(std::make_pair(dim, dim));
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
IMPLEMT_INFERFUNC(Expand, ExpandInferShape) {
|
||||
const char *op_name = "Expand";
|
||||
OP_LOGD(op_name, "ExpandInferShape start.");
|
||||
const vector<string> const_names = {"shape"};
|
||||
PREPARE_DYNAMIC_SHAPE(const_names);
|
||||
TensorDesc tensordesc_input = op.GetInputDescByName("x");
|
||||
Shape x_shape = tensordesc_input.GetShape();
|
||||
std::vector<int64_t> x_dims = x_shape.GetDims();
|
||||
DataType x_dtype = tensordesc_input.GetDataType();
|
||||
|
||||
Tensor data;
|
||||
std::vector<int64_t> vec_dim;
|
||||
|
||||
TensorDesc tensordesc_output = op.GetOutputDescByName("y");
|
||||
tensordesc_output.SetDataType(x_dtype);
|
||||
|
||||
TensorDesc tensordesc_shape = op.GetInputDescByName("shape");
|
||||
size_t dim_num = tensordesc_shape.GetShape().GetDimNum();
|
||||
std::vector<int64_t> empty_dim_vec = tensordesc_shape.GetShape().GetDims();
|
||||
for (size_t i = 0; i < dim_num; i++) {
|
||||
if (empty_dim_vec[i] == 0) {
|
||||
tensordesc_output.SetShape(ge::Shape(empty_dim_vec));
|
||||
return op.UpdateOutputDesc("y", tensordesc_output);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> range_vector;
|
||||
|
||||
if (op.GetInputConstData("shape", data) != GRAPH_SUCCESS) {
|
||||
OP_LOGD(op_name, "Get constValue failed of [shape]");
|
||||
vector<int64_t> shape_dims = tensordesc_shape.GetShape().GetDims();
|
||||
size_t dim_num = shape_dims.size();
|
||||
|
||||
if (dim_num > 1) {
|
||||
OP_LOGE(op_name, "The dim numbers of shape [%zu] are more than one.", dim_num);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
int64_t max_len = x_dims.size();
|
||||
if (shape_dims[0] > max_len) {
|
||||
max_len = shape_dims[0];
|
||||
}
|
||||
for (int64_t item = 0; item < max_len; ++item) {
|
||||
vec_dim.push_back(-1);
|
||||
range_vector.push_back(std::make_pair(1, -1));
|
||||
}
|
||||
} else {
|
||||
OP_LOGD(op_name, "Get constValue succeeded of [shape]");
|
||||
vector<int64_t> shape_dims = tensordesc_shape.GetShape().GetDims();
|
||||
if (shape_dims.size() > 1) {
|
||||
OP_LOGE(op_name, "The dim numbers of shape [%zu] are more than one.", shape_dims.size());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
DataType data_type = tensordesc_shape.GetDataType();
|
||||
if (data_type == DT_INT32) {
|
||||
if (!ExpandCalDim<int32_t>(data, vec_dim, x_dims, range_vector)) {
|
||||
OP_LOGE(op_name, "Data shape are not compatible!");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
} else if (data_type == DT_INT64) {
|
||||
if (!ExpandCalDim<int64_t>(data, vec_dim, x_dims, range_vector)) {
|
||||
OP_LOGE(op_name, "Data shape are not compatible!");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
} else {
|
||||
OP_LOGE(op_name, "Data type not supported!");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
tensordesc_output.SetShape(ge::Shape(vec_dim));
|
||||
tensordesc_output.SetShapeRange(range_vector);
|
||||
(void)op.UpdateOutputDesc("y", tensordesc_output);
|
||||
OP_LOGD(op_name, "ExpandInferShape finish.");
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(Expand, ExpandInferShape);
|
||||
// ----------------Expand END---------------------
|
||||
|
||||
// ---------------SliceGrad-------------------
|
||||
CUST_IMPLEMT_INFERFUNC(SliceGrad, SliceGradInfer) {
|
||||
TensorDesc x_desc = op.GetInputDescByName("x");
|
||||
|
@ -33,4 +188,278 @@ CUST_IMPLEMT_INFERFUNC(SliceGrad, SliceGradInfer) {
|
|||
|
||||
CUST_INFER_FUNC_REG(SliceGrad, SliceGradInfer);
|
||||
// ---------------SliceGrad End---------------
|
||||
|
||||
// ---------------MaskedSelectGrad-------------------
|
||||
CUST_IMPLEMT_INFERFUNC(MaskedSelectGrad, MaskedSelectGradInfer) {
|
||||
TensorDesc x_desc = op.GetInputDescByName("x");
|
||||
if (op.UpdateOutputDesc("dx", x_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE("MaskedSelectGrad", "Update output desc failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_INFER_FUNC_REG(MaskedSelectGrad, MaskedSelectGradInfer);
|
||||
// ---------------MaskedSelectGrad End---------------
|
||||
|
||||
// -------------------------------IdentityN Begin-------------------------------
|
||||
// //
|
||||
IMPLEMT_INFERFUNC(IdentityN, IdentityNInfer) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
for (size_t i = 0; i < op.GetInputsSize(); i++) {
|
||||
auto input_desc = op_desc->MutableInputDesc(i);
|
||||
auto input_dims = input_desc->MutableShape().GetDims();
|
||||
auto output_desc = op_desc->MutableOutputDesc(i);
|
||||
auto intput_dtype = input_desc->GetDataType();
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> input_range;
|
||||
input_desc->GetShapeRange(input_range);
|
||||
output_desc->SetShape(GeShape(input_dims));
|
||||
output_desc->SetOriginShape(GeShape(input_dims));
|
||||
output_desc->SetDataType(intput_dtype);
|
||||
output_desc->SetShapeRange(input_range);
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(IdentityN, IdentityNInfer);
|
||||
// -------------------------------IdentityN End-------------------------------
|
||||
// //
|
||||
|
||||
// -------------------------------LowerBound------------------------------- //
|
||||
IMPLEMT_INFERFUNC(LowerBound, LowerBoundInfer) {
|
||||
TensorDesc sorted_x_desc = op.GetInputDescByName("sorted_x");
|
||||
TensorDesc values_desc = op.GetInputDescByName("values");
|
||||
Shape unused_shape;
|
||||
if (WithRank(sorted_x_desc, 2, unused_shape, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
|
||||
TbeGetName(op),
|
||||
ConcatString("call WithRank failed, ", GetShapeErrMsg(0, DebugString(sorted_x_desc.GetShape().GetDims()), "2D")));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(values_desc, 2, unused_shape, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
|
||||
TbeGetName(op),
|
||||
ConcatString("call WithRank failed, ", GetShapeErrMsg(1, DebugString(values_desc.GetShape().GetDims()), "2D")));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
DataType out_type;
|
||||
if (op.GetAttr("out_type", out_type) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Get attr [out_type] failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc y_desc = op.GetOutputDescByName("y");
|
||||
y_desc.SetDataType(out_type);
|
||||
y_desc.SetShape(values_desc.GetShape());
|
||||
if (op.UpdateOutputDesc("y", y_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Update [y] desc failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(LowerBound, LowerBoundInfer);
|
||||
// -------------------------------LowerBound END-------------------------------
|
||||
// //
|
||||
|
||||
// -------------------------------ListDiff------------------------------- //
|
||||
IMPLEMT_INFERFUNC(ListDiff, ListDiffInfer) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto x_desc = op_desc->MutableInputDesc(0);
|
||||
auto y_desc = op_desc->MutableInputDesc(1);
|
||||
|
||||
Shape unused_shape;
|
||||
std::string error_msg;
|
||||
if (WithRank(x_desc, 1, unused_shape, op) != GRAPH_SUCCESS) {
|
||||
std::string error_msg = GetShapeErrMsg(0, DebugString(x_desc->GetShape().GetDims()), "1D");
|
||||
error_msg = string("failed to call WithRank function, ") + error_msg;
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (WithRank(y_desc, 1, unused_shape, op) != GRAPH_SUCCESS) {
|
||||
std::string error_msg = GetShapeErrMsg(1, DebugString(y_desc->GetShape().GetDims()), "1D");
|
||||
error_msg = string("failed to call WithRank function, ") + error_msg;
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
DataType output_type = x_desc->GetDataType();
|
||||
DataType index_type;
|
||||
if (op.GetAttr("out_idx", index_type) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("failed to get attr[out_idx]."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape result({ge::UNKNOWN_DIM});
|
||||
auto output_desc = op_desc->MutableOutputDesc(0);
|
||||
output_desc->SetShape(GeShape(result));
|
||||
output_desc->SetDataType(output_type);
|
||||
|
||||
auto index_desc = op_desc->MutableOutputDesc(1);
|
||||
index_desc->SetShape(GeShape(result));
|
||||
index_desc->SetDataType(index_type);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(ListDiff, ListDiffInfer);
|
||||
// -------------------------------ListDiff END------------------------------- //
|
||||
|
||||
// ----------------HammingWindow Begin---------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(HammingWindowInferShape) {
|
||||
std::vector<int64_t> input_dim = op.GetInputDesc(0).GetShape().GetDims();
|
||||
if (input_dim.size() != 1) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Tensor length input must be 1D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Tensor length_tensor;
|
||||
int64_t length_data;
|
||||
if (op.GetInputConstData("length", length_tensor) == GRAPH_SUCCESS) {
|
||||
uint8_t *length = length_tensor.GetData();
|
||||
length_data = static_cast<int64_t>(*length);
|
||||
} else {
|
||||
length_data = UNKNOWN_DIM;
|
||||
}
|
||||
std::vector<int64_t> output_dim;
|
||||
if (length_data != UNKNOWN_DIM && length_data < 0) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Non-negative window length required, got [%ld].", length_data);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (length_data != 0) {
|
||||
output_dim.push_back(length_data);
|
||||
}
|
||||
ge::Shape output_shape = ge::Shape(output_dim);
|
||||
|
||||
Operator::OpInt dtype;
|
||||
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
|
||||
dtype = 0;
|
||||
}
|
||||
DataType output_dtype = static_cast<DataType>(dtype);
|
||||
|
||||
TensorDesc output_desc = op.GetOutputDescByName("y");
|
||||
output_desc.SetShape(output_shape);
|
||||
output_desc.SetDataType(output_dtype);
|
||||
op.UpdateOutputDesc("y", output_desc);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(HammingWindow, HammingWindowInferShape);
|
||||
// ----------------HammingWindow End---------------------
|
||||
|
||||
// ----------------Mvlgamma Begin-------------------
|
||||
CUST_IMPLEMT_INFERFUNC(Mvlgamma, MvlgammaInferShape) {
|
||||
const char *op_name = "Mvlgamma";
|
||||
OP_LOGD(op_name, "MvlgammaInferShape begin.");
|
||||
TensorDesc tensordesc_input = op.GetInputDescByName("x");
|
||||
Shape input_shape = tensordesc_input.GetShape();
|
||||
std::vector<int64_t> dims_input = input_shape.GetDims();
|
||||
DataType input_dtype = tensordesc_input.GetDataType();
|
||||
|
||||
TensorDesc tensordesc_output1 = op.GetOutputDescByName("y");
|
||||
tensordesc_output1.SetDataType(input_dtype);
|
||||
tensordesc_output1.SetShape(ge::Shape(dims_input));
|
||||
|
||||
(void)op.UpdateOutputDesc("y", tensordesc_output1);
|
||||
OP_LOGD(op_name, "MvlgammaInferShape end.");
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(Mvlgamma, MvlgammaVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
CUST_INFER_FUNC_REG(Mvlgamma, MvlgammaInferShape);
|
||||
CUST_VERIFY_FUNC_REG(Mvlgamma, MvlgammaVerify);
|
||||
// ----------------Mvlgamma END---------------------
|
||||
|
||||
// ----------------MvlgammaGrad Begin-------------------
|
||||
CUST_IMPLEMT_INFERFUNC(MvlgammaGrad, MvlgammaGradInferShape) {
|
||||
const char *op_name = "MvlgammaGrad";
|
||||
OP_LOGD(op_name, "MvlgammaGradInferShape begin.");
|
||||
TensorDesc tensordesc_input = op.GetInputDescByName("y_grad");
|
||||
Shape input_shape = tensordesc_input.GetShape();
|
||||
std::vector<int64_t> dims_input = input_shape.GetDims();
|
||||
DataType input_dtype = tensordesc_input.GetDataType();
|
||||
|
||||
TensorDesc tensordesc_output1 = op.GetOutputDescByName("x_grad");
|
||||
tensordesc_output1.SetDataType(input_dtype);
|
||||
tensordesc_output1.SetShape(ge::Shape(dims_input));
|
||||
|
||||
(void)op.UpdateOutputDesc("x_grad", tensordesc_output1);
|
||||
OP_LOGD(op_name, "MvlgammaGradInferShape end.");
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(MvlgammaGrad, MvlgammaGradVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
CUST_INFER_FUNC_REG(MvlgammaGrad, MvlgammaGradInferShape);
|
||||
CUST_VERIFY_FUNC_REG(MvlgammaGrad, MvlgammaGradVerify);
|
||||
// ----------------MvlgammaGrad END---------------------
|
||||
|
||||
// --------------------------LogSpace---------------------
|
||||
static bool CheckSteps(const Operator &op, const string &attr_num_steps) {
|
||||
int64_t steps = 0;
|
||||
int64_t steps_ori = 100;
|
||||
if (ge::GRAPH_SUCCESS != op.GetAttr(attr_num_steps.c_str(), steps)) {
|
||||
steps = steps_ori;
|
||||
}
|
||||
if (steps < 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(LogSpace, LogSpaceVerify) {
|
||||
AscendString opName;
|
||||
op.GetName(opName);
|
||||
if (op.GetInputDescByName("start").GetShape().GetDims().size() != 1) {
|
||||
OP_LOGE(opName.GetString(), "Input start size must be 1.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (op.GetInputDescByName("end").GetShape().GetDims().size() != 1) {
|
||||
OP_LOGE(opName.GetString(), "Input end size must be 1.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
DataType input_type_start = op.GetInputDescByName("start").GetDataType();
|
||||
DataType input_type_end = op.GetInputDescByName("end").GetDataType();
|
||||
if (input_type_start != input_type_end) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
// Obtains the processing function of the output tensor description.
|
||||
IMPLEMT_COMMON_INFERFUNC(LogSpaceInferShape) {
|
||||
AscendString opName1;
|
||||
op.GetName(opName1);
|
||||
TensorDesc v_output_desc = op.GetOutputDescByName("y");
|
||||
int64_t steps;
|
||||
int64_t num_rows = 1;
|
||||
op.GetAttr("steps", steps);
|
||||
if (!CheckSteps(op, "steps")) {
|
||||
OP_LOGE(opName1.GetString(), "the attr 'steps' should be greater than or equal to 0.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::vector<int64_t> dim_vec;
|
||||
dim_vec.push_back(num_rows);
|
||||
dim_vec.push_back(steps);
|
||||
v_output_desc.SetShape(ge::Shape(dim_vec));
|
||||
int64_t dtype = 1;
|
||||
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
|
||||
v_output_desc.SetDataType(DT_FLOAT16);
|
||||
} else {
|
||||
if (dtype == 1) {
|
||||
v_output_desc.SetDataType(DT_FLOAT16);
|
||||
}
|
||||
if (dtype == 0) {
|
||||
v_output_desc.SetDataType(DT_FLOAT);
|
||||
}
|
||||
}
|
||||
(void)op.UpdateOutputDesc("y", v_output_desc);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(LogSpace, LogSpaceInferShape);
|
||||
// Registered verify function
|
||||
CUST_VERIFY_FUNC_REG(LogSpace, LogSpaceVerify);
|
||||
// --------------------------LogSpace END---------------------
|
||||
} // namespace ge
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/bartlett_window_op.h"
|
||||
#include "custom_op_proto/cust_spectral_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
@ -22,14 +22,14 @@
|
|||
namespace ge {
|
||||
// ---------BartlettWindow-------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(BartlettWindowInferShape) {
|
||||
int dtype;
|
||||
Operator::OpType dtype;
|
||||
Format input_format = op.GetInputDescByName("window_length").GetFormat();
|
||||
TensorDesc out_desc = op.GetOutputDescByName("y");
|
||||
Shape unused;
|
||||
std::string err_msg;
|
||||
|
||||
// Set output shape
|
||||
if (WithRankAtMost(op.GetInputDescByName("window_length"), 1, unused, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRankAtMost(op.GetInputDescByName("window_length"), 1, unused, op) != GRAPH_SUCCESS) {
|
||||
err_msg = GetShapeErrMsg(1, DebugString(op.GetInputDescByName("window_length").GetShape().GetDims()), "at most 1D");
|
||||
err_msg = std::string("failed to call WithRankAtMost, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
|
@ -39,19 +39,19 @@ IMPLEMT_COMMON_INFERFUNC(BartlettWindowInferShape) {
|
|||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
op_desc->SetOpInferDepends(input_infer_depends);
|
||||
Tensor tensor;
|
||||
Shape output_shape({UNKNOWN_DIM});
|
||||
if (op.GetInputConstData("window_length", tensor) == GRAPH_SUCCESS) {
|
||||
auto tensor_data = reinterpret_cast<int64_t *>(tensor.GetData());
|
||||
if (*tensor_data == 0) {
|
||||
std::vector<int64_t> dim_vector = {};
|
||||
Shape output_shape(dim_vector);
|
||||
out_desc.SetShape(output_shape);
|
||||
} else {
|
||||
std::vector<int64_t> dim_vector = {};
|
||||
dim_vector.push_back(*tensor_data);
|
||||
Shape output_shape(dim_vector);
|
||||
out_desc.SetShape(output_shape);
|
||||
int64_t length;
|
||||
if (MakeDimForScalarInput(tensor, length, op) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (Vector(length, output_shape) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(),
|
||||
string("fail to gen vector shape according dim bins."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
out_desc.SetShape(output_shape);
|
||||
|
||||
// Set output dtype
|
||||
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/math_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/op_const.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
IMPLEMT_INFERFUNC(Betainc, BetaincInfer) {
|
||||
const int num_inputs = 3;
|
||||
Shape output(UNKNOWN_RANK);
|
||||
int num_scalars = 0;
|
||||
Shape some_non_scalar;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
TensorDesc input_desc = op.GetInputDesc(i);
|
||||
Shape input_shape = input_desc.GetShape();
|
||||
if (!RankKnown(input_shape)) {
|
||||
some_non_scalar = input_shape;
|
||||
} else if (input_shape.GetDimNum() == 0) {
|
||||
++num_scalars;
|
||||
} else {
|
||||
if (Merge(output, input_shape, output, op) != GRAPH_SUCCESS) {
|
||||
std::string err_msg =
|
||||
ConcatString("failed to call Merge function to merge", i, "th input shape",
|
||||
DebugString(input_shape.GetDims()), " and output[z] shape", DebugString(output.GetDims()));
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
some_non_scalar = output;
|
||||
}
|
||||
}
|
||||
|
||||
if (num_scalars == num_inputs - 1) {
|
||||
output = some_non_scalar;
|
||||
} else if (num_scalars == num_inputs) {
|
||||
TensorDesc a_desc = op.GetInputDescByName("a");
|
||||
output = a_desc.GetShape();
|
||||
}
|
||||
DataType a_type = op.GetInputDescByName("a").GetDataType();
|
||||
TensorDesc z_desc = op.GetOutputDescByName("z");
|
||||
z_desc.SetShape(output);
|
||||
z_desc.SetDataType(a_type);
|
||||
if (op.UpdateOutputDesc("z", z_desc) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("fail to update output[z] desc."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(Betainc, BetaincInfer);
|
||||
} // namespace ge
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright (c) 2023 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/math_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
IMPLEMT_INFERFUNC(Bincount, BincountInfer) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
op_desc->SetOpInferDepends({"size"});
|
||||
|
||||
GeShape unused;
|
||||
auto size_desc = op_desc->MutableInputDesc(1);
|
||||
if (WithRank(size_desc, 0, unused, op) != GRAPH_SUCCESS) {
|
||||
std::string err_msg = GetShapeErrMsg(1, DebugString(size_desc->GetShape().GetDims()), "scalar");
|
||||
err_msg = string("failed to call WithRank function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Tensor tensor;
|
||||
int64_t bins = 0;
|
||||
if (op.GetInputConstData("size", tensor) != GRAPH_SUCCESS) {
|
||||
bins = UNKNOWN_DIM;
|
||||
}
|
||||
|
||||
if (bins != UNKNOWN_DIM) {
|
||||
if (MakeDimForScalarInput(tensor, bins, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("fail to get dim from tensor of input[size]."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
Shape bins_shape;
|
||||
if (Vector(bins, bins_shape) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("fail to gen vector shape according dim bins."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
auto bins_desc = op_desc->MutableOutputDesc(0);
|
||||
bins_desc->SetShape(GeShape(bins_shape.GetDims()));
|
||||
bins_desc->SetDataType(op_desc->MutableInputDesc(2)->GetDataType());
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(Bincount, BincountInfer);
|
||||
} // namespace ge
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright (c) 2023 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/bitwise_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
// -----------------------------LeftShift-----------------------------
|
||||
IMPLEMT_VERIFIER(LeftShift, LeftShiftVerify) {
|
||||
if (!CheckTwoInputDtypeSame(op, "x", "y")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_INFERFUNC(LeftShift, LeftShiftInferShape) {
|
||||
Shape data_shape = op.GetInputDescByName("x").GetShape();
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
TensorDesc td = op.GetOutputDescByName("z");
|
||||
td.SetShape(ge::Shape(data_shape));
|
||||
td.SetDataType(input_dtype);
|
||||
(void)op.UpdateOutputDesc("z", td);
|
||||
return BROADCAST_INFER("x", "y", "z")(op);
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(LeftShift, LeftShiftInferShape);
|
||||
VERIFY_FUNC_REG(LeftShift, LeftShiftVerify);
|
||||
// -----------------------------LeftShift END-----------------------------
|
||||
} // namespace ge
|
|
@ -26,29 +26,34 @@ CUST_IMPLEMT_VERIFIER(BlackmanWindow, BlackmanWindowVerify) { return GRAPH_SUCCE
|
|||
IMPLEMT_COMMON_INFERFUNC(BlackmanWindowInferShape) {
|
||||
Shape shape;
|
||||
Shape unused;
|
||||
Operator::OpType dtype;
|
||||
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Get attr dtype failed.");
|
||||
}
|
||||
|
||||
if (WithRank(op.GetInputDesc(0), 0, unused, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(op.GetInputDesc(0), 0, unused, op) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::vector<std::string> input_infer_depends = {"window_length"};
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
op_desc->SetOpInferDepends(input_infer_depends);
|
||||
Tensor window_length_tensor;
|
||||
if (op.GetInputConstData("window_length", window_length_tensor) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
auto output_desc = op.GetOutputDescByName("y");
|
||||
output_desc.SetDataType(dtype);
|
||||
output_desc.SetShape(Shape(ge::UNKNOWN_SHAPE));
|
||||
return op.UpdateOutputDesc("y", output_desc);
|
||||
}
|
||||
int64_t length;
|
||||
if (MakeDimForScalarInput(window_length_tensor, length, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (MakeDimForScalarInput(window_length_tensor, length, op) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (Vector(length, shape) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), string("fail to gen vector shape according dim bins."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
Operator::OpType type;
|
||||
if (op.GetAttr("dtype", type) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), std::string("get attr[dtype] failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc output_desc = op.GetOutputDescByName("y");
|
||||
output_desc.SetDataType(type);
|
||||
output_desc.SetDataType(dtype);
|
||||
output_desc.SetShape(shape);
|
||||
return op.UpdateOutputDesc("y", output_desc);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/pad_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
namespace ge {
|
||||
template <typename T>
|
||||
static void CaclDims(const Tensor &data, std::vector<int64_t> &vec_dim) {
|
||||
int32_t size = data.GetSize() / sizeof(T);
|
||||
for (int32_t i = 0; i < size; i++) {
|
||||
T dim = *((T *)data.GetData() + i);
|
||||
vec_dim.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------BroadcastTo-----------------------
|
||||
IMPLEMT_INFERFUNC(BroadcastTo, BroadcastToInferShape) {
|
||||
const vector<string> depend_names = {"shape"};
|
||||
PREPARE_DYNAMIC_SHAPE(depend_names);
|
||||
Tensor data;
|
||||
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
|
||||
if (op.GetInputConstData("shape", data) != GRAPH_SUCCESS) {
|
||||
OP_LOGI(TbeGetName(op).c_str(), "Get constValue failed of [shape]");
|
||||
auto shape_desc = op_info->MutableInputDesc("shape");
|
||||
vector<int64_t> shapedims = shape_desc->MutableShape().GetDims();
|
||||
size_t dim_num = shapedims.size();
|
||||
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
|
||||
if (dim_num > 1) {
|
||||
std::string err_msg = ConcatString("the rank[", dim_num, "] of input[shape] should not be more than 1");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<int64_t> shape_vector;
|
||||
std::vector<std::pair<int64_t, int64_t>> range_vector;
|
||||
for (int64_t item = 0; item < shapedims[0]; ++item) {
|
||||
shape_vector.push_back(-1);
|
||||
range_vector.push_back(std::make_pair(1, -1));
|
||||
}
|
||||
auto output_desc = op_info->MutableOutputDesc("y");
|
||||
output_desc->SetShape(GeShape(shape_vector));
|
||||
output_desc->SetShapeRange(range_vector);
|
||||
output_desc->SetDataType(input_dtype);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
DataType data_type = data.GetTensorDesc().GetDataType();
|
||||
std::vector<int64_t> vec_dim;
|
||||
if (data_type == DT_INT32) {
|
||||
CaclDims<int32_t>(data, vec_dim);
|
||||
} else if (data_type == DT_INT64) {
|
||||
CaclDims<int64_t>(data, vec_dim);
|
||||
} else {
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
OP_LOGI(TbeGetName(op).c_str(), "the op infer shape and dtype");
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
|
||||
auto output_desc = op_info->MutableOutputDesc("y");
|
||||
output_desc->SetShape(GeShape(vec_dim));
|
||||
output_desc->SetDataType(input_dtype);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(BroadcastTo, BroadcastToInferShape);
|
||||
// --------------------BroadcastTo END-----------------------
|
||||
} // namespace ge
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/math_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
// ---------Bucketize------------------
|
||||
IMPLEMT_INFERFUNC(Bucketize, BucketizeInfer) {
|
||||
std::string str_name = TbeGetName(op);
|
||||
const char *opname = str_name.c_str();
|
||||
OP_LOGD(opname, "Enter Bucketize inferfunction!");
|
||||
|
||||
// set output shape
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
CHECK(op_desc == nullptr, OP_LOGE(opname, "op desc is null."), return GRAPH_FAILED);
|
||||
std::vector<int64_t> x_shape = op_desc->MutableInputDesc("x")->MutableShape().GetDims();
|
||||
auto output_desc = op_desc->MutableOutputDesc(0);
|
||||
output_desc->SetShape(GeShape(x_shape));
|
||||
|
||||
// set output dtype
|
||||
DataType dtype;
|
||||
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("dtype");
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("get attr[dtype] failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if ((dtype != DT_INT32) && (dtype != DT_INT64)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("dtype");
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("The attr [dtype] must be one of DT_INT32 or DT_INT64"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
output_desc->SetDataType(dtype);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_VERIFIER(Bucketize, BucketizeVerify) { return GRAPH_SUCCESS; }
|
||||
INFER_FUNC_REG(Bucketize, BucketizeInfer);
|
||||
VERIFY_FUNC_REG(Bucketize, BucketizeVerify);
|
||||
// ---------Bucketize End-------------------
|
||||
} // namespace ge
|
|
@ -14,30 +14,32 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/log_normal_reverse.h"
|
||||
#include "inc/ops/linalg_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
#include "utils/linalg_ops_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------LogNormalReverse-------------------
|
||||
// Obtains the processing function of the output tensor description.
|
||||
IMPLEMT_COMMON_INFERFUNC(LogNormalReverseInferShape) {
|
||||
TensorDesc v_output_desc = op.GetOutputDescByName("y");
|
||||
IMPLEMT_INFERFUNC(CholeskyGrad, CholeskyGradInfer) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto x_desc = op_desc->MutableInputDesc(0);
|
||||
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
Format input_format = op.GetInputDescByName("x").GetFormat();
|
||||
ge::Shape shape_input = op.GetInputDescByName("x").GetShape();
|
||||
|
||||
v_output_desc.SetShape(shape_input);
|
||||
v_output_desc.SetDataType(input_dtype);
|
||||
v_output_desc.SetFormat(input_format);
|
||||
|
||||
if (op.UpdateOutputDesc("y", v_output_desc) != GRAPH_SUCCESS) {
|
||||
GeShape y_shape;
|
||||
if (MakeBatchSquareMatrix(x_desc, y_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(),
|
||||
"Op CholeskyGrad first input x tensor make batch square matrix "
|
||||
"failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
DataType type = x_desc->GetDataType();
|
||||
auto y_desc = op_desc->MutableOutputDesc(0);
|
||||
y_desc->SetShape(y_shape);
|
||||
y_desc->SetDataType(type);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(LogNormalReverse, LogNormalReverseInferShape);
|
||||
// ----------------LogNormalReverse-------------------
|
||||
INFER_FUNC_REG(CholeskyGrad, CholeskyGradInfer);
|
||||
} // namespace ge
|
|
@ -22,8 +22,7 @@ namespace ge {
|
|||
// -----------------------CholeskyInverse---------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(CholeskyInverseInferShape) {
|
||||
TensorDesc out_desc = op.GetOutputDescByName("x");
|
||||
out_desc.SetDataType(op.GetInputDescByName("x").GetDataType());
|
||||
if (op.UpdateOutputDesc("x", out_desc) != GRAPH_SUCCESS) {
|
||||
if (op.UpdateOutputDesc("y", out_desc) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/linalg_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
#include "utils/linalg_ops_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
IMPLEMT_INFERFUNC(Cholesky, CholeskyInfer) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto x_desc = op_desc->MutableInputDesc(0);
|
||||
|
||||
GeShape y_shape;
|
||||
if (MakeBatchSquareMatrix(x_desc, y_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Op Cholesky first input x's tensor make batch square matrix failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
DataType type = x_desc->GetDataType();
|
||||
|
||||
auto y_desc = op_desc->MutableOutputDesc(0);
|
||||
y_desc->SetShape(y_shape);
|
||||
y_desc->SetDataType(type);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(Cholesky, CholeskyInfer);
|
||||
} // namespace ge
|
|
@ -0,0 +1,199 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/image_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/op_log.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
IMPLEMT_INFERFUNC(CombinedNonMaxSuppression, CombinedNonMaxSuppressionInfer) {
|
||||
DYNAMIC_SHAPE_NOT_SUPPORTED(op);
|
||||
Shape boxes;
|
||||
Shape scores;
|
||||
Shape max_output_size_per_class;
|
||||
Shape max_total_size;
|
||||
Shape unused_shape;
|
||||
|
||||
std::vector<std::string> input_infer_depends = {"max_total_size", "max_output_size_per_class"};
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
op_desc->SetOpInferDepends(input_infer_depends);
|
||||
|
||||
if (WithRank(op.GetInputDesc(0), 4, boxes, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op),
|
||||
GetShapeErrMsg(0, DebugString(op.GetInputDesc(0).GetShape().GetDims()), "4D"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(op.GetInputDesc(1), 3, scores, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op),
|
||||
GetShapeErrMsg(1, DebugString(op.GetInputDesc(1).GetShape().GetDims()), "3D"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(op.GetInputDesc(2), 0, max_output_size_per_class, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
|
||||
TbeGetName(op), GetShapeErrMsg(2, DebugString(op.GetInputDesc(2).GetShape().GetDims()), "scalar"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(op.GetInputDesc(3), 0, max_total_size, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
|
||||
TbeGetName(op), GetShapeErrMsg(3, DebugString(op.GetInputDesc(3).GetShape().GetDims()), "scalar"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(op.GetInputDesc(4), 0, unused_shape, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
|
||||
TbeGetName(op), GetShapeErrMsg(4, DebugString(op.GetInputDesc(4).GetShape().GetDims()), "scalar"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(op.GetInputDesc(5), 0, unused_shape, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
|
||||
TbeGetName(op), GetShapeErrMsg(5, DebugString(op.GetInputDesc(5).GetShape().GetDims()), "scalar"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int64_t unused = 0;
|
||||
int64_t dim1 = boxes.GetDim(0);
|
||||
int64_t dim2 = scores.GetDim(0);
|
||||
if (Merge(dim1, dim2, unused) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op),
|
||||
ConcatString("call Merge function failed to merge 0th dim of input[boxes]"
|
||||
" and input[scores], ",
|
||||
dim1, " and ", dim2));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
int64_t dim3 = boxes.GetDim(1);
|
||||
int64_t dim4 = scores.GetDim(1);
|
||||
if (Merge(dim3, dim4, unused) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op),
|
||||
ConcatString("call Merge function failed to merge 1th dim of input[boxes]"
|
||||
" and input[scores], ",
|
||||
dim3, " and ", dim4));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (boxes.GetDim(3) != 4) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op),
|
||||
ConcatString("invalid 3th dim value[", boxes.GetDim(3), "], it should be 4"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Shape boxes_shape = op.GetInputDesc(0).GetShape();
|
||||
Shape scores_shape = op.GetInputDesc(1).GetShape();
|
||||
if (ValueKnown(boxes_shape, 2) && ValueKnown(scores_shape, 2)) {
|
||||
if (boxes_shape.GetDim(2) != 1 && boxes_shape.GetDim(2) != scores_shape.GetDim(2)) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
TbeGetName(op), ConcatString("2th dim of input[boxes] and input[scores] are not equal, ", boxes_shape.GetDim(2),
|
||||
" and ", scores_shape.GetDim(2)));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor maxTotalSizeTensor;
|
||||
Tensor maxOutputSizePerClassTensor;
|
||||
if ((op.GetInputConstData("max_total_size", maxTotalSizeTensor) != GRAPH_SUCCESS) ||
|
||||
(op.GetInputConstData("max_output_size_per_class", maxOutputSizePerClassTensor) != GRAPH_SUCCESS)) {
|
||||
Shape out_shape0({-1, -1, 4});
|
||||
Shape out_shape1({-1, -1});
|
||||
Shape out_shape2({-1, -1});
|
||||
Shape out_shape3({-1});
|
||||
op.GetOutputDesc(0).SetShape(out_shape0);
|
||||
op.GetOutputDesc(1).SetShape(out_shape1);
|
||||
op.GetOutputDesc(2).SetShape(out_shape2);
|
||||
op.GetOutputDesc(3).SetShape(out_shape3);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
int64_t maxTotalSize;
|
||||
if (MakeDimForScalarInput(maxTotalSizeTensor, maxTotalSize, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
|
||||
TbeGetName(op), ConcatString("call MakeDimForScalarInput failed to get value from input[max_total_size] tensor"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (maxTotalSize <= 0) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
TbeGetName(op), ConcatString("invalid value[", maxTotalSize, "] of input[max_total_size], should be > 0"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int64_t maxOutputSizePerClass;
|
||||
if (MakeDimForScalarInput(maxOutputSizePerClassTensor, maxOutputSizePerClass, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
|
||||
TbeGetName(op),
|
||||
ConcatString("call MakeDimForScalarInput failed to get value from input[max_output_size_per_class] tensor"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int64_t output_size;
|
||||
bool pad_per_class;
|
||||
if (op.GetAttr("pad_per_class", pad_per_class) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), std::string("get attr[pad_per_class] failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (!pad_per_class) {
|
||||
output_size = maxTotalSize;
|
||||
} else {
|
||||
if (maxOutputSizePerClass <= 0) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
TbeGetName(op),
|
||||
ConcatString("invalid value[", maxOutputSizePerClass, "] of input[max_output_size_per_class], should be > 0"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (maxTotalSize <= maxOutputSizePerClass * scores_shape.GetDim(2)) {
|
||||
output_size = maxTotalSize;
|
||||
} else {
|
||||
output_size = maxOutputSizePerClass * scores_shape.GetDim(2);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t batch_dim = boxes.GetDim(0);
|
||||
Shape shape1({batch_dim, output_size, 4});
|
||||
Shape shape2({batch_dim, output_size});
|
||||
Shape shape3({batch_dim, output_size});
|
||||
Shape shape4({batch_dim});
|
||||
|
||||
TensorDesc desc1 = op.GetOutputDescByName("nmsed_boxes");
|
||||
desc1.SetShape(shape1);
|
||||
desc1.SetDataType(DT_FLOAT);
|
||||
if (op.UpdateOutputDesc("nmsed_boxes", desc1) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), std::string("update output[nmsed_boxes] desc failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc desc2 = op.GetOutputDescByName("nmsed_scores");
|
||||
desc2.SetShape(shape2);
|
||||
desc2.SetDataType(DT_FLOAT);
|
||||
if (op.UpdateOutputDesc("nmsed_scores", desc2) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), std::string("update output[nmsed_scores] desc failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc desc3 = op.GetOutputDescByName("nmsed_classes");
|
||||
desc3.SetShape(shape3);
|
||||
desc3.SetDataType(DT_FLOAT);
|
||||
if (op.UpdateOutputDesc("nmsed_classes", desc3) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), std::string("update output[nmsed_classes] desc failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc desc4 = op.GetOutputDescByName("valid_detections");
|
||||
desc4.SetShape(shape4);
|
||||
desc4.SetDataType(DT_INT32);
|
||||
if (op.UpdateOutputDesc("valid_detections", desc4) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(),
|
||||
std::string("update output[valid_detections] desc failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(CombinedNonMaxSuppression, CombinedNonMaxSuppressionInfer);
|
||||
} // namespace ge
|
|
@ -0,0 +1,418 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/split_combination_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
#include "utils/op_common_util.h"
|
||||
#include "register/infer_axis_slice_registry.h"
|
||||
#include "utils/op_const.h"
|
||||
|
||||
namespace ge {
|
||||
static void JoinShapeRanges(vector<pair<int64_t, int64_t>> &dest_ranges,
|
||||
const vector<pair<int64_t, int64_t>> &src_ranges) {
|
||||
auto dest_size = dest_ranges.size();
|
||||
auto src_size = src_ranges.size();
|
||||
if (dest_size != src_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < dest_size; i++) {
|
||||
dest_ranges[i].first = std::max(dest_ranges[i].first, src_ranges[i].first);
|
||||
dest_ranges[i].second = std::min(dest_ranges[i].second, src_ranges[i].second);
|
||||
}
|
||||
}
|
||||
|
||||
static vector<pair<int64_t, int64_t>> GetShapeRangesWithUnKnowConcatDim(Operator &op, int64_t num_concat) {
|
||||
vector<pair<int64_t, int64_t>> input_shape_ranges;
|
||||
vector<vector<pair<int64_t, int64_t>>> all_input_shape_ranges;
|
||||
vector<pair<int64_t, int64_t>> output_shape_ranges;
|
||||
bool has_shape_ranges = false;
|
||||
for (int32_t i = 0; i < num_concat; i++) {
|
||||
const auto input_desc = op.GetDynamicInputDesc("x", i);
|
||||
(void)input_desc.GetShapeRange(input_shape_ranges);
|
||||
OP_LOGD(TbeGetName(op).c_str(), "input shape range:%s", to_string(input_shape_ranges).c_str());
|
||||
if (input_shape_ranges.empty()) {
|
||||
auto shape_dims = input_desc.GetShape().GetDims();
|
||||
MakeUpShapeRange(shape_dims, input_shape_ranges);
|
||||
} else {
|
||||
has_shape_ranges = true;
|
||||
}
|
||||
|
||||
all_input_shape_ranges.push_back(input_shape_ranges);
|
||||
}
|
||||
|
||||
if (has_shape_ranges) {
|
||||
output_shape_ranges = all_input_shape_ranges[0];
|
||||
for (size_t i = 1; i < all_input_shape_ranges.size(); i++) {
|
||||
if (output_shape_ranges.size() != all_input_shape_ranges[i].size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < output_shape_ranges.size(); j++) {
|
||||
output_shape_ranges[j].first = std::max(output_shape_ranges[j].first, all_input_shape_ranges[i][j].first);
|
||||
if (output_shape_ranges[j].second == -1 || all_input_shape_ranges[i][j].second == -1) {
|
||||
output_shape_ranges[j].second = -1;
|
||||
} else {
|
||||
output_shape_ranges[j].second = output_shape_ranges[j].second + all_input_shape_ranges[i][j].second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return output_shape_ranges;
|
||||
}
|
||||
|
||||
bool JoinShapes(vector<int64_t> &dst_shape, const vector<int64_t> &src_shape, int64_t axis) {
|
||||
if (dst_shape == src_shape) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (dst_shape.empty() || IsUnknownRankShape(dst_shape)) {
|
||||
dst_shape = src_shape;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!IsUnknownRankShape(src_shape)) {
|
||||
if (dst_shape.size() != src_shape.size()) {
|
||||
return false;
|
||||
}
|
||||
auto shape_dims = dst_shape.size();
|
||||
for (size_t i = 0; i < shape_dims; i++) {
|
||||
if (dst_shape[i] == src_shape[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (axis != static_cast<int64_t>(i) && dst_shape[i] != UNKNOWN_DIM && src_shape[i] != UNKNOWN_DIM) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src_shape[i] != UNKNOWN_DIM) {
|
||||
dst_shape[i] = src_shape[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConcatInferShapeCommonStatic(Operator &op, const int64_t dynamic_input_start_idx, int64_t num_concat,
|
||||
int64_t axis) {
|
||||
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto input_desc = op_info->MutableInputDesc(dynamic_input_start_idx);
|
||||
auto output_desc = op_info->MutableOutputDesc(0);
|
||||
const GeShape &input_shape = input_desc->MutableShape();
|
||||
GeShape &output_shape = output_desc->MutableShape();
|
||||
output_shape = input_shape;
|
||||
if (output_shape.IsUnknownShape() || num_concat == 1) {
|
||||
// dynamic case or the input only one will use dynamic infer func
|
||||
return false;
|
||||
}
|
||||
|
||||
if (output_shape.IsScalar()) {
|
||||
// scalar to shape [1]
|
||||
output_shape.SetDimNum(1);
|
||||
output_shape.SetDim(0, 1);
|
||||
}
|
||||
size_t output_dim = output_shape.GetDimNum();
|
||||
if ((axis < -static_cast<int64_t>(output_dim)) || (axis >= static_cast<int64_t>(output_dim))) {
|
||||
// axes is valid
|
||||
return false;
|
||||
}
|
||||
if (axis < 0) {
|
||||
axis += output_dim;
|
||||
}
|
||||
int64_t concat_dim_size = output_shape.GetDim(axis);
|
||||
|
||||
for (int64_t input_idx = 1; input_idx < num_concat; input_idx++) {
|
||||
auto input_i_desc = op_info->MutableInputDesc(input_idx + dynamic_input_start_idx);
|
||||
const GeShape &input_i_shape = input_i_desc->MutableShape();
|
||||
if (input_i_shape.IsScalar() && output_dim == 1) {
|
||||
concat_dim_size += 1;
|
||||
continue;
|
||||
}
|
||||
if (input_i_shape.IsUnknownShape()) {
|
||||
// dynamic case
|
||||
return false;
|
||||
}
|
||||
if (input_i_shape.GetDimNum() != output_dim) {
|
||||
// input shape size is not equal output
|
||||
return false;
|
||||
}
|
||||
// check whether the non concat dim is equal
|
||||
for (int64_t check_dim = 0; check_dim < static_cast<int64_t>(output_dim); check_dim++) {
|
||||
if (check_dim != axis && input_i_shape.GetDim(check_dim) != output_shape.GetDim(check_dim)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
concat_dim_size += input_i_shape.GetDim(axis);
|
||||
}
|
||||
output_shape.SetDim(axis, concat_dim_size);
|
||||
|
||||
// set data type
|
||||
output_desc->SetDataType(input_desc->GetDataType());
|
||||
return true;
|
||||
}
|
||||
|
||||
static graphStatus ConcatInferShapeCommon(Operator &op, const int64_t dy_input_start_idx, int64_t num_concat,
|
||||
int64_t axis, bool unknown_axis) {
|
||||
if (num_concat <= 0) {
|
||||
std::string err_msg = GetAttrValueErrMsg("num_concat", std::to_string(num_concat), ConcatString("num_concat > 0"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
// try static infershape directly
|
||||
if (!unknown_axis) {
|
||||
if (ConcatInferShapeCommonStatic(op, dy_input_start_idx, num_concat, axis)) {
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
}
|
||||
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
|
||||
size_t dim_num = 0;
|
||||
std::vector<GeTensorDescPtr> input_x_desc;
|
||||
const string input_name = "x";
|
||||
string input_name_i = "x63";
|
||||
for (int64_t input_idx = 0; input_idx < num_concat; input_idx++) {
|
||||
input_name_i = input_name + std::to_string(input_idx);
|
||||
auto input_desc = op_info->MutableInputDesc(input_name_i);
|
||||
if (!input_desc) {
|
||||
std::string err_msg = GetInputInvalidErrMsg(input_name_i.c_str());
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
input_x_desc.emplace_back(input_desc);
|
||||
}
|
||||
|
||||
bool all_unknown_rank_shape = true;
|
||||
for (const auto &desc : input_x_desc) {
|
||||
dim_num = std::max(dim_num, desc->MutableShape().GetDimNum());
|
||||
all_unknown_rank_shape = IsUnknownRankShape(desc->MutableShape()) && all_unknown_rank_shape;
|
||||
}
|
||||
|
||||
if (all_unknown_rank_shape) {
|
||||
DataType input_dtype = input_x_desc[0]->GetDataType();
|
||||
auto output_desc = op_info->MutableOutputDesc(0);
|
||||
output_desc->SetDataType(input_dtype);
|
||||
output_desc->SetShape(ge::GeShape(UNKNOWN_RANK));
|
||||
OP_LOGD(TbeGetName(op).c_str(), "output shape:%s", to_string(output_desc->GetShape()).c_str());
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
if (unknown_axis) {
|
||||
DataType input_dtype = input_x_desc[0]->GetDataType();
|
||||
auto output_desc = op_info->MutableOutputDesc(0);
|
||||
output_desc->SetDataType(input_dtype);
|
||||
vector<int64_t> dimVector(dim_num, -1);
|
||||
output_desc->SetShape(ge::GeShape(dimVector));
|
||||
auto output_shape_ranges = GetShapeRangesWithUnKnowConcatDim(op, num_concat);
|
||||
if (!output_shape_ranges.empty()) {
|
||||
output_desc->SetShapeRange(output_shape_ranges);
|
||||
OP_LOGD(TbeGetName(op).c_str(), "output shape range:%s", to_string(output_shape_ranges).c_str());
|
||||
}
|
||||
OP_LOGD(TbeGetName(op).c_str(), "output shape:%s", to_string(output_desc->GetShape()).c_str());
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
if ((axis < -static_cast<int64_t>(dim_num)) || (axis >= static_cast<int64_t>(dim_num))) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "the parameter [axis] should be in the range of [%ld, %ld], but actually is %ld",
|
||||
-static_cast<int64_t>(dim_num), static_cast<int64_t>(dim_num), axis);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int64_t non_negative_axis = axis;
|
||||
if (non_negative_axis < 0) {
|
||||
non_negative_axis += dim_num;
|
||||
}
|
||||
|
||||
vector<int64_t> output_shape_dims;
|
||||
for (const auto &desc : input_x_desc) {
|
||||
auto input_shape_dims = desc->MutableShape().GetDims();
|
||||
if (!JoinShapes(output_shape_dims, input_shape_dims, non_negative_axis)) {
|
||||
vector<vector<int64_t>> shapes = {output_shape_dims, input_shape_dims};
|
||||
std::string err_msg =
|
||||
OtherErrMsg(ConcatString("the input shape dims should be equal except merge axis,"
|
||||
"shapes:",
|
||||
ops::to_string(shapes), "axis:", std::to_string(axis)));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t size = 0;
|
||||
for (const auto &desc : input_x_desc) {
|
||||
if (IsUnknownRankShape(desc->MutableShape())) {
|
||||
size = -1;
|
||||
break;
|
||||
}
|
||||
|
||||
auto dim_value = desc->MutableShape().GetDim(non_negative_axis);
|
||||
if (dim_value == -1) {
|
||||
size = -1;
|
||||
break;
|
||||
}
|
||||
|
||||
if (size != -1) {
|
||||
size += dim_value;
|
||||
}
|
||||
}
|
||||
|
||||
output_shape_dims[non_negative_axis] = size;
|
||||
DataType input_dtype = input_x_desc[0]->GetDataType();
|
||||
auto output_desc = op_info->MutableOutputDesc(0);
|
||||
output_desc->SetDataType(input_dtype);
|
||||
output_desc->SetShape(ge::GeShape(output_shape_dims));
|
||||
OP_LOGD(TbeGetName(op).c_str(), "output shape:%s", to_string(output_desc->GetShape()).c_str());
|
||||
|
||||
if (IsUnKnownShape(output_shape_dims)) {
|
||||
vector<pair<int64_t, int64_t>> input_shape_ranges;
|
||||
vector<pair<int64_t, int64_t>> output_shape_ranges;
|
||||
pair<int64_t, int64_t> output_concat_dim_range(0, 0);
|
||||
for (const auto &input_desc : input_x_desc) {
|
||||
if (IsUnknownRankShape(input_desc->MutableShape())) {
|
||||
output_concat_dim_range = {0, -1};
|
||||
continue;
|
||||
}
|
||||
|
||||
input_shape_ranges.clear();
|
||||
input_desc->GetShapeRange(input_shape_ranges);
|
||||
OP_LOGD(TbeGetName(op).c_str(), "input shape range:%s", to_string(input_shape_ranges).c_str());
|
||||
if (input_shape_ranges.empty()) {
|
||||
MakeUpShapeRange(input_desc->MutableShape(), input_shape_ranges);
|
||||
}
|
||||
|
||||
if (static_cast<int64_t>(input_shape_ranges.size()) > non_negative_axis) {
|
||||
output_concat_dim_range.first += input_shape_ranges[non_negative_axis].first;
|
||||
if (input_shape_ranges[non_negative_axis].second == -1 || output_concat_dim_range.second == -1) {
|
||||
output_concat_dim_range.second = -1;
|
||||
} else {
|
||||
output_concat_dim_range.second += input_shape_ranges[non_negative_axis].second;
|
||||
}
|
||||
}
|
||||
|
||||
if (output_shape_ranges.empty()) {
|
||||
output_shape_ranges = input_shape_ranges;
|
||||
} else {
|
||||
JoinShapeRanges(output_shape_ranges, input_shape_ranges);
|
||||
}
|
||||
}
|
||||
|
||||
if (output_concat_dim_range.second != 0 && static_cast<uint64_t>(non_negative_axis) < output_shape_ranges.size()) {
|
||||
output_shape_ranges[non_negative_axis] = output_concat_dim_range;
|
||||
}
|
||||
|
||||
output_desc->SetShapeRange(output_shape_ranges);
|
||||
OP_LOGD(TbeGetName(op).c_str(), "output shape range:%s", to_string(output_shape_ranges).c_str());
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
static graphStatus ConcatInputsVerify(const Operator &op) {
|
||||
std::vector<std::string> inputs;
|
||||
const string input_name = "x";
|
||||
string input_name_i = "x63";
|
||||
int64_t N;
|
||||
if (op.GetAttr("N", N) == GRAPH_FAILED) {
|
||||
OP_LOGE(TbeGetName(op), "get attr N failed");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
for (int64_t input_idx = 0; input_idx < N; input_idx++) {
|
||||
input_name_i = input_name + std::to_string(input_idx);
|
||||
inputs.emplace_back(input_name_i);
|
||||
}
|
||||
|
||||
if (!CheckInputDtypeSame(op, inputs)) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// ----------------Concat OP Begin-------------------
|
||||
IMPLEMT_VERIFIER(Concat, ConcatVerify) { return ConcatInputsVerify(op); }
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(ConcatInferShape) {
|
||||
const vector<string> depend_names = {"concat_dim"};
|
||||
PREPARE_DYNAMIC_SHAPE(depend_names);
|
||||
|
||||
int64_t N;
|
||||
if (op.GetAttr("N", N) == GRAPH_FAILED) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("get attr[N] failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto shape_idx = static_cast<uint32_t>(op_desc->GetInputIndexByName("concat_dim"));
|
||||
const GeTensor *data = OpDescUtils::GetInputConstData(op, shape_idx);
|
||||
bool is_unknown_axis = data == nullptr;
|
||||
OP_LOGD(TbeGetName(op), "concat_dim is unknown[%s].", is_unknown_axis ? "true" : "false");
|
||||
int64_t axis = 0;
|
||||
if (!is_unknown_axis) {
|
||||
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
|
||||
DataType dtype = op_info->MutableInputDesc(0)->GetDataType();
|
||||
std::vector<int64_t> const_vec;
|
||||
if (!GetConstValue(op, data, dtype, const_vec)) {
|
||||
is_unknown_axis = true;
|
||||
OP_LOGW(TbeGetName(op), "Get concat_dim value failed.");
|
||||
} else {
|
||||
axis = const_vec[0];
|
||||
}
|
||||
}
|
||||
|
||||
return ConcatInferShapeCommon(op, 1, N, axis, is_unknown_axis);
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFER_AXIS_TYPE_INFO(InferAxisType4Concat) {
|
||||
OP_LOGD(TbeGetName(op), "Infer axis type for %s begin.", TbeGetOpType(op).c_str());
|
||||
std::vector<int64_t> allowed_split_inputs = {};
|
||||
std::vector<int64_t> allowed_split_outputs = {0};
|
||||
std::vector<int64_t> excepted_axes = {};
|
||||
bool concat_dim_done = true;
|
||||
static const int64_t concat_dim_input_index = 0;
|
||||
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
CHECK(op_desc == nullptr,
|
||||
INFER_AXIS_TYPE_ERR_REPORT(TbeGetName(op), "Failed to get desc from operator. Please check the node info."),
|
||||
return GRAPH_FAILED);
|
||||
vector<string> depends = op_desc->GetOpInferDepends();
|
||||
if (find(depends.begin(), depends.end(), "concat_dim") == depends.end()) {
|
||||
depends.emplace_back("concat_dim");
|
||||
op_desc->SetOpInferDepends(depends);
|
||||
}
|
||||
if (!(ops::GetConstIntData(op, concat_dim_input_index, excepted_axes))) {
|
||||
concat_dim_done = false;
|
||||
OP_LOGD(TbeGetName(op), "Get const concat_dim value failed. Can not infer axis type.");
|
||||
}
|
||||
CHECK(!concat_dim_done, OP_LOGD(TbeGetName(op), "Concat dim is not const node. Can not infer axis type."),
|
||||
return GRAPH_SUCCESS);
|
||||
|
||||
int64_t attr_n;
|
||||
CHECK(op.GetAttr("N", attr_n) == GRAPH_FAILED, OP_LOGD(TbeGetName(op), "Get attr N failed. Can not infer axis type."),
|
||||
return GRAPH_SUCCESS);
|
||||
for (auto i = 0; i < attr_n; ++i) {
|
||||
allowed_split_inputs.emplace_back(i + 1);
|
||||
}
|
||||
|
||||
return InferElementwiseAxisTypeHelper(op, axis_type, allowed_split_inputs, allowed_split_outputs, excepted_axes);
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(Concat, ConcatInferShape);
|
||||
VERIFY_FUNC_REG(Concat, ConcatVerify);
|
||||
INFER_VALUE_RANGE_DEFAULT_REG(Concat);
|
||||
INFER_AXIS_TYPE_INFO_REG(Concat, InferAxisType4Concat);
|
||||
// ----------------Concat OP End-------------------
|
||||
} // namespace ge
|
|
@ -25,41 +25,41 @@ CUST_IMPLEMT_INFERFUNC(CSRSparseMatrixToSparseTensor, CSRSparseMatrixToSparseTen
|
|||
|
||||
GeShape x_dense_shape_shape;
|
||||
auto x_dense_shape_desc = op_desc->MutableInputDesc(0);
|
||||
if (WithRank(x_dense_shape_desc, 1, x_dense_shape_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(x_dense_shape_desc, 1, x_dense_shape_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input x_dense_shape must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape x_batch_pointers_shape;
|
||||
auto x_batch_pointers_desc = op_desc->MutableInputDesc(1);
|
||||
if (WithRank(x_batch_pointers_desc, 1, x_batch_pointers_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(x_batch_pointers_desc, 1, x_batch_pointers_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input x_batch_pointers must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape x_row_pointers_shape;
|
||||
auto x_row_pointers_desc = op_desc->MutableInputDesc(2);
|
||||
if (WithRank(x_row_pointers_desc, 1, x_row_pointers_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(x_row_pointers_desc, 1, x_row_pointers_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input x_row_pointers must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape x_col_indices_shape;
|
||||
auto x_col_indices_desc = op_desc->MutableInputDesc(3);
|
||||
if (WithRank(x_col_indices_desc, 1, x_col_indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(x_col_indices_desc, 1, x_col_indices_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input x_col_indices must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape x_values_shape;
|
||||
auto x_values_desc = op_desc->MutableInputDesc(4);
|
||||
if (WithRank(x_values_desc, 1, x_values_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(x_values_desc, 1, x_values_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input x_values must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape unused;
|
||||
if (Merge(x_col_indices_shape, x_values_shape, unused, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (Merge(x_col_indices_shape, x_values_shape, unused, op) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2023 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/selection_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
namespace ge {
|
||||
// ----------------Cumprod-------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(CumprodInferShape) {
|
||||
TensorDesc desc = op.GetInputDescByName("x");
|
||||
return op.UpdateOutputDesc("y", desc);
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(Cumprod, CumprodInferShape);
|
||||
// ----------------Cumprod END-------------------
|
||||
} // namespace ge
|
|
@ -0,0 +1,402 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/elewise_calculation_ops.h"
|
||||
#include "custom_op_proto/cust_math_ops.h"
|
||||
|
||||
#include "register/infer_axis_slice_registry.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "utils/op_attr.h"
|
||||
#include "utils/op_log.h"
|
||||
#include "utils/op_const.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/error_util.h"
|
||||
#include "utils/reduce_infer_util.h"
|
||||
#include "graph/utils/node_utils.h"
|
||||
#include "graph/utils/node_utils_ex.h"
|
||||
#include "graph/utils/op_desc_utils.h"
|
||||
#include "register/infer_data_slice_registry.h"
|
||||
#include "graph/debug/ge_attr_define.h"
|
||||
#include "graph/axis_type_info.h"
|
||||
|
||||
namespace ge {
|
||||
IMPLEMT_COMMON_INFERFUNC(TwoInOneOutCommonInferShape) {
|
||||
bool is_dynamic_output = true;
|
||||
if (!InferShapeAndTypeTwoInOneOutBroadcast(op, 0, 1, 0, is_dynamic_output)) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(OneInOneOutCommonInferShape) {
|
||||
static const int64_t input_x_idx = 0;
|
||||
static const int64_t output_y_idx = 0;
|
||||
if (OneInOneOutDynamicInfer(op, input_x_idx, {output_y_idx})) {
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
// ----------------------------------OneInOneOutCommonInfer-----------------------------
|
||||
COMMON_INFER_FUNC_REG(CheckNumerics, OneInOneOutCommonInferShape);
|
||||
COMMON_INFER_FUNC_REG(Conj, OneInOneOutCommonInferShape);
|
||||
COMMON_INFER_FUNC_REG(Cos, OneInOneOutCommonInferShape);
|
||||
COMMON_INFER_FUNC_REG(Expm1, OneInOneOutCommonInferShape);
|
||||
COMMON_INFER_FUNC_REG(Log1p, OneInOneOutCommonInferShape);
|
||||
COMMON_INFER_FUNC_REG(Log, OneInOneOutCommonInferShape);
|
||||
// ----------------------------------OneInOneOutCommonInfer END-----------------------------
|
||||
|
||||
// ----------------------------------TowInOneOutCommonInfer-----------------------------
|
||||
COMMON_INFER_FUNC_REG(Div, TwoInOneOutCommonInferShape);
|
||||
COMMON_INFER_FUNC_REG(DivNoNan, TwoInOneOutCommonInferShape);
|
||||
CUST_COMMON_INFER_FUNC_REG(Gcd, TwoInOneOutCommonInferShape);
|
||||
CUST_COMMON_INFER_FUNC_REG(Heaviside, TwoInOneOutCommonInferShape);
|
||||
CUST_COMMON_INFER_FUNC_REG(Hypot, TwoInOneOutCommonInferShape);
|
||||
CUST_COMMON_INFER_FUNC_REG(Lcm, TwoInOneOutCommonInferShape);
|
||||
// ----------------------------------TowInOneOutCommonInfer END-----------------------------
|
||||
|
||||
// --------------AcosGrad----------------
|
||||
IMPLEMT_VERIFIER(AcosGrad, AcosGradVerify) {
|
||||
if (!CheckTwoInputDtypeSame(op, "y", "dy")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
VERIFY_FUNC_REG(AcosGrad, AcosGradVerify);
|
||||
COMMON_INFER_FUNC_REG(AcosGrad, TwoInOneOutCommonInferShape);
|
||||
INFER_AXIS_TYPE_INFO_REG(AcosGrad, InferAxisType4ElementwiseOp);
|
||||
// ------------AcosGrad END----------------
|
||||
|
||||
// ----------------AcoshGrad-------------------
|
||||
IMPLEMT_VERIFIER(AcoshGrad, AcoshGradVerify) {
|
||||
if (!CheckTwoInputDtypeSame(op, "y", "dy")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
VERIFY_FUNC_REG(AcoshGrad, AcoshGradVerify);
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(AcoshGradInferShape) {
|
||||
bool is_dynamic_output = true;
|
||||
if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "y", "dy", "z", is_dynamic_output)) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(AcoshGrad, AcoshGradInferShape);
|
||||
INFER_AXIS_TYPE_INFO_REG(AcoshGrad, InferAxisType4ElementwiseOp);
|
||||
// --------------AcoshGrad END-----------------
|
||||
|
||||
// ----------------AsinGrad---------------
|
||||
IMPLEMT_VERIFIER(AsinGrad, AsinGradVerify) {
|
||||
if (!CheckTwoInputDtypeSame(op, "y", "dy")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
VERIFY_FUNC_REG(AsinGrad, AsinGradVerify);
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(AsinGradInferShape) {
|
||||
bool is_dynamic_output = true;
|
||||
if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "y", "dy", "z", is_dynamic_output)) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
COMMON_INFER_FUNC_REG(AsinGrad, AsinGradInferShape);
|
||||
INFER_AXIS_TYPE_INFO_REG(AsinGrad, InferAxisType4ElementwiseOp);
|
||||
// --------------AsinGrad END-------------
|
||||
|
||||
// ----------------AsinhGrad-------------------
|
||||
IMPLEMT_VERIFIER(AsinhGrad, AsinhGradVerify) {
|
||||
if (!CheckTwoInputDtypeSame(op, "y", "dy")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
VERIFY_FUNC_REG(AsinhGrad, AsinhGradVerify);
|
||||
IMPLEMT_COMMON_INFERFUNC(AsinhGradInferShape) {
|
||||
bool is_dynamic_output = true;
|
||||
if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "y", "dy", "z", is_dynamic_output)) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(AsinhGrad, AsinhGradInferShape);
|
||||
INFER_AXIS_TYPE_INFO_REG(AsinhGrad, InferAxisType4ElementwiseOp);
|
||||
// --------------AsinhGrad END-----------------
|
||||
|
||||
// ----------------AddN-------------------
|
||||
int64_t GetAddNConstValue(const ge::Operator &op) {
|
||||
int64_t tensor_num;
|
||||
if (ge::GRAPH_SUCCESS != op.GetAttr("N", tensor_num)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("N");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
}
|
||||
return tensor_num;
|
||||
}
|
||||
|
||||
int64_t AddNInferClassify(ge::Operator &op, int64_t &tensor_num) {
|
||||
const int64_t infer_condition_one_one = 11;
|
||||
const int64_t infer_condition_one_two = 12;
|
||||
const int64_t infer_condition_two = 2;
|
||||
const int64_t infer_condition_three = 3;
|
||||
|
||||
int64_t empty_num = 0;
|
||||
int64_t static_num = 0;
|
||||
int64_t dynamic_shape_num = 0;
|
||||
int64_t dynamic_dim_num = 0;
|
||||
|
||||
for (int64_t i = 0; i < tensor_num; i++) {
|
||||
vector<int64_t> tempVector = op.GetDynamicInputDesc("x", i).GetShape().GetDims();
|
||||
if (tempVector.empty()) {
|
||||
empty_num++;
|
||||
} else if (std::find(tempVector.begin(), tempVector.end(), ge::UNKNOWN_DIM) != tempVector.end()) {
|
||||
dynamic_shape_num++;
|
||||
} else if (std::find(tempVector.begin(), tempVector.end(), ge::UNKNOWN_DIM_NUM) != tempVector.end()) {
|
||||
dynamic_dim_num++;
|
||||
} else {
|
||||
static_num++;
|
||||
}
|
||||
}
|
||||
if (tensor_num == empty_num + dynamic_dim_num) {
|
||||
if (tensor_num == empty_num) {
|
||||
return infer_condition_one_one;
|
||||
} else {
|
||||
return infer_condition_one_two;
|
||||
}
|
||||
} else if (tensor_num == static_num || tensor_num == empty_num + static_num ||
|
||||
tensor_num == static_num + dynamic_dim_num || tensor_num == empty_num + static_num + dynamic_dim_num) {
|
||||
return infer_condition_two;
|
||||
} else {
|
||||
return infer_condition_three;
|
||||
}
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(AddNInferShape) {
|
||||
/*
|
||||
add_n has four type inputs:
|
||||
1.empty 2.static shape 3.-1 4.-2
|
||||
The combinations bring 15 scenes, and the 15 scenes can be classify into 4 categories:
|
||||
1.input with no range and output no need range, and it can be divided half:
|
||||
1.1 all input is empty
|
||||
1.2 input only contains empty and -2 shape
|
||||
2.input contains static shape and with no -1 shape
|
||||
3.input contains -1 shape
|
||||
*/
|
||||
int64_t tensor_num = GetAddNConstValue(op);
|
||||
int64_t infer_classify = AddNInferClassify(op, tensor_num);
|
||||
// condition 1: all input shape is empty
|
||||
if (infer_classify == 11) {
|
||||
std::vector<int64_t> shape_vector = op.GetDynamicInputDesc("x", 0).GetShape().GetDims();
|
||||
DataType x_dtype = op.GetDynamicInputDesc("x", 0).GetDataType();
|
||||
TensorDesc y_desc = op.GetOutputDescByName("y");
|
||||
y_desc.SetShape(Shape(shape_vector));
|
||||
y_desc.SetDataType(x_dtype);
|
||||
(void)op.UpdateOutputDesc("y", y_desc);
|
||||
// condition 2: all input is -2 or only empty and -2
|
||||
} else if (infer_classify == 12) {
|
||||
std::vector<int64_t> shape_vector = {-2};
|
||||
DataType x_dtype = op.GetDynamicInputDesc("x", 0).GetDataType();
|
||||
TensorDesc y_desc = op.GetOutputDescByName("y");
|
||||
y_desc.SetShape(Shape(shape_vector));
|
||||
y_desc.SetDataType(x_dtype);
|
||||
(void)op.UpdateOutputDesc("y", y_desc);
|
||||
// condition 3: contains static shape and no -1 shape
|
||||
} else if (infer_classify == 2) {
|
||||
DataType x_dtype = op.GetDynamicInputDesc("x", 0).GetDataType();
|
||||
std::vector<int64_t> shape_vector = op.GetDynamicInputDesc("x", 0).GetShape().GetDims();
|
||||
for (int64_t i = 0; i < tensor_num; i++) {
|
||||
std::vector<int64_t> temp_vector = op.GetDynamicInputDesc("x", i).GetShape().GetDims();
|
||||
if (!shape_vector.empty() && !IsUnknownRankShape(shape_vector)) {
|
||||
shape_vector = temp_vector;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TensorDesc y_desc = op.GetOutputDescByName("y");
|
||||
y_desc.SetShape(ge::Shape(shape_vector));
|
||||
y_desc.SetDataType(x_dtype);
|
||||
std::vector<std::pair<int64_t, int64_t>> out_range;
|
||||
MakeUpShapeRange(shape_vector, out_range);
|
||||
y_desc.SetShapeRange(out_range);
|
||||
(void)op.UpdateOutputDesc("y", y_desc);
|
||||
// condition 4: contains -1 shape, range need to choose the intersection
|
||||
} else {
|
||||
Shape out_shape = op.GetDynamicInputDesc("x", 0).GetShape();
|
||||
DataType x_dtype = op.GetDynamicInputDesc("x", 0).GetDataType();
|
||||
std::vector<int64_t> out_vector;
|
||||
std::vector<std::pair<int64_t, int64_t>> out_range;
|
||||
// Init the output shape and range
|
||||
for (int64_t i = 0; i < tensor_num; i++) {
|
||||
std::vector<int64_t> temp_vector = op.GetDynamicInputDesc("x", i).GetShape().GetDims();
|
||||
if (!temp_vector.empty() && !IsUnknownRankShape(temp_vector)) {
|
||||
out_vector = temp_vector;
|
||||
op.GetDynamicInputDesc("x", i).GetShapeRange(out_range);
|
||||
MakeUpShapeRange(out_vector, out_range);
|
||||
break;
|
||||
}
|
||||
}
|
||||
// compute the shape dims and range intersection
|
||||
for (int64_t i = 0; i < tensor_num; i++) {
|
||||
std::vector<int64_t> temp_vector = op.GetDynamicInputDesc("x", i).GetShape().GetDims();
|
||||
if (temp_vector.empty() || IsUnknownRankShape(temp_vector)) {
|
||||
continue;
|
||||
}
|
||||
std::vector<std::pair<int64_t, int64_t>> temp_range;
|
||||
op.GetDynamicInputDesc("x", i).GetShapeRange(temp_range);
|
||||
MakeUpShapeRange(temp_vector, temp_range);
|
||||
for (size_t j = 0; j < temp_vector.size(); j++) {
|
||||
// two condition: const == const; const > -1
|
||||
if (temp_vector[j] >= out_vector[j]) {
|
||||
out_vector[j] = temp_vector[j];
|
||||
// update range: left choose the max value
|
||||
if (temp_range[j].first >= out_range[j].first) {
|
||||
out_range[j].first = temp_range[j].first;
|
||||
}
|
||||
// update range: right choose the miner value but when it was > 0
|
||||
if ((temp_range[j].second <= out_range[j].second && temp_range[j].second > 0) ||
|
||||
(out_range[j].second == -1 && temp_range[j].second != -1)) {
|
||||
out_range[j].second = temp_range[j].second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
TensorDesc y_desc = op.GetOutputDescByName("y");
|
||||
out_shape = Shape(out_vector);
|
||||
y_desc.SetShape(out_shape);
|
||||
y_desc.SetDataType(x_dtype);
|
||||
y_desc.SetShapeRange(out_range);
|
||||
(void)op.UpdateOutputDesc("y", y_desc);
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(AddN, AddNInferShape);
|
||||
INFER_AXIS_TYPE_INFO_REG(AddN, InferAxisType4ElementwiseOp);
|
||||
// ----------------AddN END-------------------
|
||||
|
||||
// --------------------------------BiasAdd-------------------------------------
|
||||
IMPLEMT_VERIFIER(BiasAdd, BiasAddVerify) {
|
||||
std::string data_format;
|
||||
if (op.GetAttr("data_format", data_format) == GRAPH_FAILED) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("data_format");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
}
|
||||
if (data_format != "NHWC" && data_format != "NCHW" && data_format != "NDHWC" && data_format != "NCDHW") {
|
||||
string expected_format_list = ConcatString("NHWC, NCHW, NDHWC, NCDHW");
|
||||
std::string err_msg = GetInputFormatNotSupportErrMsg(TbeGetName(op).c_str(), expected_format_list, data_format);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(BiasAddInferShape) {
|
||||
const int64_t input_x_idx = 0;
|
||||
const int64_t output_y_idx = 0;
|
||||
if (!OneInOneOutDynamicInfer(op, input_x_idx, {output_y_idx})) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(BiasAdd, BiasAddInferShape);
|
||||
VERIFY_FUNC_REG(BiasAdd, BiasAddVerify);
|
||||
INFER_AXIS_TYPE_INFO_REG(BiasAdd, InferAxisType4BroadcastOp);
|
||||
// ----------------------------------BiasAdd END-----------------------------
|
||||
|
||||
// --------------MulNoNan--------------
|
||||
IMPLEMT_VERIFIER(MulNoNan, MulNoNanVerify) {
|
||||
DataType input_type_x1 = op.GetInputDescByName("x1").GetDataType();
|
||||
DataType input_type_x2 = op.GetInputDescByName("x2").GetDataType();
|
||||
if (input_type_x1 != input_type_x2) {
|
||||
string err_msg1 =
|
||||
ConcatString("the dtype of input_type_x1 and input_type_x2 must be same! input_type_x1:", input_type_x1,
|
||||
", input_type_x2:", input_type_x2);
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
VERIFY_FUNC_REG(MulNoNan, MulNoNanVerify);
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(MulNoNanInferShape) {
|
||||
bool is_dynamic_output = true;
|
||||
if (InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y", is_dynamic_output)) {
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
COMMON_INFER_FUNC_REG(MulNoNan, MulNoNanInferShape);
|
||||
// ------------MulNoNan END--------------
|
||||
|
||||
// -------------------LessEqual---------------------
|
||||
IMPLEMT_VERIFIER(LessEqual, LessEqualVerify) {
|
||||
if (!CheckTwoInputDtypeSame(op, "x1", "x2")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(LessEqualInferShape) {
|
||||
if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto vec_y = op_desc->MutableOutputDesc("y")->MutableShape().GetDims();
|
||||
if (IsUnknownRankShape(vec_y) || IsUnknownVec(vec_y)) {
|
||||
if (!InferShapeRangeTwoInOneOutBroadcast(op, "x1", "x2", "y")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
op_desc->MutableOutputDesc("y")->SetDataType(DT_BOOL);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(LessEqual, LessEqualInferShape);
|
||||
VERIFY_FUNC_REG(LessEqual, LessEqualVerify);
|
||||
// --------------------LessEqual END-----------------------
|
||||
|
||||
// --------------------Mul-----------------------
|
||||
IMPLEMT_VERIFIER(Mul, MulVerify) {
|
||||
if (!CheckTwoInputDtypeSame(op, "x1", "x2")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(Mul, TwoInOneOutCommonInferShape);
|
||||
VERIFY_FUNC_REG(Mul, MulVerify);
|
||||
// --------------------Mul END-----------------------
|
||||
|
||||
// -------------------FloorDiv-----------------------
|
||||
IMPLEMT_VERIFIER(FloorDiv, FloorDivVerify) {
|
||||
if (!CheckTwoInputDtypeSame(op, "x1", "x2")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(FloorDiv, TwoInOneOutCommonInferShape);
|
||||
VERIFY_FUNC_REG(FloorDiv, FloorDivVerify);
|
||||
// ----------------FloorDiv END------------------------
|
||||
} // namespace ge
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/nn_pooling_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
IMPLEMT_INFERFUNC(FractionalAvgPoolGrad, FractionalAvgPoolGradInfer) {
|
||||
Tensor tensor;
|
||||
if (op.GetInputConstData("orig_input_tensor_shape", tensor) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op),
|
||||
ConcatString("get const data from input[orig_input_tensor_shape] failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Shape result;
|
||||
if (MakeShapeFromShapeTensor(tensor, result, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op),
|
||||
ConcatString("call MakeShapeFromShapeTensor function failed to make",
|
||||
" shape from input[orig_input_tensor_shape] data"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
DataType y_type = op.GetInputDescByName("out_backprop").GetDataType();
|
||||
TensorDesc out_desc = op.GetOutputDescByName("y");
|
||||
out_desc.SetShape(Shape(result));
|
||||
out_desc.SetDataType(y_type);
|
||||
if (op.UpdateOutputDesc("y", out_desc) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("update output[y] desc failed."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(FractionalAvgPoolGrad, FractionalAvgPoolGradInfer);
|
||||
} // namespace ge
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/nn_pooling_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
IMPLEMT_INFERFUNC(FractionalAvgPool, FractionalAvgPoolInfer) {
|
||||
Shape input;
|
||||
if (WithRank(op.GetInputDesc(0), 4, input, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
|
||||
TbeGetName(op), ConcatString("Call WithRank function failed, ",
|
||||
GetShapeErrMsg(0, DebugString(op.GetInputDesc(0).GetShape().GetDims()), "4D")));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::vector<float> pooling_ratio;
|
||||
op.GetAttr("pooling_ratio", pooling_ratio);
|
||||
if (pooling_ratio.size() != 4) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op),
|
||||
GetAttrValueErrMsg("pooling_ratio", ConcatString(pooling_ratio.size()), "4"));
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
auto x_dims = op.GetInputDesc(0).GetShape().GetDims();
|
||||
std::vector<int64_t> dims;
|
||||
dims.reserve(4);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
int64_t val = ge::UNKNOWN_DIM;
|
||||
if (x_dims[i] != ge::UNKNOWN_DIM) {
|
||||
val = static_cast<int64_t>(x_dims[i] / pooling_ratio[i]);
|
||||
if (val < 0) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
TbeGetName(op), ConcatString("size computed for ", i, "th dim is ", val, ", should be >= 0"));
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
OP_LOGI(TbeGetName(op).c_str(), "i = %d, x_dims[i] = %ld, pooling_ratio[i] = %f, val = %ld", i, x_dims[i],
|
||||
pooling_ratio[i], val);
|
||||
dims.push_back(val);
|
||||
}
|
||||
Shape out(dims);
|
||||
Shape row_pooling_sequence;
|
||||
(void)Vector(dims[1] + 1, row_pooling_sequence);
|
||||
Shape col_pooling_sequence;
|
||||
(void)Vector(dims[2] + 1, col_pooling_sequence);
|
||||
|
||||
DataType type = op.GetInputDescByName("x").GetDataType();
|
||||
|
||||
TensorDesc y_desc = op.GetOutputDescByName("y");
|
||||
y_desc.SetShape(out);
|
||||
y_desc.SetDataType(type);
|
||||
op.UpdateOutputDesc("y", y_desc);
|
||||
|
||||
TensorDesc row_desc = op.GetOutputDescByName("row_pooling_sequence");
|
||||
row_desc.SetShape(row_pooling_sequence);
|
||||
row_desc.SetDataType(DT_INT64);
|
||||
op.UpdateOutputDesc("row_pooling_sequence", row_desc);
|
||||
|
||||
TensorDesc col_desc = op.GetOutputDescByName("col_pooling_sequence");
|
||||
col_desc.SetShape(col_pooling_sequence);
|
||||
col_desc.SetDataType(DT_INT64);
|
||||
op.UpdateOutputDesc("col_pooling_sequence", col_desc);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(FractionalAvgPool, FractionalAvgPoolInfer);
|
||||
} // namespace ge
|
|
@ -1,73 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/geqrf_op.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
// ---------------Geqrf-------------------
|
||||
CUST_IMPLEMT_INFERFUNC(Geqrf, GeqrfInfer) {
|
||||
auto tensor = op.get_input_desc_x();
|
||||
Shape input;
|
||||
if (WithRank(tensor, 2, input, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int dim_num = input.GetDimNum();
|
||||
int m = input.GetDim(dim_num - 2);
|
||||
int n = input.GetDim(dim_num - 1);
|
||||
Shape r_shape;
|
||||
Shape tau_shape;
|
||||
int p = m > n ? n : m;
|
||||
Matrix(m, n, r_shape);
|
||||
Vector(p, tau_shape);
|
||||
|
||||
DataType type = op.GetInputDescByName("x").GetDataType();
|
||||
TensorDesc r_desc = op.GetOutputDescByName("r");
|
||||
r_desc.SetShape(Shape(r_shape));
|
||||
r_desc.SetDataType(type);
|
||||
if (op.UpdateOutputDesc("r", r_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Update r desc failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
TensorDesc tau_desc = op.GetOutputDescByName("tau");
|
||||
tau_desc.SetShape(Shape(tau_shape));
|
||||
tau_desc.SetDataType(type);
|
||||
if (op.UpdateOutputDesc("tau", tau_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Update tau desc failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_INFER_FUNC_REG(Geqrf, GeqrfInfer);
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(Geqrf, GeqrfVerify) {
|
||||
DataType type = op.GetInputDescByName("x").GetDataType();
|
||||
if (type != DT_FLOAT16 && type != DT_FLOAT && type != DT_DOUBLE && type != DT_COMPLEX64 && type != DT_COMPLEX128) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Expect a floating point or complex tensor as input.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_VERIFY_FUNC_REG(Geqrf, GeqrfVerify);
|
||||
// ---------------Geqrf End---------------
|
||||
} // namespace ge
|
|
@ -1,64 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/hamming_window_op.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------HammingWindow Begin---------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(HammingWindowInferShape) {
|
||||
std::vector<int64_t> input_dim = op.GetInputDesc(0).GetShape().GetDims();
|
||||
if (input_dim.size() != 1) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Tensor length input must be 1D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Tensor length_tensor;
|
||||
int64_t length_data;
|
||||
if (op.GetInputConstData("length", length_tensor) == GRAPH_SUCCESS) {
|
||||
uint8_t *length = length_tensor.GetData();
|
||||
length_data = static_cast<int64_t>(*length);
|
||||
} else {
|
||||
length_data = UNKNOWN_DIM;
|
||||
}
|
||||
std::vector<int64_t> output_dim;
|
||||
if (length_data != UNKNOWN_DIM && length_data < 0) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Non-negative window length required, got [%ld].", length_data);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (length_data != 0) {
|
||||
output_dim.push_back(length_data);
|
||||
}
|
||||
ge::Shape output_shape = ge::Shape(output_dim);
|
||||
|
||||
Operator::OpInt dtype;
|
||||
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
|
||||
dtype = 0;
|
||||
}
|
||||
DataType output_dtype = static_cast<DataType>(dtype);
|
||||
|
||||
TensorDesc output_desc = op.GetOutputDescByName("y");
|
||||
output_desc.SetShape(output_shape);
|
||||
output_desc.SetDataType(output_dtype);
|
||||
op.UpdateOutputDesc("y", output_desc);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(HammingWindow, HammingWindowInferShape);
|
||||
// ----------------HammingWindow End---------------------
|
||||
} // namespace ge
|
|
@ -0,0 +1,202 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/transformation_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
#include "graph/common_error_codes.h"
|
||||
namespace ge {
|
||||
const std::string ATTR_NAME_DATA_SLICE = "_data_slice";
|
||||
|
||||
static bool CheckListEmptyAndValue(const std::string &op_name, const std::vector<int64_t> &list,
|
||||
const std::string &attr_name) {
|
||||
if (list.size() < 1) {
|
||||
OP_LOGE(op_name.c_str(), "The %s dose not have enough elements(%lu)!", attr_name.c_str(), list.size());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::vector<int64_t> GetAttrValue(const Operator &op, const std::string &key_name) {
|
||||
std::vector<int64_t> list;
|
||||
if (op.GetAttr(key_name.c_str(), list) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "GetOpAttr ConstValue failed!");
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
// -----------------Im2col Op-------------------------
|
||||
IMPLEMT_VERIFIER(Im2col, Im2colVerify) {
|
||||
std::vector<int64_t> ksize;
|
||||
ksize = GetAttrValue(op, "ksizes");
|
||||
if (ksize.size() < 2) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "The ksizes dose not have enough elements(%lu)!", ksize.size());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<int64_t> stride;
|
||||
stride = GetAttrValue(op, "strides");
|
||||
if (!CheckListEmptyAndValue(TbeGetName(op), stride, "strides")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<int64_t> dilation;
|
||||
dilation = GetAttrValue(op, "dilations");
|
||||
if (!CheckListEmptyAndValue(TbeGetName(op), dilation, "dilations")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::string padding_mode;
|
||||
if (op.GetAttr("padding_mode", padding_mode) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Get padding_mode failed!");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (padding_mode != "CALCULATED" && padding_mode != "SAME" && padding_mode != "VALID") {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "padding_mode only support CALCULATED, SAME and VALID!");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<int64_t> pad;
|
||||
pad = GetAttrValue(op, "pads");
|
||||
if (!CheckListEmptyAndValue(TbeGetName(op), pad, "pads")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(Im2colInferShape) {
|
||||
OP_LOGI(TbeGetName(op).c_str(), "Enter op_proto inferfunction!");
|
||||
|
||||
std::vector<int64_t> ksize;
|
||||
ksize = GetAttrValue(op, "ksizes");
|
||||
std::vector<int64_t> stride;
|
||||
stride = GetAttrValue(op, "strides");
|
||||
std::vector<int64_t> dilation;
|
||||
dilation = GetAttrValue(op, "dilations");
|
||||
std::string padding_mode;
|
||||
if (op.GetAttr("padding_mode", padding_mode) == GRAPH_FAILED) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "GetOpAttr ConstValue padding_mode failed!");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::vector<int64_t> pad;
|
||||
pad = GetAttrValue(op, "pads");
|
||||
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
GeTensorDescPtr desc_in_ptr = op_desc->MutableInputDesc("x");
|
||||
GeTensorDescPtr desc_out_ptr = op_desc->MutableOutputDesc("y");
|
||||
auto dtype = desc_in_ptr->GetDataType();
|
||||
auto shape_in = desc_in_ptr->GetShape();
|
||||
auto x_format = desc_in_ptr->GetOriginFormat();
|
||||
if (x_format != FORMAT_NHWC && x_format != FORMAT_NCHW) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Attr x_format only support NHWC, NCHW.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::map<char, int> idx_map{{'N', 0}, {'H', 1}, {'W', 2}, {'C', 3}};
|
||||
if (x_format == FORMAT_NCHW) {
|
||||
idx_map = {{'N', 0}, {'C', 1}, {'H', 2}, {'W', 3}};
|
||||
}
|
||||
|
||||
int64_t in_n = shape_in.GetDim(idx_map['N']);
|
||||
int64_t in_h = shape_in.GetDim(idx_map['H']);
|
||||
int64_t in_w = shape_in.GetDim(idx_map['W']);
|
||||
int64_t in_c = shape_in.GetDim(idx_map['C']);
|
||||
|
||||
if (ksize.size() != 2) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "The size of ksizes must be 2 when x_format only support NHWC, NCHW.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
int64_t filter_h = ksize[0];
|
||||
int64_t filter_w = ksize[1];
|
||||
|
||||
int64_t stride_h = stride[0];
|
||||
int64_t stride_w = stride[0];
|
||||
if (stride.size() == 2) {
|
||||
stride_h = stride[0];
|
||||
stride_w = stride[1];
|
||||
} else if (stride.size() != 1) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "The size of strides must be 1 or 2 when x_format only support NHWC, NCHW.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (stride_h == 0 || stride_w == 0) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "The stride_h or stride_w should not 0");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int64_t dilation_h = dilation[0];
|
||||
int64_t dilation_w = dilation[0];
|
||||
if (dilation.size() == 2) {
|
||||
dilation_h = dilation[0];
|
||||
dilation_w = dilation[1];
|
||||
} else if (dilation.size() != 1) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "The size of dilations must be 1 or 2 when x_format only support NHWC, NCHW.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int64_t effective_filter_h = (filter_h - 1) * dilation_h + 1;
|
||||
int64_t effective_filter_w = (filter_w - 1) * dilation_w + 1;
|
||||
int64_t out_h{0};
|
||||
int64_t out_w{0};
|
||||
int64_t out_c{0};
|
||||
if (padding_mode == "VALID") {
|
||||
out_h = (in_h - effective_filter_h + stride_h) / stride_h;
|
||||
out_w = (in_w - effective_filter_w + stride_w) / stride_w;
|
||||
} else if (padding_mode == "SAME") {
|
||||
out_h = (in_h + stride_h - 1) / stride_h;
|
||||
out_w = (in_w + stride_w - 1) / stride_w;
|
||||
} else if (padding_mode == "CALCULATED") {
|
||||
int64_t pad_h_top;
|
||||
int64_t pad_h_bottom;
|
||||
int64_t pad_w_before;
|
||||
int64_t pad_w_after;
|
||||
if (pad.size() == 1) {
|
||||
pad_h_top = pad[0];
|
||||
pad_h_bottom = pad[0];
|
||||
pad_w_before = pad[0];
|
||||
pad_w_after = pad[0];
|
||||
} else if (pad.size() == 4) {
|
||||
pad_h_top = pad[0];
|
||||
pad_h_bottom = pad[1];
|
||||
pad_w_before = pad[2];
|
||||
pad_w_after = pad[3];
|
||||
} else {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "The size of pads must be 1 or 4 when x_format only support NHWC, NCHW.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out_h = (in_h + pad_h_top + pad_h_bottom - (dilation_h * (filter_h - 1) + 1)) / stride_h + 1;
|
||||
out_w = (in_w + pad_w_before + pad_w_after - (dilation_w * (filter_w - 1) + 1)) / stride_w + 1;
|
||||
} else {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "The padding_mode only support VALID, SAME and CALCULATED.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out_c = in_c * filter_h * filter_w;
|
||||
|
||||
std::vector<int64_t> out_dim{in_n, out_h, out_w, out_c};
|
||||
if (x_format == FORMAT_NCHW) {
|
||||
out_dim = {in_n, out_c, out_h, out_w};
|
||||
}
|
||||
|
||||
desc_out_ptr->SetShape(ge::GeShape(out_dim));
|
||||
desc_out_ptr->SetDataType(dtype);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(Im2col, Im2colInferShape);
|
||||
VERIFY_FUNC_REG(Im2col, Im2colVerify);
|
||||
// -----------------Im2col END-------------------------
|
||||
} // namespace ge
|
|
@ -0,0 +1,142 @@
|
|||
/**
|
||||
* Copyright (c) 2022-202 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/image_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/image_ops_shape_fns.h"
|
||||
namespace ge {
|
||||
// ----------------AdjustHue Start-------------------
|
||||
IMPLEMT_INFERFUNC(AdjustHue, AdjustHueInfer) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto images_desc = op_desc->MutableInputDesc(0);
|
||||
|
||||
GeShape out;
|
||||
auto ret = WithRankAtLeast(images_desc, 3, out, op);
|
||||
if (ret != GRAPH_SUCCESS) {
|
||||
std::string err_msg = GetShapeErrMsg(0, DebugString(images_desc->GetShape().GetDims()), "at least 3D");
|
||||
err_msg = string("failed to call WithRankAtLeast function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> range;
|
||||
if (images_desc->GetShapeRange(range) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
auto y_desc = op_desc->MutableOutputDesc(0);
|
||||
y_desc->SetShape(out);
|
||||
y_desc->SetShapeRange(range);
|
||||
y_desc->SetDataType(images_desc->GetDataType());
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(AdjustHue, AdjustHueInfer);
|
||||
// ----------------AdjustHue End-------------------
|
||||
|
||||
// ----------------AdjustSaturation Start-------------------
|
||||
static graphStatus AdjustSaturationCommInferShape(const Operator &op) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto images_desc = op_desc->MutableInputDesc(0);
|
||||
|
||||
GeShape out;
|
||||
auto ret = WithRankAtLeast(images_desc, 3, out, op);
|
||||
if (ret != GRAPH_SUCCESS) {
|
||||
std::string err_msg = GetShapeErrMsg(0, DebugString(images_desc->GetShape().GetDims()), "at least 3D");
|
||||
err_msg = string("failed to call WithRankAtLeast function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> range;
|
||||
if (images_desc->GetShapeRange(range) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
auto y_desc = op_desc->MutableOutputDesc(0);
|
||||
y_desc->SetShape(out);
|
||||
y_desc->SetShapeRange(range);
|
||||
y_desc->SetDataType(images_desc->GetDataType());
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_INFERFUNC(AdjustSaturation, AdjustSaturationInfer) { return AdjustSaturationCommInferShape(op); }
|
||||
|
||||
INFER_FUNC_REG(AdjustSaturation, AdjustSaturationInfer);
|
||||
// ----------------AdjustSaturation END-------------------
|
||||
|
||||
// ----------------ExtractGlimpse-------------------
|
||||
IMPLEMT_INFERFUNC(ExtractGlimpse, ExtractGlimpseInfer) {
|
||||
Shape x_shape;
|
||||
auto ret = WithRank(op.GetInputDesc(0), 4, x_shape, op);
|
||||
if (ret != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op), "input x must be 4-D");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
Shape offsets_shape;
|
||||
ret = WithRank(op.GetInputDesc(2), 2, offsets_shape, op);
|
||||
if (ret != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op), "input offsets must be 2-D");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
auto x_dims = op.GetInputDesc(0).GetShape().GetDims();
|
||||
auto offsets_dims = op.GetInputDesc(2).GetShape().GetDims();
|
||||
CHECK(x_dims.size() < 4 || offsets_dims.size() < 2, OP_LOGE(TbeGetName(op), "invalid x_dims or offsets_dims."),
|
||||
return GRAPH_FAILED);
|
||||
int64_t batch_dim;
|
||||
if (Merge(x_dims[0], offsets_dims[0], batch_dim) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op), "x dim-0 or offsets dim-0 is invalid");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
if (offsets_dims[1] != 2) {
|
||||
OP_LOGE(TbeGetName(op), "offsets dim-1 must be 2");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
bool uniform_noise = false;
|
||||
if (op.GetAttr("uniform_noise", uniform_noise) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op), "get attr uniform_noise failed");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::string noise;
|
||||
if (op.GetAttr("noise", noise) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op), "get attr noise failed");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (uniform_noise && (!noise.empty() && noise != "uniform")) {
|
||||
OP_LOGE(TbeGetName(op), "The uniform_noise and noise should not be specified at the same time");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
TensorDesc desc = op.GetOutputDescByName("y");
|
||||
desc.SetDataType(DT_FLOAT);
|
||||
if (op.UpdateOutputDesc("y", desc) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
auto channel_dim = x_dims[3];
|
||||
TensorDesc input_td = op.GetInputDesc(0);
|
||||
if (static_cast<ge::Format>(ge::GetPrimaryFormat(input_td.GetFormat())) == FORMAT_NCHW) {
|
||||
channel_dim = x_dims[1];
|
||||
}
|
||||
return SetOutputToSizedImage(op, batch_dim, "size", channel_dim, "y");
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(ExtractGlimpse, ExtractGlimpseInfer);
|
||||
// ----------------ExtractGlimpse-------------------
|
||||
} // namespace ge
|
|
@ -1,45 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/index_fill.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
namespace ge {
|
||||
// ----------------IndexFill-------------------
|
||||
// Obtains the processing function of the output tensor description.
|
||||
IMPLEMT_COMMON_INFERFUNC(IndexFillInferShape) {
|
||||
TensorDesc v_output_desc = op.GetOutputDescByName("y");
|
||||
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
Format input_format = op.GetInputDescByName("x").GetFormat();
|
||||
// shape of output y is the same as input x
|
||||
ge::Shape shape_input = op.GetInputDescByName("x").GetShape();
|
||||
|
||||
v_output_desc.SetShape(shape_input);
|
||||
v_output_desc.SetDataType(input_dtype);
|
||||
v_output_desc.SetFormat(input_format);
|
||||
|
||||
if (op.UpdateOutputDesc("y", v_output_desc) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// Registered inferfunction
|
||||
CUST_COMMON_INFER_FUNC_REG(IndexFill, IndexFillInferShape);
|
||||
// ----------------IndexFill END-------------------
|
||||
} // namespace ge
|
|
@ -0,0 +1,407 @@
|
|||
/**
|
||||
* Copyright (c) 2023 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/linalg_ops.h"
|
||||
#include "custom_op_proto/cust_linalg_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/linalg_ops_shape_fns.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------MatrixSolve-------------------
|
||||
IMPLEMT_INFERFUNC(MatrixSolve, MatrixSolveInfer) {
|
||||
auto matrix_tensor = op.get_input_desc_matrix();
|
||||
auto rhs_tensor = op.get_input_desc_rhs();
|
||||
Shape result;
|
||||
if (MatrixSolve(matrix_tensor, rhs_tensor, true, result, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Op MatrixSolve Call MatrixSolve Infer Shape fns Failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
DataType type = op.GetInputDescByName("matrix").GetDataType();
|
||||
|
||||
TensorDesc y_desc = op.GetOutputDescByName("y");
|
||||
y_desc.SetShape(Shape(result));
|
||||
y_desc.SetDataType(type);
|
||||
op.UpdateOutputDesc("y", y_desc);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(MatrixSolve, MatrixSolveInfer);
|
||||
// ----------------MatrixSolve End-------------------
|
||||
|
||||
// ----------------MatrixDeterminant-------------------
|
||||
IMPLEMT_INFERFUNC(MatrixDeterminant, MatrixDeterminantInfer) {
|
||||
auto tensor = op.get_input_desc_x();
|
||||
Shape s;
|
||||
if (WithRankAtLeast(tensor, 2, s, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "The rank of x must be at least 2.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int64_t existing = s.GetDimNum();
|
||||
int64_t dim1 = s.GetDim(existing - 1);
|
||||
int64_t dim2 = s.GetDim(existing - 2);
|
||||
int64_t unused_dim = 0;
|
||||
|
||||
if (Merge(dim1, dim2, unused_dim) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Merge two dimension failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Shape result;
|
||||
if (SubShape(s, 0, -2, 1, result, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Op MatrixDeterminant Get SubShape Failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
DataType type = op.GetInputDescByName("x").GetDataType();
|
||||
|
||||
TensorDesc y_desc = op.GetOutputDescByName("y");
|
||||
y_desc.SetShape(Shape(result));
|
||||
y_desc.SetDataType(type);
|
||||
op.UpdateOutputDesc("y", y_desc);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(MatrixDeterminant, MatrixDeterminantInfer);
|
||||
// ----------------MatrixDeterminant End-------------------
|
||||
|
||||
// ----------------MatrixTriangularSolve-------------------
|
||||
IMPLEMT_INFERFUNC(MatrixTriangularSolve, MatrixTriangularSolveInfer) {
|
||||
auto matrix_tensor = op.get_input_desc_matrix();
|
||||
auto rhs_tensor = op.get_input_desc_rhs();
|
||||
Shape result;
|
||||
if (MatrixSolve(matrix_tensor, rhs_tensor, true, result, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Op MatrixTriangularSolve Call MatrixSolve Infer Shape fns Failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
DataType type = op.GetInputDescByName("matrix").GetDataType();
|
||||
|
||||
TensorDesc y_desc = op.GetOutputDescByName("y");
|
||||
y_desc.SetShape(Shape(result));
|
||||
y_desc.SetDataType(type);
|
||||
op.UpdateOutputDesc("y", y_desc);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
INFER_FUNC_REG(MatrixTriangularSolve, MatrixTriangularSolveInfer);
|
||||
// ----------------MatrixTriangularSolve END-------------------
|
||||
|
||||
// ---------------Geqrf-------------------
|
||||
CUST_IMPLEMT_INFERFUNC(Geqrf, GeqrfInfer) {
|
||||
auto tensor = op.get_input_desc_x();
|
||||
Shape input;
|
||||
if (WithRank(tensor, 2, input, op) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int dim_num = input.GetDimNum();
|
||||
int m = input.GetDim(dim_num - 2);
|
||||
int n = input.GetDim(dim_num - 1);
|
||||
Shape r_shape;
|
||||
Shape tau_shape;
|
||||
int p = m > n ? n : m;
|
||||
Matrix(m, n, r_shape);
|
||||
Vector(p, tau_shape);
|
||||
|
||||
DataType type = op.GetInputDescByName("x").GetDataType();
|
||||
TensorDesc r_desc = op.GetOutputDescByName("r");
|
||||
r_desc.SetShape(Shape(r_shape));
|
||||
r_desc.SetDataType(type);
|
||||
if (op.UpdateOutputDesc("r", r_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Update r desc failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
TensorDesc tau_desc = op.GetOutputDescByName("tau");
|
||||
tau_desc.SetShape(Shape(tau_shape));
|
||||
tau_desc.SetDataType(type);
|
||||
if (op.UpdateOutputDesc("tau", tau_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Update tau desc failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_INFER_FUNC_REG(Geqrf, GeqrfInfer);
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(Geqrf, GeqrfVerify) {
|
||||
DataType type = op.GetInputDescByName("x").GetDataType();
|
||||
if (type != DT_FLOAT16 && type != DT_FLOAT && type != DT_DOUBLE && type != DT_COMPLEX64 && type != DT_COMPLEX128) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Expect a floating point or complex tensor as input.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_VERIFY_FUNC_REG(Geqrf, GeqrfVerify);
|
||||
// ---------------Geqrf End---------------
|
||||
|
||||
// ---------------LuUnpack---------------
|
||||
CUST_IMPLEMT_INFERFUNC(LuUnpack, LuUnpackInferShape) {
|
||||
Shape LU_data;
|
||||
if (WithRankAtLeast(op.GetInputDesc(0), 2, LU_data, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "LU_data rank must be at least 2.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int64_t existing = LU_data.GetDimNum();
|
||||
int64_t dim1 = LU_data.GetDim(existing - 2);
|
||||
int64_t dim2 = LU_data.GetDim(existing - 1);
|
||||
|
||||
Shape batch_shape;
|
||||
if (SubShape(LU_data, 0, -2, 1, batch_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Op LuUnpack Get SubShape Failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Shape L_shape;
|
||||
vector<int64_t> L_dims;
|
||||
L_dims.reserve(2);
|
||||
if (dim1 >= dim2) {
|
||||
L_dims.push_back(dim1);
|
||||
L_dims.push_back(dim2);
|
||||
} else {
|
||||
L_dims.push_back(dim1);
|
||||
L_dims.push_back(dim1);
|
||||
}
|
||||
Shape L_sec_shape(L_dims);
|
||||
if (Concatenate(batch_shape, L_sec_shape, L_shape) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Concatenate L_shape failed!");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Shape U_shape;
|
||||
vector<int64_t> U_dims;
|
||||
U_dims.reserve(2);
|
||||
if (dim1 >= dim2) {
|
||||
U_dims.push_back(dim2);
|
||||
U_dims.push_back(dim2);
|
||||
} else {
|
||||
U_dims.push_back(dim1);
|
||||
U_dims.push_back(dim2);
|
||||
}
|
||||
Shape U_sec_shape(U_dims);
|
||||
if (Concatenate(batch_shape, U_sec_shape, U_shape) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Concatenate U_shape failed!");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Shape pivots_shape;
|
||||
vector<int64_t> pivots_dims;
|
||||
pivots_dims.reserve(2);
|
||||
pivots_dims.push_back(dim1);
|
||||
pivots_dims.push_back(dim1);
|
||||
Shape pivots_sec_shape(pivots_dims);
|
||||
if (Concatenate(batch_shape, pivots_sec_shape, pivots_shape) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Concatenate pivots_shape failed!");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
TensorDesc L_desc = op.GetOutputDescByName("L");
|
||||
L_desc.SetShape(Shape(L_shape));
|
||||
DataType L_type = op.GetInputDescByName("LU_data").GetDataType();
|
||||
L_desc.SetDataType(L_type);
|
||||
if (L_desc.GetDataType() != L_type) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "the type of L must be the same as the type of LU_data.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (op.UpdateOutputDesc("L", L_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "fail to update output L.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc U_desc = op.GetOutputDescByName("U");
|
||||
U_desc.SetShape(Shape(U_shape));
|
||||
DataType U_type = op.GetInputDescByName("LU_data").GetDataType();
|
||||
U_desc.SetDataType(U_type);
|
||||
if (U_desc.GetDataType() != U_type) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "the type of U must be the same as the type of LU_data.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (op.UpdateOutputDesc("U", U_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "fail to update output U.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc pivots_desc = op.GetOutputDescByName("pivots");
|
||||
pivots_desc.SetShape(Shape(pivots_shape));
|
||||
DataType pivots_type = op.GetInputDescByName("LU_data").GetDataType();
|
||||
pivots_desc.SetDataType(pivots_type);
|
||||
if (pivots_desc.GetDataType() != pivots_type) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "the type of pivots must be the same as the type of LU_data.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (op.UpdateOutputDesc("pivots", pivots_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "fail to update output pivots.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(LuUnpack, LuUnpackVerify) {
|
||||
DataType LU_data_type = op.GetInputDescByName("LU_data").GetDataType();
|
||||
DataType LU_pivots_type = op.GetInputDescByName("LU_pivots").GetDataType();
|
||||
if (LU_data_type != DT_FLOAT16 && LU_data_type != DT_FLOAT && LU_data_type != DT_DOUBLE && LU_data_type != DT_INT8 &&
|
||||
LU_data_type != DT_UINT8 && LU_data_type != DT_INT16 && LU_data_type != DT_INT32 && LU_data_type != DT_INT64) {
|
||||
std::string err_msg;
|
||||
err_msg = ConcatString("Op LuUnpack first input LU_data_type's data type should be of the follows: ",
|
||||
"DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE,",
|
||||
"but this type is ", LU_data_type, ".");
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (LU_pivots_type != DT_INT8 && LU_pivots_type != DT_UINT8 && LU_pivots_type != DT_INT16 &&
|
||||
LU_pivots_type != DT_INT32 && LU_pivots_type != DT_INT64) {
|
||||
std::string err_msg;
|
||||
err_msg =
|
||||
ConcatString("Op LuUnpack first input LU_data_type's data type should be of the follows: ",
|
||||
"DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64,", "but this type is ", LU_pivots_type, ".");
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_INFER_FUNC_REG(LuUnpack, LuUnpackInferShape);
|
||||
CUST_VERIFY_FUNC_REG(LuUnpack, LuUnpackVerify);
|
||||
// ---------------LuUnpack END---------------
|
||||
|
||||
// ---------------LuUnpackGrad---------------
|
||||
IMPLEMT_COMMON_INFERFUNC(LuUnpackGradInferShape) {
|
||||
TensorDesc L_grad;
|
||||
TensorDesc U_grad;
|
||||
TensorDesc LU_data;
|
||||
if (op.TryGetInputDesc("LU_data", LU_data) == GRAPH_FAILED) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "LU_data can not be null.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc L_data_grad = op.GetOutputDescByName("L_data_grad");
|
||||
L_data_grad.SetDataType(op.GetInputDescByName("LU_data").GetDataType());
|
||||
L_data_grad.SetShape(op.GetInputDescByName("LU_data").GetShape());
|
||||
if (op.UpdateOutputDesc("L_data_grad", L_data_grad) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "fail to update output L_data_grad.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc U_data_grad = op.GetOutputDescByName("U_data_grad");
|
||||
U_data_grad.SetDataType(op.GetInputDescByName("LU_data").GetDataType());
|
||||
U_data_grad.SetShape(op.GetInputDescByName("LU_data").GetShape());
|
||||
if (op.UpdateOutputDesc("U_data_grad", U_data_grad) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "fail to update output U_data_grad.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
CUST_IMPLEMT_VERIFIER(LuUnpackGrad, LuUnpackGradVerify) {
|
||||
DataType LU_data_type = op.GetInputDescByName("LU_data").GetDataType();
|
||||
Shape LU_data_shape = op.GetInputDescByName("LU_data").GetShape();
|
||||
TensorDesc L_grad;
|
||||
TensorDesc U_grad;
|
||||
if (op.TryGetInputDesc("L_grad", L_grad) == GRAPH_SUCCESS) {
|
||||
DataType L_grad_type = op.GetInputDescByName("L_grad").GetDataType();
|
||||
Shape L_data_shape = op.GetInputDescByName("L_grad").GetShape();
|
||||
auto L_data_dim1 = L_data_shape.GetDim(-2);
|
||||
auto L_data_dim2 = L_data_shape.GetDim(-1);
|
||||
auto LU_data_dim1 = LU_data_shape.GetDim(-2);
|
||||
auto LU_data_dim2 = LU_data_shape.GetDim(-1);
|
||||
int64_t LU_data_min = std::min(LU_data_dim1, LU_data_dim2);
|
||||
if (LU_data_dim1 != L_data_dim1) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "L_grad's data dim[-2] and LU_data's dim[-2] should be same.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (LU_data_min != L_data_dim2) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "L_grad's data dim[-1] and LU_data's minimum dim should be same.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (LU_data_type != L_grad_type) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "L_grad's data type and LU_data's type should be same.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
if (op.TryGetInputDesc("U_grad", U_grad) == GRAPH_SUCCESS) {
|
||||
DataType U_grad_type = op.GetInputDescByName("U_grad").GetDataType();
|
||||
Shape U_data_shape = op.GetInputDescByName("U_grad").GetShape();
|
||||
auto U_data_dim1 = U_data_shape.GetDim(-2);
|
||||
auto U_data_dim2 = U_data_shape.GetDim(-1);
|
||||
auto LU_data_dim1 = LU_data_shape.GetDim(-2);
|
||||
auto LU_data_dim2 = LU_data_shape.GetDim(-1);
|
||||
int64_t LU_data_min = std::min(LU_data_dim1, LU_data_dim2);
|
||||
if (U_data_dim2 != LU_data_dim2) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "U_grad's data dim[-1] and LU_data's dim[-1] should be same.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (LU_data_min != U_data_dim1) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "U_grad's data dim[-2] and LU_data's minimum dim should be same.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (LU_data_type != U_grad_type) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "U_grad's data type and LU_data's type should be same.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
CUST_COMMON_INFER_FUNC_REG(LuUnpackGrad, LuUnpackGradInferShape);
|
||||
CUST_VERIFY_FUNC_REG(LuUnpackGrad, LuUnpackGradVerify);
|
||||
// ---------------LuUnpackGrad End---------------
|
||||
|
||||
// -----------------------LuSolve---------------------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(LuSolveInferShape) {
|
||||
Shape b_shape = op.GetInputDescByName("x").GetShape();
|
||||
Shape lu_shape = op.GetInputDescByName("lu_data").GetShape();
|
||||
size_t b_dims = b_shape.GetDimNum();
|
||||
size_t lu_dims = lu_shape.GetDimNum();
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
Format input_format = op.GetInputDescByName("x").GetFormat();
|
||||
std::vector<int64_t> dim_vector;
|
||||
if (b_dims >= lu_dims) {
|
||||
Shape output_shape = b_shape;
|
||||
TensorDesc td = op.GetOutputDescByName("y");
|
||||
td.SetShape(output_shape);
|
||||
td.SetDataType(input_dtype);
|
||||
td.SetFormat(input_format);
|
||||
(void)op.UpdateOutputDesc("y", td);
|
||||
return GRAPH_SUCCESS;
|
||||
} else {
|
||||
for (size_t i = 0; i <= lu_dims - b_dims - 1; i++) {
|
||||
dim_vector.push_back(lu_shape.GetDim(i));
|
||||
}
|
||||
for (size_t i = 0; i <= b_dims - 1; i++) {
|
||||
dim_vector.push_back(b_shape.GetDim(i));
|
||||
}
|
||||
Shape output_shape(dim_vector);
|
||||
TensorDesc td = op.GetOutputDescByName("y");
|
||||
td.SetShape(output_shape);
|
||||
td.SetDataType(input_dtype);
|
||||
td.SetFormat(input_format);
|
||||
(void)op.UpdateOutputDesc("y", td);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(LuSolve, LuSolveVerify) {
|
||||
DataType input_type_x = op.GetInputDescByName("x").GetDataType();
|
||||
DataType input_type_y = op.GetInputDescByName("lu_data").GetDataType();
|
||||
if (input_type_x != input_type_y) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(LuSolve, LuSolveInferShape);
|
||||
CUST_VERIFY_FUNC_REG(LuSolve, LuSolveVerify);
|
||||
// -----------------------LuSolve END---------------------------------
|
||||
} // namespace ge
|
|
@ -1,88 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/logspace.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
namespace ge {
|
||||
// --------------------------LogSpace---------------------
|
||||
static bool CheckSteps(const Operator &op, const string &attr_num_steps) {
|
||||
int64_t steps = 0;
|
||||
int64_t steps_ori = 100;
|
||||
if (ge::GRAPH_SUCCESS != op.GetAttr(attr_num_steps.c_str(), steps)) {
|
||||
steps = steps_ori;
|
||||
}
|
||||
if (steps < 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(LogSpace, LogSpaceVerify) {
|
||||
AscendString opName;
|
||||
op.GetName(opName);
|
||||
if (op.GetInputDescByName("start").GetShape().GetDims().size() != 1) {
|
||||
OP_LOGE(opName.GetString(), "Input start size must be 1.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (op.GetInputDescByName("end").GetShape().GetDims().size() != 1) {
|
||||
OP_LOGE(opName.GetString(), "Input end size must be 1.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
DataType input_type_start = op.GetInputDescByName("start").GetDataType();
|
||||
DataType input_type_end = op.GetInputDescByName("end").GetDataType();
|
||||
if (input_type_start != input_type_end) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
// Obtains the processing function of the output tensor description.
|
||||
IMPLEMT_COMMON_INFERFUNC(LogSpaceInferShape) {
|
||||
AscendString opName1;
|
||||
op.GetName(opName1);
|
||||
TensorDesc v_output_desc = op.GetOutputDescByName("y");
|
||||
int64_t steps;
|
||||
int64_t num_rows = 1;
|
||||
op.GetAttr("steps", steps);
|
||||
if (!CheckSteps(op, "steps")) {
|
||||
OP_LOGE(opName1.GetString(), "the attr 'steps' should be greater than or equal to 0.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::vector<int64_t> dim_vec;
|
||||
dim_vec.push_back(num_rows);
|
||||
dim_vec.push_back(steps);
|
||||
v_output_desc.SetShape(ge::Shape(dim_vec));
|
||||
int64_t dtype = 1;
|
||||
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
|
||||
v_output_desc.SetDataType(DT_FLOAT16);
|
||||
} else {
|
||||
if (dtype == 1) {
|
||||
v_output_desc.SetDataType(DT_FLOAT16);
|
||||
}
|
||||
if (dtype == 0) {
|
||||
v_output_desc.SetDataType(DT_FLOAT);
|
||||
}
|
||||
}
|
||||
(void)op.UpdateOutputDesc("y", v_output_desc);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(LogSpace, LogSpaceInferShape);
|
||||
// Registered verify function
|
||||
CUST_VERIFY_FUNC_REG(LogSpace, LogSpaceVerify);
|
||||
// --------------------------LogSpace END---------------------
|
||||
} // namespace ge
|
|
@ -1,68 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/lu_solve_op.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
// -----------------------LuSolve---------------------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(LuSolveInferShape) {
|
||||
Shape b_shape = op.GetInputDescByName("x").GetShape();
|
||||
Shape lu_shape = op.GetInputDescByName("lu_data").GetShape();
|
||||
size_t b_dims = b_shape.GetDimNum();
|
||||
size_t lu_dims = lu_shape.GetDimNum();
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
Format input_format = op.GetInputDescByName("x").GetFormat();
|
||||
std::vector<int64_t> dim_vector;
|
||||
if (b_dims >= lu_dims) {
|
||||
Shape output_shape = b_shape;
|
||||
TensorDesc td = op.GetOutputDescByName("y");
|
||||
td.SetShape(output_shape);
|
||||
td.SetDataType(input_dtype);
|
||||
td.SetFormat(input_format);
|
||||
(void)op.UpdateOutputDesc("y", td);
|
||||
return GRAPH_SUCCESS;
|
||||
} else {
|
||||
for (size_t i = 0; i <= lu_dims - b_dims - 1; i++) {
|
||||
dim_vector.push_back(lu_shape.GetDim(i));
|
||||
}
|
||||
for (size_t i = 0; i <= b_dims - 1; i++) {
|
||||
dim_vector.push_back(b_shape.GetDim(i));
|
||||
}
|
||||
Shape output_shape(dim_vector);
|
||||
TensorDesc td = op.GetOutputDescByName("y");
|
||||
td.SetShape(output_shape);
|
||||
td.SetDataType(input_dtype);
|
||||
td.SetFormat(input_format);
|
||||
(void)op.UpdateOutputDesc("y", td);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(LuSolve, LuSolveVerify) {
|
||||
DataType input_type_x = op.GetInputDescByName("x").GetDataType();
|
||||
DataType input_type_y = op.GetInputDescByName("lu_data").GetDataType();
|
||||
if (input_type_x != input_type_y) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(LuSolve, LuSolveInferShape);
|
||||
CUST_VERIFY_FUNC_REG(LuSolve, LuSolveVerify);
|
||||
// -----------------------LuSolve END---------------------------------
|
||||
} // namespace ge
|
|
@ -0,0 +1,215 @@
|
|||
/*
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/math_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------ComplexAbs-------------------
|
||||
IMPLEMT_INFERFUNC(ComplexAbs, ComplexAbsInfer) {
|
||||
TensorDesc x_desc = op.GetInputDescByName("x");
|
||||
DataType x_type = x_desc.GetDataType();
|
||||
DataType out_type;
|
||||
switch (x_type) {
|
||||
case DT_COMPLEX64:
|
||||
out_type = DT_FLOAT;
|
||||
break;
|
||||
case DT_COMPLEX128:
|
||||
out_type = DT_DOUBLE;
|
||||
break;
|
||||
default:
|
||||
OP_LOGE("ComplexAbs", "Invalid input dtype: %s", DTypeStr(x_type).c_str());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
x_desc.SetDataType(out_type);
|
||||
return op.UpdateOutputDesc("y", x_desc);
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(ComplexAbs, ComplexAbsInfer);
|
||||
// ----------------ComplexAbs End-------------------
|
||||
|
||||
// ----------------ComplexAbs-------------------
|
||||
IMPLEMT_INFERFUNC(Complex, ComplexInfer) {
|
||||
bool is_dynamic_output = true;
|
||||
if (!InferShapeAndTypeTwoInOneOutBroadcast(op, 0, 1, 0, is_dynamic_output)) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc x_desc = op.GetInputDescByName("real");
|
||||
DataType x_type = x_desc.GetDataType();
|
||||
DataType out_type;
|
||||
switch (x_type) {
|
||||
case DT_FLOAT:
|
||||
out_type = DT_COMPLEX64;
|
||||
break;
|
||||
case DT_DOUBLE:
|
||||
out_type = DT_COMPLEX128;
|
||||
break;
|
||||
default:
|
||||
OP_LOGE("Complex", "Invalid input dtype: %s", DTypeStr(x_type).c_str());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc out_desc = op.GetOutputDescByName("out");
|
||||
out_desc.SetDataType(out_type);
|
||||
return op.UpdateOutputDesc("out", out_desc);
|
||||
}
|
||||
INFER_FUNC_REG(Complex, ComplexInfer);
|
||||
// ----------------ComplexAbs-------------------
|
||||
|
||||
// ----------------IsNan-------------------
|
||||
IMPLEMT_INFERFUNC(IsNan, IsNanInfer) {
|
||||
TensorDesc out_desc = op.GetOutputDescByName("y");
|
||||
out_desc.SetDataType(DT_BOOL);
|
||||
if (op.UpdateOutputDesc("y", out_desc) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("update output[y] failed."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return UnchangedShape(op, "x", "y");
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(IsNan, IsNanInfer);
|
||||
// ----------------IsNan End-------------------
|
||||
|
||||
// ----------------NextAfter-------------------
|
||||
IMPLEMT_INFERFUNC(NextAfter, NextAfterInfer) {
|
||||
Shape x_shape = op.GetInputDescByName("x1").GetShape();
|
||||
Shape y_shape = op.GetInputDescByName("x2").GetShape();
|
||||
TensorDesc out_desc = op.GetOutputDescByName("output");
|
||||
DataType x_type = op.GetInputDescByName("x1").GetDataType();
|
||||
DataType y_type = op.GetInputDescByName("x2").GetDataType();
|
||||
if (x_type != y_type) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "the type of x1 is different from that of x2!");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
out_desc.SetDataType(x_type);
|
||||
if ((!RankKnown(x_shape)) || (!RankKnown(y_shape))) {
|
||||
Shape out_shape(UNKNOWN_SHAPE);
|
||||
out_desc.SetShape(out_shape);
|
||||
if (op.UpdateOutputDesc("output", out_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "update output failed");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
const size_t rank_x = x_shape.GetDimNum();
|
||||
const size_t rank_y = y_shape.GetDimNum();
|
||||
const size_t rank_out = std::max(rank_x, rank_y);
|
||||
|
||||
// To compute the broadcast dimensions, zip together x_shape and y_shape
|
||||
// and pad with 1 to make them the same length.
|
||||
std::vector<int64_t> dims;
|
||||
int64_t dim_one = 1;
|
||||
if (rank_x != rank_y) {
|
||||
OP_LOGI(TbeGetName(op).c_str(), "x1 shape is not equal to x2 shape!");
|
||||
dim_one = 1;
|
||||
}
|
||||
for (size_t i = 0; i < rank_out; ++i) {
|
||||
int64_t dim_x;
|
||||
if (i < (rank_out - rank_x)) {
|
||||
dim_x = dim_one;
|
||||
} else {
|
||||
// rank_out = rank_x or i >= rank_y - rank_x.
|
||||
for (size_t j = 0; j < x_shape.GetDimNum(); ++j) {
|
||||
if (x_shape.GetDim(j) == UNKNOWN_DIM) {
|
||||
dim_x = UNKNOWN_DIM;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if ((i - (rank_out - rank_x)) < 0) {
|
||||
dim_x = x_shape.GetDim(rank_x + i - (rank_out - rank_x));
|
||||
} else {
|
||||
dim_x = x_shape.GetDim(i - (rank_out - rank_x));
|
||||
}
|
||||
}
|
||||
|
||||
const bool dim_y_is_one = (i < (rank_out - rank_y));
|
||||
int64_t dim_y;
|
||||
if (dim_y_is_one) {
|
||||
dim_y = dim_one;
|
||||
} else {
|
||||
// rank_out = rank_y or i >= rank_x - rank_y.
|
||||
for (size_t j = 0; j < y_shape.GetDimNum(); ++j) {
|
||||
if (y_shape.GetDim(j) == UNKNOWN_DIM) {
|
||||
dim_y = UNKNOWN_DIM;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if ((i - (rank_out - rank_y)) < 0) {
|
||||
dim_y = y_shape.GetDim(rank_y + i - (rank_out - rank_y));
|
||||
} else {
|
||||
dim_y = y_shape.GetDim(i - (rank_out - rank_y));
|
||||
}
|
||||
}
|
||||
|
||||
if ((dim_x == UNKNOWN_DIM) || (dim_y == UNKNOWN_DIM)) {
|
||||
/* One or both dimensions is unknown.
|
||||
* If either dimension is greater than 1, assume that the program is
|
||||
* correct, and the other dimension will be broadcast to match it.
|
||||
* For shape inference, if eliminate the shape checks
|
||||
* in this code, assert that the unknown dim is either 1
|
||||
* or the same as the known dim.
|
||||
* If either dimension is 1, the other dimension is the output.
|
||||
*/
|
||||
if (dim_x > 1) {
|
||||
dims.push_back(dim_x);
|
||||
} else if (dim_y > 1) {
|
||||
dims.push_back(dim_y);
|
||||
} else if (dim_x == 1) {
|
||||
dims.push_back(dim_y);
|
||||
} else if (dim_y == 1) {
|
||||
dims.push_back(dim_x);
|
||||
} else if (dim_x == dim_y) {
|
||||
dims.push_back(dim_x);
|
||||
} else {
|
||||
dims.push_back(UNKNOWN_DIM);
|
||||
}
|
||||
} else if ((dim_x == 1) || (dim_y == 1)) {
|
||||
// dim_x is dim_one or dim_y is dim_one.
|
||||
if ((dim_x == 1) && (!dim_y_is_one)) {
|
||||
// broadcast dim_x to dim_y.
|
||||
dims.push_back(dim_y);
|
||||
} else {
|
||||
if (dim_y == 1) {
|
||||
// broadcast dim_y to dim_x.
|
||||
dims.push_back(dim_x);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int64_t dim;
|
||||
if (Merge(dim_x, dim_y, dim) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
dims.push_back(dim);
|
||||
}
|
||||
}
|
||||
Shape out_shape(dims);
|
||||
out_desc.SetShape(out_shape);
|
||||
if (op.UpdateOutputDesc("output", out_desc) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "update output failed");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
INFER_FUNC_REG(NextAfter, NextAfterInfer);
|
||||
// ----------------NextAfter End-------------------
|
||||
|
||||
// ----------------IsInf------------------------
|
||||
IMPLEMT_INFERFUNC(IsInf, IsInfInfer) {
|
||||
TensorDesc out_desc = op.GetOutputDescByName("y");
|
||||
out_desc.SetDataType(DT_BOOL);
|
||||
if (op.UpdateOutputDesc("y", out_desc) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("update output[y] failed."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return UnchangedShape(op, "x", "y");
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(IsInf, IsInfInfer);
|
||||
// ----------------IsInf END------------------------
|
||||
} // namespace ge
|
|
@ -0,0 +1,162 @@
|
|||
/*
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/matrix_calculation_ops.h"
|
||||
#include "custom_op_proto/cust_math_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
IMPLEMT_COMMON_INFERFUNC(OneInOneOutCommonInferShape) {
|
||||
static const int64_t input_x_idx = 0;
|
||||
static const int64_t output_y_idx = 0;
|
||||
if (OneInOneOutDynamicInfer(op, input_x_idx, {output_y_idx})) {
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
// ----------------DiagPart-------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(DiagPartInferShape) {
|
||||
ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(op);
|
||||
CHECK(op_desc == nullptr, VECTOR_INFER_SHAPE_INNER_ERR_REPORT("DiagPart", GetInputInvalidErrMsg("op_desc")),
|
||||
return GRAPH_FAILED);
|
||||
ge::ConstGeTensorDescPtr input_x_desc = op_desc->GetInputDescPtr(0);
|
||||
CHECK(input_x_desc == nullptr, VECTOR_INFER_SHAPE_INNER_ERR_REPORT("DiagPart", GetInputInvalidErrMsg("x")),
|
||||
return GRAPH_FAILED);
|
||||
const GeShape &input_shape = input_x_desc->GetShape();
|
||||
const size_t input_to_output_dims_times = 2;
|
||||
size_t output_shape_len = input_shape.GetDimNum() / input_to_output_dims_times;
|
||||
ge::GeTensorDescPtr output_desc = op_desc->MutableOutputDesc(0);
|
||||
GeShape &output_shape = output_desc->MutableShape();
|
||||
DataType input_dtype = input_x_desc->GetDataType();
|
||||
|
||||
if (input_shape.IsUnknownDimNum()) {
|
||||
output_desc->SetShape(input_shape);
|
||||
} else {
|
||||
output_shape.SetDimNum(output_shape_len);
|
||||
for (size_t i = 0; i < output_shape_len; i++) {
|
||||
output_shape.SetDim(i, input_shape.GetDim(i));
|
||||
}
|
||||
}
|
||||
if (input_shape.IsUnknownShape()) {
|
||||
std::vector<std::pair<int64_t, int64_t>> shape_range;
|
||||
input_x_desc->GetShapeRange(shape_range);
|
||||
for (unsigned i = 0; i < shape_range.size(); i++) {
|
||||
if (shape_range[i].first > 0) {
|
||||
shape_range[i].first = shape_range[i].first;
|
||||
}
|
||||
if (shape_range[i].second > 0) {
|
||||
shape_range[i].second = shape_range[i].second;
|
||||
}
|
||||
}
|
||||
output_desc->SetShapeRange(shape_range);
|
||||
}
|
||||
output_desc->SetShape(output_shape);
|
||||
output_desc->SetDataType(input_dtype);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(DiagPart, DiagPartInferShape);
|
||||
// ----------------DiagPart END-------------------
|
||||
|
||||
// ---------------Eye----------------------------
|
||||
static bool CheckRows(const Operator &op, const string &attr_num_rows) {
|
||||
int64_t num_rows;
|
||||
op.GetAttr(attr_num_rows.c_str(), num_rows);
|
||||
if (num_rows <= 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CheckBatchShape(const Operator &op, const string &attr_batch_shape) {
|
||||
const std::string opName = TbeGetName(op);
|
||||
std::vector<int64_t> batch_shape;
|
||||
op.GetAttr(attr_batch_shape.c_str(), batch_shape);
|
||||
for (size_t i = 0; i < batch_shape.size(); ++i) {
|
||||
if (batch_shape[i] <= 0) {
|
||||
OP_LOGE(opName, "the value of batch_shape less than 0.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(EyeInferShape) {
|
||||
TensorDesc td = op.GetOutputDescByName("y");
|
||||
int64_t num_rows, num_columns;
|
||||
std::vector<int64_t> batch_shape;
|
||||
op.GetAttr("num_rows", num_rows);
|
||||
op.GetAttr("num_columns", num_columns);
|
||||
op.GetAttr("batch_shape", batch_shape);
|
||||
|
||||
if (!CheckRows(op, "num_rows") || !CheckBatchShape(op, "batch_shape")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (num_columns <= 0) {
|
||||
num_columns = num_rows;
|
||||
}
|
||||
std::vector<int64_t> dim_vec;
|
||||
for (size_t i = 0; i < batch_shape.size(); ++i) {
|
||||
dim_vec.push_back(batch_shape[i]);
|
||||
}
|
||||
dim_vec.push_back(num_rows);
|
||||
dim_vec.push_back(num_columns);
|
||||
td.SetShape(ge::Shape(dim_vec));
|
||||
(void)op.UpdateOutputDesc("y", td);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_VERIFIER(Eye, EyeVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
COMMON_INFER_FUNC_REG(Eye, EyeInferShape);
|
||||
|
||||
VERIFY_FUNC_REG(Eye, EyeVerify);
|
||||
// --------------Eye END-------------------------------
|
||||
|
||||
// ----------------FillDiagonal-------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(FillDiagonalInferShape) {
|
||||
Shape x_shape = op.GetInputDescByName("x").GetShape();
|
||||
DataType x_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
TensorDesc td = op.GetOutputDescByName("y");
|
||||
td.SetShape(ge::Shape(x_shape));
|
||||
td.SetDataType(x_dtype);
|
||||
(void)op.UpdateOutputDesc("y", td);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(FillDiagonal, FillDiagonalInferShape);
|
||||
// ----------------FillDiagonal END-------------------
|
||||
|
||||
// ----------------MatrixLogarithm--------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(MatrixLogarithmInferShaper) {
|
||||
auto x_shape = op.GetInputDescByName("x").GetShape().GetDims();
|
||||
Shape input_shape = op.GetInputDescByName("x").GetShape();
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
int64_t size_num = op.GetInputDescByName("x").GetShape().GetDimNum();
|
||||
TensorDesc td = op.GetOutputDescByName("y");
|
||||
td.SetShape(ge::Shape(input_shape));
|
||||
td.SetDataType(input_dtype);
|
||||
if (op.UpdateOutputDesc("y", td) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (size_num < 2) {
|
||||
string err_msg = ConcatString("the input[x] should be greater than 2, but get ", size_num, ".");
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
// 注册函数
|
||||
CUST_COMMON_INFER_FUNC_REG(MatrixLogarithm, MatrixLogarithmInferShaper);
|
||||
// ----------------MatrixLogarithm END-------------------
|
||||
|
||||
// ----------------MatrixExp-------------------
|
||||
CUST_COMMON_INFER_FUNC_REG(MatirxExp, OneInOneOutCommonInferShape);
|
||||
// ----------------MatrixExp END-------------------
|
||||
} // namespace ge
|
|
@ -1,44 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/matrix_logarithm.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
namespace ge {
|
||||
// ----------------MatrixLogarithm--------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(MatrixLogarithmInferShaper) {
|
||||
auto x_shape = op.GetInputDescByName("x").GetShape().GetDims();
|
||||
Shape input_shape = op.GetInputDescByName("x").GetShape();
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
int64_t size_num = op.GetInputDescByName("x").GetShape().GetDimNum();
|
||||
TensorDesc td = op.GetOutputDescByName("y");
|
||||
td.SetShape(ge::Shape(input_shape));
|
||||
td.SetDataType(input_dtype);
|
||||
if (op.UpdateOutputDesc("y", td) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (size_num < 2) {
|
||||
string err_msg = ConcatString("the input[x] should be greater than 2, but get ", size_num, ".");
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
// 注册函数
|
||||
CUST_COMMON_INFER_FUNC_REG(MatrixLogarithm, MatrixLogarithmInferShaper);
|
||||
// ----------------MatrixLogarithm END-------------------
|
||||
} // namespace ge
|
|
@ -1,138 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/max_pool_3d_grad_with_argmax_op.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
CUST_IMPLEMT_VERIFIER(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxVerify) {
|
||||
const size_t DIM_SIZE1 = 1;
|
||||
const size_t DIM_SIZE3 = 3;
|
||||
const size_t DIM_SIZE5 = 5;
|
||||
|
||||
std::vector<int32_t> ksizeList;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("ksize", ksizeList)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("ksize");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if ((ksizeList.size() != DIM_SIZE1) && (ksizeList.size() != DIM_SIZE3)) {
|
||||
string excepted_size = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3);
|
||||
std::string err_msg = GetAttrSizeErrMsg("ksizeList", ConcatString(ksizeList.size()), excepted_size);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<int32_t> stridesList;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("strides", stridesList)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("strides");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if ((stridesList.size() != DIM_SIZE1) && (stridesList.size() != DIM_SIZE3)) {
|
||||
string excepted_size = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3);
|
||||
std::string err_msg = GetAttrSizeErrMsg("stridesList", ConcatString(stridesList.size()), excepted_size);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<int32_t> padsList;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("pads", padsList)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("pads");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if ((padsList.size() != DIM_SIZE1) && (padsList.size() != DIM_SIZE3)) {
|
||||
string excepted_size = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3);
|
||||
std::string err_msg = GetAttrSizeErrMsg("padsList", ConcatString(padsList.size()), excepted_size);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<int32_t> dilationList;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("dilation", dilationList)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("dilation");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if ((dilationList.size() != DIM_SIZE1) && (dilationList.size() != DIM_SIZE3) && (dilationList.size() != DIM_SIZE5)) {
|
||||
string excepted_value = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3, " or ", DIM_SIZE5);
|
||||
std::string err_msg = GetAttrSizeErrMsg("dilationList", ConcatString(dilationList.size()), excepted_value);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
bool ceilMode = false;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("ceil_mode", ceilMode)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("ceil_mode");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::string data_format;
|
||||
if (op.GetAttr("data_format", data_format) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "get attr data_format failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (data_format != "NCDHW") {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Attr data_format(%s) only support NCDHW.", data_format.c_str());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int dtype = 0;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("dtype", dtype)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("dtype");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
CHECK_PTR_NULL(op_desc, "op desc", return GRAPH_FAILED);
|
||||
auto grads_desc = op_desc->MutableInputDesc("grads");
|
||||
CHECK_PTR_NULL(grads_desc, "grads desc", return GRAPH_FAILED);
|
||||
vector<int64_t> grads_shape = grads_desc->MutableShape().GetDims();
|
||||
if (grads_shape.size() != DIM_SIZE5 && !IsUnknownRankShape(grads_shape)) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "grads_shape's dim expect: %lu, but real: %lu.", DIM_SIZE5, grads_shape.size());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
TensorDesc inputDesc = op.GetInputDescByName("x");
|
||||
vector<int64_t> inputShape = inputDesc.GetShape().GetDims();
|
||||
if (inputShape.size() != DIM_SIZE5) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "input x's dim expect: %lu, but real: %lu.", DIM_SIZE5, inputShape.size());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_INFERFUNC(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxInferShape) {
|
||||
auto shape = op.GetInputDescByName("x").GetShape();
|
||||
auto shape_dims = shape.GetDims();
|
||||
TensorDesc td = op.GetOutputDescByName("y");
|
||||
td.SetShape(shape);
|
||||
td.SetDataType(op.GetInputDescByName("x").GetDataType());
|
||||
(void)op.UpdateOutputDesc("y", td);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
CUST_INFER_FUNC_REG(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxInferShape);
|
||||
CUST_VERIFY_FUNC_REG(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxVerify);
|
||||
} // namespace ge
|
|
@ -1,112 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/multi_margin_loss_grad_op.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------------MultiMarginLossGrad------------------------
|
||||
CUST_IMPLEMT_VERIFIER(MultiMarginLossGrad, MultiMarginLossGradVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
CUST_VERIFY_FUNC_REG(MultiMarginLossGrad, MultiMarginLossGradVerify);
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(MultiMarginLossGradInferShape) {
|
||||
Shape shape_x = op.GetInputDescByName("x").GetShape();
|
||||
Shape shape_target = op.GetInputDescByName("target").GetShape();
|
||||
TensorDesc tensordesc_weight;
|
||||
DataType x_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
DataType y_grad_dtype = op.GetInputDescByName("y_grad").GetDataType();
|
||||
DataType target_dtype = op.GetInputDescByName("target").GetDataType();
|
||||
if (y_grad_dtype != x_dtype) {
|
||||
string err_msg1 = ConcatString("dtype of input x must be the same as y_grad.");
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (x_dtype != DT_DOUBLE && x_dtype != DT_FLOAT && x_dtype != DT_FLOAT16) {
|
||||
string err_msg1 = ConcatString("dtype of input x must be double, float or float16");
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (target_dtype != DT_INT64) {
|
||||
string err_msg1 = ConcatString("dtype of input target must be int64.");
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (op.TryGetInputDesc("weight", tensordesc_weight) == GRAPH_SUCCESS) {
|
||||
Shape shape_w = op.GetInputDescByName("weight").GetShape();
|
||||
DataType weight_dtype = op.GetInputDescByName("weight").GetDataType();
|
||||
if (weight_dtype != x_dtype) {
|
||||
string err_msg1 = ConcatString("weight should have the same dtype with x.");
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (shape_w.GetDimNum() != 1) {
|
||||
string err_msg1 = ConcatString("rank of weight must be 1, shape_weight.GetDimNum():", shape_w.GetDimNum());
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
if ((shape_x.GetDimNum() != 2) || (shape_target.GetDimNum() != 1)) {
|
||||
string err_msg2 =
|
||||
ConcatString("Rank of x must be 2, rank of target must be 1, shape_x.GetDimNum():", shape_x.GetDimNum(),
|
||||
", shape_target.GetDimNum():", shape_target.GetDimNum());
|
||||
std::string err_msg = OtherErrMsg(err_msg2);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (shape_x.GetDim(0) != (shape_target.GetDim(0))) {
|
||||
string err_msg3 = ConcatString(
|
||||
"shape[0] of x and shape[0] of target must be "
|
||||
"the same, shape_x.GetDim(0):",
|
||||
shape_x.GetDim(0), ", shape_target.GetDim(0):", shape_target.GetDim(0));
|
||||
std::string err_msg = OtherErrMsg(err_msg3);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::string reduction;
|
||||
op.GetAttr("reduction", reduction);
|
||||
if ((reduction != "mean") && (reduction != "sum") && (reduction != "none")) {
|
||||
string expected_reduction_list = ConcatString("mean, sum, none");
|
||||
std::string err_msg = GetInputFormatNotSupportErrMsg("reduction", expected_reduction_list, reduction);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
int64_t p;
|
||||
op.GetAttr("p", p);
|
||||
if ((p != 1) && (p != 2)) {
|
||||
string err_msg4 = ConcatString("The value of p must be 1 or 2, p:", p);
|
||||
std::string err_msg = OtherErrMsg(err_msg4);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc tensordesc_output = op.GetOutputDescByName("x_grad");
|
||||
Shape x_grad_shape = Shape(shape_x);
|
||||
tensordesc_output.SetShape(x_grad_shape);
|
||||
TensorDesc input_desc = op.GetInputDescByName("x");
|
||||
tensordesc_output.SetDataType(input_desc.GetDataType());
|
||||
op.UpdateOutputDesc("x_grad", tensordesc_output);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(MultiMarginLossGrad, MultiMarginLossGradInferShape);
|
||||
// ----------------------MultiMarginLossGrad END------------------------
|
||||
} // namespace ge
|
|
@ -1,117 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/multi_margin_loss_op.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------MultiMarginLoss Begin-------------------
|
||||
CUST_IMPLEMT_VERIFIER(MultiMarginLoss, MultiMarginLossVerify) {
|
||||
Shape shape_x = op.GetInputDescByName("x").GetShape();
|
||||
Shape shape_target = op.GetInputDescByName("target").GetShape();
|
||||
TensorDesc tensordesc_weight;
|
||||
DataType x_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
DataType target_dtype = op.GetInputDescByName("target").GetDataType();
|
||||
if (x_dtype != DT_DOUBLE && x_dtype != DT_FLOAT && x_dtype != DT_FLOAT16) {
|
||||
string err_msg1 = ConcatString("dtype of input x must be double, float or float16.");
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (target_dtype != DT_INT64) {
|
||||
string err_msg1 = ConcatString("dtype of input target must be int64.");
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (op.TryGetInputDesc("weight", tensordesc_weight) == GRAPH_SUCCESS) {
|
||||
Shape shape_w = op.GetInputDescByName("weight").GetShape();
|
||||
DataType weight_dtype = op.GetInputDescByName("weight").GetDataType();
|
||||
if (weight_dtype != x_dtype) {
|
||||
string err_msg1 = ConcatString("weight should have the same dtype with x.");
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (shape_w.GetDimNum() != 1) {
|
||||
string err_msg1 = ConcatString("rank of input weight must be 1, shape_weight.GetDimNum():", shape_w.GetDimNum());
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
if ((shape_x.GetDimNum() != 2) || (shape_target.GetDimNum() != 1)) {
|
||||
string err_msg2 =
|
||||
ConcatString("Rank of x must be 2, rank of target must be 1, shape_x.GetDimNum():", shape_x.GetDimNum(),
|
||||
", shape_target.GetDimNum():", shape_target.GetDimNum());
|
||||
std::string err_msg = OtherErrMsg(err_msg2);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (shape_x.GetDim(0) != (shape_target.GetDim(0))) {
|
||||
string err_msg3 = ConcatString(
|
||||
"shape[0] of x and shape[0] of target must be "
|
||||
"the same, shape_x.GetDim(0):",
|
||||
shape_x.GetDim(0), ", shape_target.GetDim(0):", shape_target.GetDim(0));
|
||||
std::string err_msg = OtherErrMsg(err_msg3);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::string reduction;
|
||||
op.GetAttr("reduction", reduction);
|
||||
if ((reduction != "mean") && (reduction != "sum") && (reduction != "none")) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "The val of reduction is invalid.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
int64_t p;
|
||||
op.GetAttr("p", p);
|
||||
if ((p != 1) && (p != 2)) {
|
||||
string err_msg4 = ConcatString("The value of p must be 1 or 2, p:", p);
|
||||
std::string err_msg = OtherErrMsg(err_msg4);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(MultiMarginLossInferShape) {
|
||||
auto shape_x = op.GetInputDescByName("x").GetShape().GetDims();
|
||||
auto shape_target = op.GetInputDescByName("target").GetShape().GetDims();
|
||||
TensorDesc tensordesc_output = op.GetOutputDescByName("y");
|
||||
Shape y_shape = Shape(shape_target);
|
||||
std::string reduction;
|
||||
op.GetAttr("reduction", reduction);
|
||||
if ((reduction == "mean") || (reduction == "sum")) {
|
||||
Shape scalar_shape;
|
||||
Scalar(scalar_shape);
|
||||
tensordesc_output.SetShape(scalar_shape);
|
||||
}
|
||||
if (reduction == "none") {
|
||||
tensordesc_output.SetShape(y_shape);
|
||||
}
|
||||
TensorDesc input_desc = op.GetInputDescByName("x");
|
||||
tensordesc_output.SetDataType(input_desc.GetDataType());
|
||||
tensordesc_output.SetFormat(FORMAT_ND);
|
||||
op.UpdateOutputDesc("y", tensordesc_output);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(MultiMarginLoss, MultiMarginLossInferShape);
|
||||
CUST_VERIFY_FUNC_REG(MultiMarginLoss, MultiMarginLossVerify);
|
||||
// ----------------MultiMarginLoss END---------------------
|
||||
} // namespace ge
|
|
@ -1,45 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/mvlgamma_grad_op.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------MvlgammaGrad Begin-------------------
|
||||
CUST_IMPLEMT_INFERFUNC(MvlgammaGrad, MvlgammaGradInferShape) {
|
||||
const char *op_name = "MvlgammaGrad";
|
||||
OP_LOGD(op_name, "MvlgammaGradInferShape begin.");
|
||||
TensorDesc tensordesc_input = op.GetInputDescByName("y_grad");
|
||||
Shape input_shape = tensordesc_input.GetShape();
|
||||
std::vector<int64_t> dims_input = input_shape.GetDims();
|
||||
DataType input_dtype = tensordesc_input.GetDataType();
|
||||
|
||||
TensorDesc tensordesc_output1 = op.GetOutputDescByName("x_grad");
|
||||
tensordesc_output1.SetDataType(input_dtype);
|
||||
tensordesc_output1.SetShape(ge::Shape(dims_input));
|
||||
|
||||
(void)op.UpdateOutputDesc("x_grad", tensordesc_output1);
|
||||
OP_LOGD(op_name, "MvlgammaGradInferShape end.");
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(MvlgammaGrad, MvlgammaGradVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
CUST_INFER_FUNC_REG(MvlgammaGrad, MvlgammaGradInferShape);
|
||||
CUST_VERIFY_FUNC_REG(MvlgammaGrad, MvlgammaGradVerify);
|
||||
// ----------------MvlgammaGrad END---------------------
|
||||
} // namespace ge
|
|
@ -1,45 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/mvlgamma_op.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------Mvlgamma Begin-------------------
|
||||
CUST_IMPLEMT_INFERFUNC(Mvlgamma, MvlgammaInferShape) {
|
||||
const char *op_name = "Mvlgamma";
|
||||
OP_LOGD(op_name, "MvlgammaInferShape begin.");
|
||||
TensorDesc tensordesc_input = op.GetInputDescByName("x");
|
||||
Shape input_shape = tensordesc_input.GetShape();
|
||||
std::vector<int64_t> dims_input = input_shape.GetDims();
|
||||
DataType input_dtype = tensordesc_input.GetDataType();
|
||||
|
||||
TensorDesc tensordesc_output1 = op.GetOutputDescByName("y");
|
||||
tensordesc_output1.SetDataType(input_dtype);
|
||||
tensordesc_output1.SetShape(ge::Shape(dims_input));
|
||||
|
||||
(void)op.UpdateOutputDesc("y", tensordesc_output1);
|
||||
OP_LOGD(op_name, "MvlgammaInferShape end.");
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_VERIFIER(Mvlgamma, MvlgammaVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
CUST_INFER_FUNC_REG(Mvlgamma, MvlgammaInferShape);
|
||||
CUST_VERIFY_FUNC_REG(Mvlgamma, MvlgammaVerify);
|
||||
// ----------------Mvlgamma END---------------------
|
||||
} // namespace ge
|
|
@ -0,0 +1,235 @@
|
|||
/*
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/nn_norm_ops.h"
|
||||
#include "custom_op_proto/cust_nn_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------KlDivLossGrad Begin-------------------
|
||||
bool InferShapeAndTypeKlDivLossGrad(Operator &op, const string &input_name, const string &output_name) {
|
||||
TensorDesc output_desc = op.GetOutputDescByName(output_name.c_str());
|
||||
DataType input_dtype = op.GetInputDescByName(input_name.c_str()).GetDataType();
|
||||
Format input_format =
|
||||
static_cast<ge::Format>(ge::GetPrimaryFormat(op.GetInputDescByName(input_name.c_str()).GetFormat()));
|
||||
ge::Shape input_shape = op.GetInputDescByName(input_name.c_str()).GetShape();
|
||||
|
||||
output_desc.SetShape(input_shape);
|
||||
output_desc.SetDataType(input_dtype);
|
||||
output_desc.SetFormat(input_format);
|
||||
op.UpdateOutputDesc(output_name.c_str(), output_desc);
|
||||
return true;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(KlDivLossGradInferShape) {
|
||||
if (InferShapeAndTypeKlDivLossGrad(op, "input", "y")) {
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
OP_LOGE(TbeGetName(op).c_str(), "KL_DIV_LOSS_GRAD Infershape Failed");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
IMPLEMT_VERIFIER(KlDivLossGrad, KlDivLossGradVerify) {
|
||||
if (op.GetInputDescByName("grad").GetDataType() != op.GetInputDescByName("input").GetDataType() ||
|
||||
op.GetInputDescByName("input").GetDataType() != op.GetInputDescByName("target").GetDataType()) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "grad type is not same with input or target");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(KlDivLossGrad, KlDivLossGradInferShape);
|
||||
VERIFY_FUNC_REG(KlDivLossGrad, KlDivLossGradVerify);
|
||||
// ----------------KlDivLossGrad END---------------------
|
||||
|
||||
// ----------------MultiMarginLoss Begin-------------------
|
||||
CUST_IMPLEMT_VERIFIER(MultiMarginLoss, MultiMarginLossVerify) {
|
||||
Shape shape_x = op.GetInputDescByName("x").GetShape();
|
||||
Shape shape_target = op.GetInputDescByName("target").GetShape();
|
||||
TensorDesc tensordesc_weight;
|
||||
DataType x_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
DataType target_dtype = op.GetInputDescByName("target").GetDataType();
|
||||
if (x_dtype != DT_DOUBLE && x_dtype != DT_FLOAT && x_dtype != DT_FLOAT16) {
|
||||
string err_msg1 = ConcatString("dtype of input x must be double, float or float16.");
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (target_dtype != DT_INT64) {
|
||||
string err_msg1 = ConcatString("dtype of input target must be int64.");
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (op.TryGetInputDesc("weight", tensordesc_weight) == GRAPH_SUCCESS) {
|
||||
Shape shape_w = op.GetInputDescByName("weight").GetShape();
|
||||
DataType weight_dtype = op.GetInputDescByName("weight").GetDataType();
|
||||
if (weight_dtype != x_dtype) {
|
||||
string err_msg1 = ConcatString("weight should have the same dtype with x.");
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (shape_w.GetDimNum() != 1) {
|
||||
string err_msg1 = ConcatString("rank of input weight must be 1, shape_weight.GetDimNum():", shape_w.GetDimNum());
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
if ((shape_x.GetDimNum() != 2) || (shape_target.GetDimNum() != 1)) {
|
||||
string err_msg2 =
|
||||
ConcatString("Rank of x must be 2, rank of target must be 1, shape_x.GetDimNum():", shape_x.GetDimNum(),
|
||||
", shape_target.GetDimNum():", shape_target.GetDimNum());
|
||||
std::string err_msg = OtherErrMsg(err_msg2);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (shape_x.GetDim(0) != (shape_target.GetDim(0))) {
|
||||
string err_msg3 = ConcatString(
|
||||
"shape[0] of x and shape[0] of target must be "
|
||||
"the same, shape_x.GetDim(0):",
|
||||
shape_x.GetDim(0), ", shape_target.GetDim(0):", shape_target.GetDim(0));
|
||||
std::string err_msg = OtherErrMsg(err_msg3);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::string reduction;
|
||||
op.GetAttr("reduction", reduction);
|
||||
if ((reduction != "mean") && (reduction != "sum") && (reduction != "none")) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "The val of reduction is invalid.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
int64_t p;
|
||||
op.GetAttr("p", p);
|
||||
if ((p != 1) && (p != 2)) {
|
||||
string err_msg4 = ConcatString("The value of p must be 1 or 2, p:", p);
|
||||
std::string err_msg = OtherErrMsg(err_msg4);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(MultiMarginLossInferShape) {
|
||||
auto shape_x = op.GetInputDescByName("x").GetShape().GetDims();
|
||||
auto shape_target = op.GetInputDescByName("target").GetShape().GetDims();
|
||||
TensorDesc tensordesc_output = op.GetOutputDescByName("y");
|
||||
Shape y_shape = Shape(shape_target);
|
||||
std::string reduction;
|
||||
op.GetAttr("reduction", reduction);
|
||||
if ((reduction == "mean") || (reduction == "sum")) {
|
||||
Shape scalar_shape;
|
||||
Scalar(scalar_shape);
|
||||
tensordesc_output.SetShape(scalar_shape);
|
||||
}
|
||||
if (reduction == "none") {
|
||||
tensordesc_output.SetShape(y_shape);
|
||||
}
|
||||
TensorDesc input_desc = op.GetInputDescByName("x");
|
||||
tensordesc_output.SetDataType(input_desc.GetDataType());
|
||||
tensordesc_output.SetFormat(FORMAT_ND);
|
||||
op.UpdateOutputDesc("y", tensordesc_output);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(MultiMarginLoss, MultiMarginLossInferShape);
|
||||
CUST_VERIFY_FUNC_REG(MultiMarginLoss, MultiMarginLossVerify);
|
||||
// ----------------MultiMarginLoss END---------------------
|
||||
|
||||
// ----------------------MultiMarginLossGrad------------------------
|
||||
CUST_IMPLEMT_VERIFIER(MultiMarginLossGrad, MultiMarginLossGradVerify) { return GRAPH_SUCCESS; }
|
||||
|
||||
CUST_VERIFY_FUNC_REG(MultiMarginLossGrad, MultiMarginLossGradVerify);
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(MultiMarginLossGradInferShape) {
|
||||
Shape shape_x = op.GetInputDescByName("x").GetShape();
|
||||
Shape shape_target = op.GetInputDescByName("target").GetShape();
|
||||
TensorDesc tensordesc_weight;
|
||||
DataType x_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
DataType y_grad_dtype = op.GetInputDescByName("y_grad").GetDataType();
|
||||
DataType target_dtype = op.GetInputDescByName("target").GetDataType();
|
||||
if (y_grad_dtype != x_dtype) {
|
||||
string err_msg1 = ConcatString("dtype of input x must be the same as y_grad.");
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (x_dtype != DT_DOUBLE && x_dtype != DT_FLOAT && x_dtype != DT_FLOAT16) {
|
||||
string err_msg1 = ConcatString("dtype of input x must be double, float or float16");
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (target_dtype != DT_INT64) {
|
||||
string err_msg1 = ConcatString("dtype of input target must be int64.");
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (op.TryGetInputDesc("weight", tensordesc_weight) == GRAPH_SUCCESS) {
|
||||
Shape shape_w = op.GetInputDescByName("weight").GetShape();
|
||||
DataType weight_dtype = op.GetInputDescByName("weight").GetDataType();
|
||||
if (weight_dtype != x_dtype) {
|
||||
string err_msg1 = ConcatString("weight should have the same dtype with x.");
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (shape_w.GetDimNum() != 1) {
|
||||
string err_msg1 = ConcatString("rank of weight must be 1, shape_weight.GetDimNum():", shape_w.GetDimNum());
|
||||
std::string err_msg = OtherErrMsg(err_msg1);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
if ((shape_x.GetDimNum() != 2) || (shape_target.GetDimNum() != 1)) {
|
||||
string err_msg2 =
|
||||
ConcatString("Rank of x must be 2, rank of target must be 1, shape_x.GetDimNum():", shape_x.GetDimNum(),
|
||||
", shape_target.GetDimNum():", shape_target.GetDimNum());
|
||||
std::string err_msg = OtherErrMsg(err_msg2);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (shape_x.GetDim(0) != (shape_target.GetDim(0))) {
|
||||
string err_msg3 = ConcatString(
|
||||
"shape[0] of x and shape[0] of target must be "
|
||||
"the same, shape_x.GetDim(0):",
|
||||
shape_x.GetDim(0), ", shape_target.GetDim(0):", shape_target.GetDim(0));
|
||||
std::string err_msg = OtherErrMsg(err_msg3);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::string reduction;
|
||||
op.GetAttr("reduction", reduction);
|
||||
if ((reduction != "mean") && (reduction != "sum") && (reduction != "none")) {
|
||||
string expected_reduction_list = ConcatString("mean, sum, none");
|
||||
std::string err_msg = GetInputFormatNotSupportErrMsg("reduction", expected_reduction_list, reduction);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
int64_t p;
|
||||
op.GetAttr("p", p);
|
||||
if ((p != 1) && (p != 2)) {
|
||||
string err_msg4 = ConcatString("The value of p must be 1 or 2, p:", p);
|
||||
std::string err_msg = OtherErrMsg(err_msg4);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc tensordesc_output = op.GetOutputDescByName("x_grad");
|
||||
Shape x_grad_shape = Shape(shape_x);
|
||||
tensordesc_output.SetShape(x_grad_shape);
|
||||
TensorDesc input_desc = op.GetInputDescByName("x");
|
||||
tensordesc_output.SetDataType(input_desc.GetDataType());
|
||||
op.UpdateOutputDesc("x_grad", tensordesc_output);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(MultiMarginLossGrad, MultiMarginLossGradInferShape);
|
||||
// ----------------------MultiMarginLossGrad END------------------------
|
||||
} // namespace ge
|
|
@ -0,0 +1,287 @@
|
|||
/*
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/nn_pooling_ops.h"
|
||||
#include "custom_op_proto/cust_nn_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
// -------------------DataFormatVecPermute---------------------
|
||||
IMPLEMT_INFERFUNC(DataFormatVecPermute, DataFormatVecPermuteInfer) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto x_desc = op_desc->MutableInputDesc(0);
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> range;
|
||||
if (x_desc->GetShapeRange(range) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
DataType y_type = x_desc->GetDataType();
|
||||
|
||||
auto y_desc = op_desc->MutableOutputDesc(0);
|
||||
y_desc->SetShape(x_desc->GetShape());
|
||||
y_desc->SetShapeRange(range);
|
||||
y_desc->SetDataType(y_type);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(DataFormatVecPermute, DataFormatVecPermuteInfer);
|
||||
// -------------------DataFormatVecPermute End---------------------
|
||||
|
||||
// -------------------MaxPool3DWithArgmax---------------------
|
||||
IMPLEMT_INFERFUNC(MaxPool3DWithArgmax, MaxPool3DWithArgmaxInferShape) {
|
||||
TensorDesc inputDesc = op.GetInputDescByName("x");
|
||||
auto inputShape = inputDesc.GetShape().GetDims();
|
||||
DataType inputDtype = inputDesc.GetDataType();
|
||||
TensorDesc argmaxDesc = op.GetOutputDescByName("argmax");
|
||||
TensorDesc outputDesc = op.GetOutputDescByName("y");
|
||||
std::vector<int64_t> stridesList;
|
||||
op.GetAttr("strides", stridesList);
|
||||
std::vector<int64_t> kernelList;
|
||||
op.GetAttr("ksize", kernelList);
|
||||
int64_t dOut = (inputShape[1] - kernelList[2]) / stridesList[2] + 1;
|
||||
int64_t hOut = (inputShape[3] - kernelList[3]) / stridesList[3] + 1;
|
||||
int64_t wOut = (inputShape[4] - kernelList[4]) / stridesList[4] + 1;
|
||||
int64_t alignedBmLine;
|
||||
alignedBmLine = (wOut * hOut % 16 == 0) ? (wOut * hOut) : (((int64_t)(wOut * hOut / 16) + 1) * 16);
|
||||
std::vector<int64_t> argShapeVec;
|
||||
argShapeVec.push_back(inputShape[0]);
|
||||
argShapeVec.push_back(dOut);
|
||||
argShapeVec.push_back(inputShape[2] * kernelList[2] * kernelList[3] * kernelList[4]);
|
||||
argShapeVec.push_back((int64_t)(alignedBmLine / 16));
|
||||
argShapeVec.push_back(inputShape[5]);
|
||||
Shape argmaxShape(argShapeVec);
|
||||
argmaxDesc.SetShape(argmaxShape);
|
||||
argmaxDesc.SetDataType(DT_UINT16);
|
||||
(void)op.UpdateOutputDesc("argmax", argmaxDesc);
|
||||
std::vector<int64_t> outShapeVec{inputShape[0], dOut, inputShape[2], hOut, wOut, inputShape[5]};
|
||||
Shape outputShape(outShapeVec);
|
||||
outputDesc.SetShape(outputShape);
|
||||
outputDesc.SetDataType(inputDtype);
|
||||
(void)op.UpdateOutputDesc("y", outputDesc);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_VERIFIER(MaxPool3DWithArgmax, MaxPool3DWithArgmaxVerify) {
|
||||
// verify in infer func
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(MaxPool3DWithArgmax, MaxPool3DWithArgmaxInferShape);
|
||||
VERIFY_FUNC_REG(MaxPool3DWithArgmax, MaxPool3DWithArgmaxVerify);
|
||||
//-------------------MaxPool3DWithArgmax---------------------
|
||||
|
||||
//-------------------FractionalMaxPool---------------------
|
||||
IMPLEMT_INFERFUNC(FractionalMaxPool, FractionalMaxPoolInfer) {
|
||||
auto tensor = op.get_input_desc_x();
|
||||
Shape input_value;
|
||||
|
||||
if (WithRank(tensor, 4, input_value, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
|
||||
TbeGetName(op),
|
||||
ConcatString("call WithRank failed, ", GetShapeErrMsg(0, DebugString(tensor.GetShape().GetDims()), "4D")));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::vector<float> pooling_ratio;
|
||||
pooling_ratio = op.get_attr_pooling_ratio();
|
||||
if (pooling_ratio.size() != 4) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
TbeGetName(op), GetAttrSizeErrMsg("pooling_ratio", DebugString(tensor.GetShape().GetDims()), "4D"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<int64_t> output_dims;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
int64_t dim = input_value.GetDim(i);
|
||||
if (dim != UNKNOWN_DIM) {
|
||||
auto real_dim = static_cast<int64_t>(dim / pooling_ratio[i]);
|
||||
if (real_dim < 0) {
|
||||
string err_msg = ConcatString("size computed for ", i, "th dim of output[y] is ", real_dim);
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
output_dims.push_back(real_dim);
|
||||
} else {
|
||||
output_dims.push_back(UNKNOWN_DIM);
|
||||
}
|
||||
}
|
||||
|
||||
TensorDesc y_desc = op.GetOutputDescByName("y");
|
||||
y_desc.SetShape(Shape(output_dims));
|
||||
y_desc.SetDataType(op.GetInputDescByName("x").GetDataType());
|
||||
if (op.UpdateOutputDesc("y", y_desc) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("update output[y] desc failed."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
TensorDesc row_pooling_desc = op.GetOutputDescByName("row_pooling_sequence");
|
||||
row_pooling_desc.SetShape(Shape({output_dims[1] + 1}));
|
||||
row_pooling_desc.SetDataType(DT_INT64);
|
||||
if (op.UpdateOutputDesc("row_pooling_sequence", row_pooling_desc) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("update output[row_pooling_sequence] desc failed."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
TensorDesc col_pooling_desc = op.GetOutputDescByName("col_pooling_sequence");
|
||||
col_pooling_desc.SetShape(Shape({output_dims[2] + 1}));
|
||||
col_pooling_desc.SetDataType(DT_INT64);
|
||||
if (op.UpdateOutputDesc("col_pooling_sequence", col_pooling_desc) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("update output[col_pooling_sequence] desc failed."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(FractionalMaxPool, FractionalMaxPoolInfer);
|
||||
//-------------------FractionalMaxPool END---------------------
|
||||
|
||||
//-------------------FractionalMaxPoolGrad---------------------
|
||||
IMPLEMT_INFERFUNC(FractionalMaxPoolGrad, FractionalMaxPoolGradInfer) {
|
||||
Shape input_shape;
|
||||
if (WithRank(op.GetInputDesc(0), 4, input_shape, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(
|
||||
TbeGetName(op), ConcatString("call WithRank failed, ",
|
||||
GetShapeErrMsg(0, DebugString(op.GetInputDesc(0).GetShape().GetDims()), "4D")));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
auto type = op.GetInputDescByName("orig_input").GetDataType();
|
||||
TensorDesc output_desc = op.GetOutputDescByName("y");
|
||||
output_desc.SetShape(Shape(input_shape));
|
||||
output_desc.SetDataType(type);
|
||||
if (op.UpdateOutputDesc("y", output_desc) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), std::string("update output[y] desc failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(FractionalMaxPoolGrad, FractionalMaxPoolGradInfer);
|
||||
//-------------------FractionalMaxPoolGrad END---------------------
|
||||
|
||||
//-------------------MaxPool3DGradWithArgMax---------------------
|
||||
CUST_IMPLEMT_VERIFIER(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxVerify) {
|
||||
const size_t DIM_SIZE1 = 1;
|
||||
const size_t DIM_SIZE3 = 3;
|
||||
const size_t DIM_SIZE5 = 5;
|
||||
|
||||
std::vector<int32_t> ksizeList;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("ksize", ksizeList)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("ksize");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if ((ksizeList.size() != DIM_SIZE1) && (ksizeList.size() != DIM_SIZE3)) {
|
||||
string excepted_size = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3);
|
||||
std::string err_msg = GetAttrSizeErrMsg("ksizeList", ConcatString(ksizeList.size()), excepted_size);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<int32_t> stridesList;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("strides", stridesList)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("strides");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if ((stridesList.size() != DIM_SIZE1) && (stridesList.size() != DIM_SIZE3)) {
|
||||
string excepted_size = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3);
|
||||
std::string err_msg = GetAttrSizeErrMsg("stridesList", ConcatString(stridesList.size()), excepted_size);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<int32_t> padsList;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("pads", padsList)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("pads");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if ((padsList.size() != DIM_SIZE1) && (padsList.size() != DIM_SIZE3)) {
|
||||
string excepted_size = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3);
|
||||
std::string err_msg = GetAttrSizeErrMsg("padsList", ConcatString(padsList.size()), excepted_size);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::vector<int32_t> dilationList;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("dilation", dilationList)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("dilation");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if ((dilationList.size() != DIM_SIZE1) && (dilationList.size() != DIM_SIZE3) && (dilationList.size() != DIM_SIZE5)) {
|
||||
string excepted_value = ConcatString(DIM_SIZE1, " or ", DIM_SIZE3, " or ", DIM_SIZE5);
|
||||
std::string err_msg = GetAttrSizeErrMsg("dilationList", ConcatString(dilationList.size()), excepted_value);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
bool ceilMode = false;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("ceil_mode", ceilMode)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("ceil_mode");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
std::string data_format;
|
||||
if (op.GetAttr("data_format", data_format) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "get attr data_format failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (data_format != "NCDHW") {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Attr data_format(%s) only support NCDHW.", data_format.c_str());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int dtype = 0;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("dtype", dtype)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("dtype");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op).c_str(), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
CHECK_PTR_NULL(op_desc, "op desc", return GRAPH_FAILED);
|
||||
auto grads_desc = op_desc->MutableInputDesc("grads");
|
||||
CHECK_PTR_NULL(grads_desc, "grads desc", return GRAPH_FAILED);
|
||||
vector<int64_t> grads_shape = grads_desc->MutableShape().GetDims();
|
||||
if (grads_shape.size() != DIM_SIZE5 && !IsUnknownRankShape(grads_shape)) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "grads_shape's dim expect: %lu, but real: %lu.", DIM_SIZE5, grads_shape.size());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
TensorDesc inputDesc = op.GetInputDescByName("x");
|
||||
vector<int64_t> inputShape = inputDesc.GetShape().GetDims();
|
||||
if (inputShape.size() != DIM_SIZE5) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "input x's dim expect: %lu, but real: %lu.", DIM_SIZE5, inputShape.size());
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_IMPLEMT_INFERFUNC(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxInferShape) {
|
||||
auto shape = op.GetInputDescByName("x").GetShape();
|
||||
auto shape_dims = shape.GetDims();
|
||||
TensorDesc td = op.GetOutputDescByName("y");
|
||||
td.SetShape(shape);
|
||||
td.SetDataType(op.GetInputDescByName("x").GetDataType());
|
||||
(void)op.UpdateOutputDesc("y", td);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
CUST_INFER_FUNC_REG(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxInferShape);
|
||||
CUST_VERIFY_FUNC_REG(MaxPool3DGradWithArgmax, MaxPool3DGradWithArgmaxVerify);
|
||||
//-------------------MaxPool3DGradWithArgMax---------------------
|
||||
|
||||
} // namespace ge
|
|
@ -0,0 +1,94 @@
|
|||
/**
|
||||
* Copyright (c) 2023 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/random_ops.h"
|
||||
#include "inc/ops/stateful_random_ops.h"
|
||||
#include "custom_op_proto/cust_random_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
IMPLEMT_INFERFUNC(NonDeterministicInts, NonDeterministicIntsInfer) {
|
||||
Shape shape;
|
||||
Tensor shape_tensor;
|
||||
if (op.GetInputConstData("shape", shape_tensor) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Get shape_tensor error.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (MakeShapeFromShapeTensor(shape_tensor, shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Get shape error.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
DataType dtype;
|
||||
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Get attr dtype error.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
TensorDesc outputDesc = op.GetOutputDescByName("y");
|
||||
outputDesc.SetDataType(dtype);
|
||||
outputDesc.SetShape(shape);
|
||||
return op.UpdateOutputDesc("y", outputDesc);
|
||||
}
|
||||
|
||||
INFER_FUNC_REG(NonDeterministicInts, NonDeterministicIntsInfer);
|
||||
|
||||
// ----------------LogNormalReverse-------------------
|
||||
// Obtains the processing function of the output tensor description.
|
||||
IMPLEMT_COMMON_INFERFUNC(LogNormalReverseInferShape) {
|
||||
TensorDesc v_output_desc = op.GetOutputDescByName("y");
|
||||
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
Format input_format = op.GetInputDescByName("x").GetFormat();
|
||||
ge::Shape shape_input = op.GetInputDescByName("x").GetShape();
|
||||
|
||||
v_output_desc.SetShape(shape_input);
|
||||
v_output_desc.SetDataType(input_dtype);
|
||||
v_output_desc.SetFormat(input_format);
|
||||
|
||||
if (op.UpdateOutputDesc("y", v_output_desc) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(LogNormalReverse, LogNormalReverseInferShape);
|
||||
// ----------------LogNormalReverse END-------------------
|
||||
|
||||
// ----------------Dropout2D-------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(Dropout2DInferShape) {
|
||||
TensorDesc output_desc = op.GetOutputDescByName("output");
|
||||
TensorDesc mask_desc = op.GetOutputDescByName("mask");
|
||||
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
ge::Shape shape_input = op.GetInputDescByName("x").GetShape();
|
||||
|
||||
output_desc.SetShape(shape_input);
|
||||
output_desc.SetDataType(input_dtype);
|
||||
mask_desc.SetShape(shape_input);
|
||||
mask_desc.SetDataType(DT_BOOL);
|
||||
|
||||
if (op.UpdateOutputDesc("output", output_desc) != GRAPH_SUCCESS ||
|
||||
op.UpdateOutputDesc("mask", mask_desc) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
CUST_COMMON_INFER_FUNC_REG(Dropout2D, Dropout2DInferShape);
|
||||
// ----------------Dropout2D END-------------------
|
||||
} // namespace ge
|
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime_util.h"
|
||||
#include "utils/op_util.h"
|
||||
#include "utils/op_log.h"
|
||||
#include "utils/op_const.h"
|
||||
|
||||
using namespace ge;
|
||||
namespace ops {
|
||||
// -------------------Diag Ops START---------------------
|
||||
static constexpr size_t DIAG_IN_X_IDX = 0;
|
||||
static constexpr size_t DIAG_OUT_Y_IDX = 0;
|
||||
static constexpr size_t INT_DATA_2 = 2;
|
||||
|
||||
ge::graphStatus Infershape4Diag(gert::InferShapeContext *context) {
|
||||
OP_LOGD(context->GetNodeName(), "Begin to do DiagInfershape.");
|
||||
const gert::Shape *input_x_shape = context->GetInputShape(DIAG_IN_X_IDX);
|
||||
OPS_CHECK_NULL_WITH_CONTEXT(context, input_x_shape);
|
||||
gert::Shape *output_y_shape = context->GetOutputShape(DIAG_OUT_Y_IDX);
|
||||
OPS_CHECK_NULL_WITH_CONTEXT(context, output_y_shape);
|
||||
|
||||
size_t x_dim_num = input_x_shape->GetDimNum();
|
||||
|
||||
output_y_shape->SetDimNum(0);
|
||||
for (size_t i = 0; i < INT_DATA_2; i++) {
|
||||
for (size_t j = 0; j < x_dim_num; j++) {
|
||||
output_y_shape->AppendDim(input_x_shape->GetDim(j));
|
||||
}
|
||||
}
|
||||
|
||||
OP_LOGD(context->GetNodeName(), "output_y_shape = %s.", ToString(*output_y_shape).c_str());
|
||||
OP_LOGD(context->GetNodeName(), "End to do DiagInfershape.");
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP(Diag).InferShape(Infershape4Diag);
|
||||
// -------------------Diag Ops END---------------------
|
||||
|
||||
// -------------------DiagPart Ops START---------------------
|
||||
ge::graphStatus Infershape4DiagPart(gert::InferShapeContext *context) {
|
||||
OP_LOGD(context->GetNodeName(), "Begin to do DiagPartInfershape.");
|
||||
const gert::Shape *input_x_shape = context->GetInputShape(DIAG_IN_X_IDX);
|
||||
OPS_CHECK_NULL_WITH_CONTEXT(context, input_x_shape);
|
||||
|
||||
int64_t input_to_output_dims_times = 2;
|
||||
OP_CHECK(input_x_shape->GetDimNum() % 2 != 0,
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(context->GetNodeName(),
|
||||
"input_x_shape->GetDimNum() % 2 != 0 is not supported."),
|
||||
return ge::GRAPH_FAILED);
|
||||
int64_t output_shape_len = input_x_shape->GetDimNum() / input_to_output_dims_times;
|
||||
|
||||
gert::Shape *output_y_shape = context->GetOutputShape(DIAG_OUT_Y_IDX);
|
||||
OPS_CHECK_NULL_WITH_CONTEXT(context, output_y_shape);
|
||||
|
||||
output_y_shape->SetDimNum(output_shape_len);
|
||||
for (int64_t i = 0; i < output_shape_len; i++) {
|
||||
output_y_shape->SetDim(i, input_x_shape->GetDim(i));
|
||||
}
|
||||
|
||||
OP_LOGD(context->GetNodeName(), "output_y_shape = %s.", ToString(*output_y_shape).c_str());
|
||||
OP_LOGD(context->GetNodeName(), "End to do DiagPartInfershape.");
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPL_OP(DiagPart).InferShape(Infershape4DiagPart);
|
||||
// -------------------DiagPart Ops END---------------------
|
||||
} // namespace ops
|
|
@ -0,0 +1,165 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime_util.h"
|
||||
#include "utils/op_util.h"
|
||||
|
||||
using namespace ge;
|
||||
namespace ops {
|
||||
ge::graphStatus InferShape4Elewise(gert::InferShapeContext *context) {
|
||||
auto in_shape = context->GetInputShape(0);
|
||||
OPS_CHECK_NULL_WITH_CONTEXT(context, in_shape);
|
||||
auto out_shape = context->GetOutputShape(0);
|
||||
OPS_CHECK_NULL_WITH_CONTEXT(context, out_shape);
|
||||
|
||||
if (IsUnknownRank(in_shape)) {
|
||||
OP_LOGD(context->GetNodeName(), "input shape is UnknownRank, set output shape to (-2, )");
|
||||
return SetUnknownRank(out_shape);
|
||||
}
|
||||
|
||||
*out_shape = *in_shape;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus CopyShapeInput2OutputWithIdx(gert::InferShapeContext *context, int64_t input_idx, int64_t output_idx) {
|
||||
auto in_shape = context->GetInputShape(input_idx);
|
||||
OPS_CHECK_NULL_WITH_CONTEXT(context, in_shape);
|
||||
auto out_shape = context->GetOutputShape(output_idx);
|
||||
OPS_CHECK_NULL_WITH_CONTEXT(context, out_shape);
|
||||
*out_shape = *in_shape;
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
ge::graphStatus InferShape4InIdxAndOutVector(gert::InferShapeContext *context, int64_t input_idx,
|
||||
const std::vector<int64_t> &output_idxs) {
|
||||
auto in_shape = context->GetInputShape(input_idx);
|
||||
OPS_CHECK_NULL_WITH_CONTEXT(context, in_shape);
|
||||
for (int64_t idx : output_idxs) {
|
||||
auto out_shape = context->GetOutputShape(idx);
|
||||
OPS_CHECK_NULL_WITH_CONTEXT(context, out_shape);
|
||||
*out_shape = *in_shape;
|
||||
}
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
std::string ShapeCannotBroadcastMsg(const gert::Shape &shape1, const gert::Shape &shape2) {
|
||||
std::string res = "shape ";
|
||||
res += ToString(shape1);
|
||||
res += " and ";
|
||||
res += ToString(shape2);
|
||||
res += " cannot broadcast!";
|
||||
return res;
|
||||
}
|
||||
|
||||
static bool BroadcastDim(int64_t &dim1, const int64_t dim2) {
|
||||
if (dim1 == dim2) {
|
||||
return true;
|
||||
}
|
||||
/* column is dim1, row is dim2, matrix value is broadcast(dim1, dim2)
|
||||
dim 0 1 d2
|
||||
0 0 0 E
|
||||
1 0 1 d2
|
||||
d1 E d1 E
|
||||
*/
|
||||
if ((dim1 != 1) && (dim2 != 1)) {
|
||||
string msg = ConcatString(dim1, " and ", dim2, " cannot broadcast!");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastDim", msg);
|
||||
return false;
|
||||
}
|
||||
dim1 = (dim1 == 1) ? dim2 : dim1;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: broadcast new shape to output shape
|
||||
* @param [in] shape: const gert::Shape*, new shape to broadcast
|
||||
* @param [in/out] shape_output: gert::Shape*, output shape
|
||||
* @return succeed or not
|
||||
*/
|
||||
static bool BroadcastShapeToOutShape(const gert::Shape *shape, gert::Shape *shape_output) {
|
||||
OP_LOGD("BroadcastShapeToOutShape", "start broadcast %s to %s!", ToString(*shape).c_str(),
|
||||
ToString(*shape_output).c_str());
|
||||
size_t shape_len = shape->GetDimNum();
|
||||
size_t shape_y_len = shape_output->GetDimNum();
|
||||
if (shape_len > shape_y_len) {
|
||||
shape_output->SetDimNum(shape_len);
|
||||
size_t len_sub = shape_len - shape_y_len;
|
||||
for (size_t i = shape_y_len; i > 0; i--) {
|
||||
int64_t dim1 = shape->GetDim(len_sub + i - 1);
|
||||
int64_t dim2 = shape_output->GetDim(i - 1);
|
||||
if (!BroadcastDim(dim1, dim2)) {
|
||||
string msg = ConcatString(dim1, " and ", dim2, " cannot broadcast!");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastShapeToOutShape", msg);
|
||||
return false;
|
||||
}
|
||||
shape_output->SetDim(len_sub + i - 1, dim1);
|
||||
}
|
||||
for (size_t i = 0; i < len_sub; i++) {
|
||||
shape_output->SetDim(i, shape->GetDim(i));
|
||||
}
|
||||
} else {
|
||||
auto len_sub = shape_y_len - shape_len;
|
||||
for (size_t i = 0; i < shape_len; i++) {
|
||||
int64_t dim1 = shape_output->GetDim(len_sub + i);
|
||||
int64_t dim2 = shape->GetDim(i);
|
||||
if (!BroadcastDim(dim1, dim2)) {
|
||||
string msg = ConcatString(dim1, " and ", dim2, " cannot broadcast!");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastShapeToOutShape", msg);
|
||||
return false;
|
||||
}
|
||||
shape_output->SetDim(len_sub + i, dim1);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BroadcastShape(const gert::Shape *in1_shape, const gert::Shape *in2_shape, gert::Shape *out_shape) {
|
||||
*out_shape = *in1_shape;
|
||||
|
||||
OP_CHECK(!BroadcastShapeToOutShape(in2_shape, out_shape),
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastShape", ShapeCannotBroadcastMsg(*in2_shape, *in1_shape)),
|
||||
return false);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BroadcastShape(const std::vector<const gert::Shape *> &in_shapes, gert::Shape *out_shape) {
|
||||
size_t size = in_shapes.size();
|
||||
OP_CHECK(size == 0, VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastShape", "in_shapes is empty!"), return false);
|
||||
*out_shape = *in_shapes[0];
|
||||
|
||||
for (size_t i = 1; i < size; i++) {
|
||||
OP_CHECK(!BroadcastShapeToOutShape(in_shapes[i], out_shape),
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastShape", ShapeCannotBroadcastMsg(*in_shapes[i], *out_shape)),
|
||||
return false);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BroadcastShape(const gert::Shape **in_shapes, size_t size, gert::Shape *out_shape) {
|
||||
OP_CHECK(size == 0, VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastShape", "in_shapes is empty!"), return false);
|
||||
*out_shape = *in_shapes[0];
|
||||
|
||||
for (size_t i = 1; i < size; i++) {
|
||||
OP_CHECK(!BroadcastShapeToOutShape(in_shapes[i], out_shape),
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT("BroadcastShape", ShapeCannotBroadcastMsg(*in_shapes[i], *out_shape)),
|
||||
return false);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace ops
|
|
@ -0,0 +1,114 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file runtime_util.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef CUSTOMIZE_OP_PROTO_RUNTIME_RUNTIME_UTIL_H_
|
||||
#define CUSTOMIZE_OP_PROTO_RUNTIME_RUNTIME_UTIL_H_
|
||||
|
||||
#include "utils/context_util.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "runtime/continuous_vector.h"
|
||||
#include "runtime/infer_shape_context.h"
|
||||
#include "runtime/storage_shape.h"
|
||||
#include "error_util.h"
|
||||
#include "op_util.h"
|
||||
|
||||
namespace ops {
|
||||
using QuickVector = gert::Shape;
|
||||
constexpr int64_t UNKNOWN_DIM_VALUE_ = -1;
|
||||
constexpr int64_t UNKNOWN_RANK_DIM_VALUE_ = -2;
|
||||
|
||||
// Do infershape for OP which is single-input single-output and in-shape equal out-shape.
|
||||
ge::graphStatus InferShape4Elewise(gert::InferShapeContext *context);
|
||||
|
||||
/*
|
||||
* @brief: get output shape
|
||||
* @param [in] context: gert::InferShapeContext
|
||||
* @param [in] input_idx: constvalue input index
|
||||
* @param [in] output_idx: constvalue output index
|
||||
* @return vector<int64_t>: success or failed
|
||||
*/
|
||||
ge::graphStatus CopyShapeInput2OutputWithIdx(gert::InferShapeContext *context, int64_t input_idx, int64_t output_idx);
|
||||
|
||||
/*
|
||||
* @brief: get output shape
|
||||
* @param [in] context: gert::InferShapeContext
|
||||
* @param [in] input_idx: constvalue input index
|
||||
* @param [in] output_idxs: constvalue output indexes,vector<int64_t>
|
||||
* @return graphStatus: success or failed
|
||||
*/
|
||||
ge::graphStatus InferShape4InIdxAndOutVector(gert::InferShapeContext *context, int64_t input_idx,
|
||||
const std::vector<int64_t> &output_idxs);
|
||||
|
||||
std::string ShapeCannotBroadcastMsg(const gert::Shape &shape1, const gert::Shape &shape2);
|
||||
/*
|
||||
* @brief: broadcast new shape to output shape
|
||||
* @param [in] shape: const gert::Shape*, new shape to broadcast
|
||||
* @param [in/out] shape_output: gert::Shape*, output shape
|
||||
* @return succeed or not
|
||||
*/
|
||||
bool BroadcastShape(const gert::Shape *in1_shape, const gert::Shape *in2_shape, gert::Shape *out_shape);
|
||||
bool BroadcastShape(const std::vector<const gert::Shape *> &in_shapes, gert::Shape *out_shape);
|
||||
bool BroadcastShape(const gert::Shape **in_shapes, size_t size, gert::Shape *out_shape);
|
||||
|
||||
/*
|
||||
* @brief: set all the output shape to [-1, -1, ....] with input rank
|
||||
* @param [in] rank: the output input rank
|
||||
* @param [out] output_shape: the output shape ptr
|
||||
* @return ge::graphStatus
|
||||
*/
|
||||
inline ge::graphStatus SetAllUnknownDim(const int64_t rank, gert::Shape *output_shape) {
|
||||
OP_CHECK(output_shape == nullptr, OP_LOGD("SetAllUnknownDim", "the output_shape is nullptr, return failed"),
|
||||
return ge::GRAPH_FAILED);
|
||||
|
||||
output_shape->SetDimNum(rank);
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
output_shape->SetDim(i, UNKNOWN_DIM_VALUE_);
|
||||
}
|
||||
OP_LOGD("SetAllUnknownDim", "set all dim = -1, output = %s", ToString(*output_shape).c_str());
|
||||
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: set output shape to [-2]
|
||||
* @param [out] output_shape: the output shape ptr
|
||||
* @return ge::graphStatus
|
||||
*/
|
||||
inline ge::graphStatus SetUnknownRank(gert::Shape *output_shape) {
|
||||
OP_CHECK(output_shape == nullptr, OP_LOGD("SetUnknownRank", "the output_shape is nullptr, return failed"),
|
||||
return ge::GRAPH_FAILED);
|
||||
output_shape->SetDimNum(0);
|
||||
output_shape->AppendDim(UNKNOWN_RANK_DIM_VALUE_);
|
||||
|
||||
OP_LOGD("SetUnknownRank", "set unknown rank = -2, output = %s", ToString(*output_shape).c_str());
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: check whether the output shape is unknown rank
|
||||
* @param [out] output_shape: the output shape ptr
|
||||
* @return ge::graphStatus
|
||||
*/
|
||||
inline bool IsUnknownRank(const gert::Shape *check_shape) {
|
||||
return check_shape->GetDimNum() == 1 && check_shape->GetDim(0) == UNKNOWN_RANK_DIM_VALUE_;
|
||||
}
|
||||
} // namespace ops
|
||||
|
||||
#endif // CUSTOMIZE_OP_PROTO_RUNTIME_RUNTIME_UTIL_H_
|
|
@ -0,0 +1,165 @@
|
|||
/*
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/selection_ops.h"
|
||||
#include "custom_op_proto/cust_array_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
// ----------------CumulativeLogsumexp-------------------
|
||||
IMPLEMT_COMMON_INFERFUNC(CumulativeLogsumexpInferShape) {
|
||||
TensorDesc output_desc = op.GetOutputDescByName("y");
|
||||
output_desc.SetShape(op.GetInputDescByName("x").GetShape());
|
||||
output_desc.SetDataType(op.GetInputDescByName("x").GetDataType());
|
||||
op.UpdateOutputDesc("y", output_desc);
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(CumulativeLogsumexp, CumulativeLogsumexpInferShape);
|
||||
// ----------------CumulativeLogsumexp END-------------------
|
||||
|
||||
// ----------------GatherNd-------------------
|
||||
bool CheckGatherNdInputIndicesSize(const Operator &op, const string &input_name) {
|
||||
auto indices_shape = OpDescUtils::GetOpDescFromOperator(op)->MutableInputDesc("indices")->GetShape();
|
||||
auto indices_shape_size = indices_shape.GetDimNum();
|
||||
int indices_last_element = indices_shape.GetDim(indices_shape_size - 1);
|
||||
int64_t indices_part{1};
|
||||
for (int i = 0; i < indices_last_element - 1; ++i) {
|
||||
indices_part *= static_cast<int64_t>(indices_shape.GetDim(i));
|
||||
}
|
||||
if (indices_part > std::numeric_limits<int>::max()) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Indices has too many elements for int indexing");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CheckGatherNdParamsSize(const Operator &op, int last_dim, int shape_size) {
|
||||
if (last_dim > shape_size) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "The last dim(%d) of indices must be <= params.rank(%d).", last_dim, shape_size);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
IMPLEMT_VERIFIER(GatherNd, GatherNdVerify) {
|
||||
if (!CheckGatherNdInputIndicesSize(op, "indices")) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(GatherNdInferShape) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
GeTensorDescPtr output_tensor_desc = op_desc->MutableOutputDesc("y");
|
||||
std::vector<std::pair<int64_t, int64_t>> shape_range_x;
|
||||
op_desc->MutableInputDesc("x")->GetShapeRange(shape_range_x);
|
||||
std::vector<std::pair<int64_t, int64_t>> shape_range_indices;
|
||||
op_desc->MutableInputDesc("indices")->GetShapeRange(shape_range_indices);
|
||||
std::vector<std::pair<int64_t, int64_t>> out_range;
|
||||
auto input_params = op_desc->MutableInputDesc("x");
|
||||
auto input_indices = op_desc->MutableInputDesc("indices");
|
||||
auto params_shape = input_params->GetShape();
|
||||
auto indices_shape = input_indices->GetShape();
|
||||
auto params_shape_size = params_shape.GetDimNum();
|
||||
int indices_shape_size = indices_shape.GetDimNum();
|
||||
vector<int64_t> dim_vec;
|
||||
vector<int64_t> params_shape_vec = params_shape.GetDims();
|
||||
vector<int64_t> indices_shape_vec = indices_shape.GetDims();
|
||||
MakeUpShapeRange(params_shape_vec, shape_range_x);
|
||||
MakeUpShapeRange(indices_shape_vec, shape_range_indices);
|
||||
int indices_last_element{-2};
|
||||
if (!IsUnknownRankShape(indices_shape_vec)) {
|
||||
indices_last_element = indices_shape.GetDim(indices_shape_size - 1);
|
||||
}
|
||||
DataType params_type = input_params->GetDataType();
|
||||
if (indices_last_element == -1 || indices_last_element == -2 || IsUnknownRankShape(params_shape_vec)) {
|
||||
dim_vec.push_back(-2);
|
||||
} else if (!CheckGatherNdParamsSize(op, indices_last_element, (int)params_shape_size)) {
|
||||
return GRAPH_FAILED;
|
||||
} else {
|
||||
for (int i = 0; i < indices_shape_size - 1; ++i) {
|
||||
dim_vec.push_back(indices_shape.GetDim(i));
|
||||
if ((size_t)i < shape_range_indices.size()) {
|
||||
out_range.push_back(shape_range_indices[i]);
|
||||
}
|
||||
}
|
||||
for (size_t i = indices_last_element; i < params_shape_size; ++i) {
|
||||
dim_vec.push_back(params_shape.GetDim(i));
|
||||
if (i < shape_range_x.size()) {
|
||||
out_range.push_back(shape_range_x[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
ge::GeShape output_shape = ge::GeShape(dim_vec);
|
||||
DataType output_dtype = params_type;
|
||||
output_tensor_desc->SetShape(output_shape);
|
||||
output_tensor_desc->SetDataType(output_dtype);
|
||||
TensorUtils::SetRealDimCnt(*output_tensor_desc, dim_vec.size());
|
||||
if (!IsUnknownRankShape(dim_vec)) {
|
||||
output_tensor_desc->SetShapeRange(out_range);
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
COMMON_INFER_FUNC_REG(GatherNd, GatherNdInferShape);
|
||||
VERIFY_FUNC_REG(GatherNd, GatherNdVerify);
|
||||
// ----------------GatherNd End-------------------
|
||||
|
||||
// ----------------MaskedSelect Begin-------------------
|
||||
bool InferShapeAndTypeMaskedSelect(Operator &op) {
|
||||
OpDescPtr op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
GeTensorDescPtr x_input = op_desc->MutableInputDesc(0);
|
||||
GeShape x_shape = x_input->GetShape();
|
||||
GeTensorDescPtr y_desc = op_desc->MutableOutputDesc(0);
|
||||
DataType input_dtype = x_input->GetDataType();
|
||||
y_desc->SetDataType(input_dtype);
|
||||
std::vector<std::pair<int64_t, int64_t>> range;
|
||||
y_desc->SetShape(GeShape({UNKNOWN_DIM}));
|
||||
y_desc->SetOriginShape(GeShape({UNKNOWN_DIM}));
|
||||
range.emplace_back(std::make_pair(1, x_shape.GetShapeSize()));
|
||||
y_desc->SetShapeRange(range);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Obtains the processing function of the output tensor description.
|
||||
IMPLEMT_COMMON_INFERFUNC(MaskedSelectInferShape) {
|
||||
if (InferShapeAndTypeMaskedSelect(op)) {
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
OP_LOGE(TbeGetName(op).c_str(), "The shape of output y does not match that of x1 x2.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(MaskedSelect, MaskedSelectInferShape);
|
||||
// ----------------MaskedSelect END---------------------
|
||||
|
||||
// ----------------IndexFill-------------------
|
||||
// Obtains the processing function of the output tensor description.
|
||||
IMPLEMT_COMMON_INFERFUNC(IndexFillInferShape) {
|
||||
TensorDesc v_output_desc = op.GetOutputDescByName("y");
|
||||
|
||||
DataType input_dtype = op.GetInputDescByName("x").GetDataType();
|
||||
Format input_format = op.GetInputDescByName("x").GetFormat();
|
||||
// shape of output y is the same as input x
|
||||
ge::Shape shape_input = op.GetInputDescByName("x").GetShape();
|
||||
|
||||
v_output_desc.SetShape(shape_input);
|
||||
v_output_desc.SetDataType(input_dtype);
|
||||
v_output_desc.SetFormat(input_format);
|
||||
|
||||
if (op.UpdateOutputDesc("y", v_output_desc) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// Registered inferfunction
|
||||
CUST_COMMON_INFER_FUNC_REG(IndexFill, IndexFillInferShape);
|
||||
// ----------------IndexFill END-------------------
|
||||
} // namespace ge
|
|
@ -93,7 +93,7 @@ CUST_IMPLEMT_VERIFIER(SparseApplyAdagradDA, SparseApplyAdagradDAVerify) {
|
|||
}
|
||||
|
||||
Shape indices_shape;
|
||||
if (WithRank(op.GetInputDesc(4), 1, indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(op.GetInputDesc(4), 1, indices_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input indices must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
|
|
@ -90,7 +90,7 @@ CUST_IMPLEMT_VERIFIER(SparseApplyCenteredRMSProp, SparseApplyCenteredRMSPropVeri
|
|||
}
|
||||
|
||||
Shape indices_shape;
|
||||
if (WithRank(op.GetInputDesc(9), 1, indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(op.GetInputDesc(9), 1, indices_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input indices must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
|
|
@ -90,7 +90,7 @@ CUST_IMPLEMT_VERIFIER(SparseApplyMomentum, SparseApplyMomentumVerify) {
|
|||
}
|
||||
|
||||
Shape indices_shape;
|
||||
if (WithRank(op.GetInputDesc(4), 1, indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(op.GetInputDesc(4), 1, indices_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input indices must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
|
|
@ -74,7 +74,7 @@ CUST_IMPLEMT_VERIFIER(SparseApplyProximalGradientDescent, SparseApplyProximalGra
|
|||
}
|
||||
|
||||
Shape indices_shape;
|
||||
if (WithRank(op.GetInputDesc(5), 1, indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(op.GetInputDesc(5), 1, indices_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input indices must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
|
|
@ -31,32 +31,32 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentMeanWithNumSegments, SparseSegmentMeanWithNu
|
|||
|
||||
GeShape indices_shape;
|
||||
auto indices_desc = op_desc->MutableInputDesc(1);
|
||||
if (WithRank(indices_desc, 1, indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(indices_desc, 1, indices_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input indices must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape segment_ids_shape;
|
||||
auto segment_ids_desc = op_desc->MutableInputDesc(2);
|
||||
if (WithRank(segment_ids_desc, 1, segment_ids_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(segment_ids_desc, 1, segment_ids_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input segment_ids must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape num_segments_shape;
|
||||
auto num_segments_desc = op_desc->MutableInputDesc(3);
|
||||
if (WithRankAtMost(num_segments_desc, 1, num_segments_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRankAtMost(num_segments_desc, 1, num_segments_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input nums_segments should be at most 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape unused;
|
||||
if (Merge(indices_shape, segment_ids_shape, unused, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (Merge(indices_shape, segment_ids_shape, unused, op) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape subshape;
|
||||
if (SubShape(x_shape, 1, x_shape.GetDimNum(), 1, subshape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (SubShape(x_shape, 1, x_shape.GetDimNum(), 1, subshape, op) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
@ -67,7 +67,7 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentMeanWithNumSegments, SparseSegmentMeanWithNu
|
|||
}
|
||||
|
||||
if (nums != UNKNOWN_DIM) {
|
||||
if (MakeDimForScalarInput(tensor, nums, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (MakeDimForScalarInput(tensor, nums, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("fail to get dim from tensor of input[num_segments]."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentSqrtNGrad, SparseSegmentSqrtNGradInfer) {
|
|||
|
||||
auto x_desc = op_desc->MutableInputDesc(0);
|
||||
GeShape x_ge_shape;
|
||||
if (WithRankAtLeast(x_desc, 1, x_ge_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRankAtLeast(x_desc, 1, x_ge_shape, op) != GRAPH_SUCCESS) {
|
||||
err_msg = GetShapeErrMsg(0, DebugString(x_desc->GetShape().GetDims()), "at least 1D");
|
||||
err_msg = string("failed to call WithRankAtLeast function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
|
@ -39,7 +39,7 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentSqrtNGrad, SparseSegmentSqrtNGradInfer) {
|
|||
|
||||
auto indices_desc = op_desc->MutableInputDesc(1);
|
||||
GeShape indices_shape;
|
||||
if (WithRank(indices_desc, 1, indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(indices_desc, 1, indices_shape, op) != GRAPH_SUCCESS) {
|
||||
err_msg = GetShapeErrMsg(1, DebugString(indices_desc->GetShape().GetDims()), "1D");
|
||||
err_msg = string("failed to call WithRank function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
|
@ -48,7 +48,7 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentSqrtNGrad, SparseSegmentSqrtNGradInfer) {
|
|||
|
||||
GeShape unused;
|
||||
GeShape segment_ids_shape(op_desc->MutableInputDesc(2)->GetShape());
|
||||
if (Merge(segment_ids_shape, indices_shape, unused, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (Merge(segment_ids_shape, indices_shape, unused, op) != GRAPH_SUCCESS) {
|
||||
err_msg = ConcatString("failed to call Merge function to merge input[segment_ids]'s shape",
|
||||
DebugString(op_desc->MutableInputDesc(2)->GetShape().GetDims()),
|
||||
" and input[indices]'s shape", DebugString(indices_shape.GetDims()));
|
||||
|
@ -56,7 +56,7 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentSqrtNGrad, SparseSegmentSqrtNGradInfer) {
|
|||
return GRAPH_FAILED;
|
||||
}
|
||||
auto unused_desc = op_desc->MutableInputDesc(3);
|
||||
if (WithRank(unused_desc, 0, unused, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(unused_desc, 0, unused, op) != GRAPH_SUCCESS) {
|
||||
err_msg = GetShapeErrMsg(3, DebugString(unused_desc->GetShape().GetDims()), "scalar");
|
||||
err_msg = string("failed to call WithRank function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
|
@ -66,7 +66,7 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentSqrtNGrad, SparseSegmentSqrtNGradInfer) {
|
|||
auto x_shape_dims = x_ge_shape.GetDims();
|
||||
Shape x_shape(x_shape_dims);
|
||||
Shape subshape;
|
||||
if (SubShape(x_shape, 1, x_shape.GetDimNum(), 1, subshape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (SubShape(x_shape, 1, x_shape.GetDimNum(), 1, subshape, op) != GRAPH_SUCCESS) {
|
||||
err_msg = ConcatString("failed to call SubShape function to get subshape from ", x_shape.GetDimNum(),
|
||||
" to 1 in input[x] shape", DebugString(x_shape.GetDims()));
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
|
|
|
@ -24,32 +24,32 @@ graphStatus SparseSegmentReductionShapeFn(Operator &op) {
|
|||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
GeShape x_shapes;
|
||||
auto x_desc = op_desc->MutableInputDesc(0);
|
||||
if (WithRankAtLeast(x_desc, 1, x_shapes, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRankAtLeast(x_desc, 1, x_shapes, op) != GRAPH_SUCCESS) {
|
||||
AICPU_OP_LOGE(TbeGetName(op).c_str(), "Input x should be at least 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape indices_shapes;
|
||||
auto indices_desc = op_desc->MutableInputDesc(1);
|
||||
if (WithRank(indices_desc, 1, indices_shapes, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(indices_desc, 1, indices_shapes, op) != GRAPH_SUCCESS) {
|
||||
AICPU_OP_LOGE(TbeGetName(op).c_str(), "Input indices must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape segment_ids_shapes;
|
||||
auto segment_ids_desc = op_desc->MutableInputDesc(2);
|
||||
if (WithRank(segment_ids_desc, 1, segment_ids_shapes, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(segment_ids_desc, 1, segment_ids_shapes, op) != GRAPH_SUCCESS) {
|
||||
AICPU_OP_LOGE(TbeGetName(op).c_str(), "Input segment_ids must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape unuse;
|
||||
if (Merge(indices_shapes, segment_ids_shapes, unuse, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (Merge(indices_shapes, segment_ids_shapes, unuse, op) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape subshapes;
|
||||
if (SubShape(x_shapes, 1, x_shapes.GetDimNum(), 1, subshapes, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (SubShape(x_shapes, 1, x_shapes.GetDimNum(), 1, subshapes, op) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
|
|
@ -24,39 +24,39 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentSqrtNWithNumSegments, SparseSegmentSqrtNWith
|
|||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
GeShape x_shape;
|
||||
auto x_desc = op_desc->MutableInputDesc(0);
|
||||
if (WithRankAtLeast(x_desc, 1, x_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRankAtLeast(x_desc, 1, x_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input x should be at least 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape indices_shape;
|
||||
auto indices_desc = op_desc->MutableInputDesc(1);
|
||||
if (WithRank(indices_desc, 1, indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(indices_desc, 1, indices_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input indices must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape segment_ids_shape;
|
||||
auto segment_ids_desc = op_desc->MutableInputDesc(2);
|
||||
if (WithRank(segment_ids_desc, 1, segment_ids_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(segment_ids_desc, 1, segment_ids_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input segment_ids must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape num_segments_shape;
|
||||
auto num_segments_desc = op_desc->MutableInputDesc(3);
|
||||
if (WithRankAtMost(num_segments_desc, 1, num_segments_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRankAtMost(num_segments_desc, 1, num_segments_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input nums_segments should be at most 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape unused;
|
||||
if (Merge(indices_shape, segment_ids_shape, unused, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (Merge(indices_shape, segment_ids_shape, unused, op) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape subshape;
|
||||
if (SubShape(x_shape, 1, x_shape.GetDimNum(), 1, subshape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (SubShape(x_shape, 1, x_shape.GetDimNum(), 1, subshape, op) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
@ -67,7 +67,7 @@ CUST_IMPLEMT_INFERFUNC(SparseSegmentSqrtNWithNumSegments, SparseSegmentSqrtNWith
|
|||
}
|
||||
|
||||
if (nums != UNKNOWN_DIM) {
|
||||
if (MakeDimForScalarInput(tensor, nums, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (MakeDimForScalarInput(tensor, nums, op) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("fail to get dim from tensor of input[num_segments]."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
|
|
@ -25,21 +25,21 @@ CUST_IMPLEMT_INFERFUNC(SparseTensorToCSRSparseMatrix, SparseTensorToCSRSparseMat
|
|||
|
||||
GeShape x_indices_shape;
|
||||
auto x_indices_desc = op_desc->MutableInputDesc(0);
|
||||
if (WithRank(x_indices_desc, 2, x_indices_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(x_indices_desc, 2, x_indices_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input x_indices_desc must be 2-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape x_values_shape;
|
||||
auto x_values_desc = op_desc->MutableInputDesc(1);
|
||||
if (WithRank(x_values_desc, 1, x_values_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(x_values_desc, 1, x_values_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input x_values must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape x_dense_shape_shape;
|
||||
auto x_dense_shape_desc = op_desc->MutableInputDesc(2);
|
||||
if (WithRank(x_dense_shape_desc, 1, x_dense_shape_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(x_dense_shape_desc, 1, x_dense_shape_shape, op) != GRAPH_SUCCESS) {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Input x_dense_shape must be 1-D.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
|
|
@ -30,32 +30,32 @@ CUST_IMPLEMT_INFERFUNC(SparseAddmm, SparseAddmmInfer) {
|
|||
auto x1_shape_tensor = op.get_input_desc_x1_shape();
|
||||
auto x2_tensor = op.get_input_desc_x2();
|
||||
std::string err_msg;
|
||||
if (WithRank(x1_indices_tensor, 2, unused_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(x1_indices_tensor, 2, unused_shape, op) != GRAPH_SUCCESS) {
|
||||
err_msg = GetShapeErrMsg(0, DebugString(x1_indices_tensor.GetShape().GetDims()), "2D");
|
||||
err_msg = string("failed to call WithRank function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(x1_values_tensor, 1, unused_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(x1_values_tensor, 1, unused_shape, op) != GRAPH_SUCCESS) {
|
||||
err_msg = GetShapeErrMsg(1, DebugString(x1_values_tensor.GetShape().GetDims()), "1D");
|
||||
err_msg = string("failed to call WithRank function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (MakeShapeFromShapeTensor(op, "x1_shape", x1_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (MakeShapeFromShapeTensor(op, "x1_shape", x1_shape) != GRAPH_SUCCESS) {
|
||||
err_msg = ConcatString(
|
||||
"failed to call MakeShapeFromShapeTensor function to make shape from "
|
||||
"input[x1_shape]");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRankShape(x1_shape, 2, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRankShape(x1_shape, 2, op) != GRAPH_SUCCESS) {
|
||||
err_msg = GetShapeErrMsg(2, DebugString(x1_shape.GetDims()), "2D");
|
||||
err_msg = string("failed to call WithRank function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(x2_tensor, 2, x2_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(x2_tensor, 2, x2_shape, op) != GRAPH_SUCCESS) {
|
||||
err_msg = GetShapeErrMsg(3, DebugString(x2_tensor.GetShape().GetDims()), "2D");
|
||||
err_msg = string("failed to call WithRank function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
|
|
|
@ -111,19 +111,19 @@ CUST_IMPLEMT_INFERFUNC(Sspaddmm, SspaddmmInfer) {
|
|||
}
|
||||
|
||||
// check dimension
|
||||
if (WithRank(mat1_values_tensor, 1, unused_shape1, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(mat1_values_tensor, 1, unused_shape1, op) != GRAPH_SUCCESS) {
|
||||
err_msg = GetShapeErrMsg(1, DebugString(mat1_values_tensor.GetShape().GetDims()), "1D");
|
||||
err_msg = string("MAT1 Values failed to call WithRank function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(input_indices_tensor, 2, unused_shape2, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(input_indices_tensor, 2, unused_shape2, op) != GRAPH_SUCCESS) {
|
||||
err_msg = GetShapeErrMsg(0, DebugString(input_indices_tensor.GetShape().GetDims()), "2D");
|
||||
err_msg = string("input indices failed to call WithRank function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(input_values_tensor, 1, unused_shape1, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(input_values_tensor, 1, unused_shape1, op) != GRAPH_SUCCESS) {
|
||||
err_msg = GetShapeErrMsg(1, DebugString(input_values_tensor.GetShape().GetDims()), "1D");
|
||||
err_msg = string("input values failed to call WithRank function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
|
@ -134,7 +134,7 @@ CUST_IMPLEMT_INFERFUNC(Sspaddmm, SspaddmmInfer) {
|
|||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(mat2_tensor, 2, mat2_shape, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRank(mat2_tensor, 2, mat2_shape, op) != GRAPH_SUCCESS) {
|
||||
err_msg = GetShapeErrMsg(3, DebugString(mat2_tensor.GetShape().GetDims()), "2D");
|
||||
err_msg = string("mat2 tensor failed to call WithRank function, ") + err_msg;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
|
|
|
@ -0,0 +1,188 @@
|
|||
/*
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inc/ops/transformation_ops.h"
|
||||
#include "register/op_impl_registry.h"
|
||||
#include "utils/util.h"
|
||||
#include "utils/common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
// ------------------DepthToSpace------------------
|
||||
static bool VerifyDepthToSpaceInputShape(const Operator &op, const int64_t &block_size,
|
||||
const std::vector<int64_t> &input_dims, const std::string &data_format) {
|
||||
bool check_format = (data_format == "NCHW" || data_format == "NHWC");
|
||||
if (check_format && !IsUnknown(input_dims)) {
|
||||
int64_t c_dim = 3;
|
||||
c_dim = data_format == "NHWC" ? 3 : 1;
|
||||
auto mod_res = input_dims[c_dim] % (block_size * block_size);
|
||||
if (mod_res != 0) {
|
||||
OP_LOGE(TbeGetName(op),
|
||||
"Depth size must be divisible by block_size * block_size,"
|
||||
"but got depth[%ld], block_size[%ld], data_format[%s]",
|
||||
input_dims[c_dim], block_size, data_format.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
IMPLEMT_VERIFIER(DepthToSpace, DepthToSpaceVerify) {
|
||||
// verify input shape size
|
||||
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto input_desc = op_info->MutableInputDesc("x");
|
||||
auto input_dims = input_desc->MutableShape().GetDims();
|
||||
if (!IsUnknownRankShape(input_dims) && (input_dims.size() < 4)) {
|
||||
std::string err_msg = GetAttrValueErrMsg("input_dims", std::to_string(input_dims.size()), ConcatString(">=4"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
// verify block size
|
||||
int64_t block_size;
|
||||
if (op.GetAttr("block_size", block_size) != GRAPH_SUCCESS) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("block_size");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (block_size < 2) {
|
||||
std::string err_msg = GetAttrValueErrMsg("block_size", std::to_string(block_size), ConcatString("=<2"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
// verify mode
|
||||
std::string mode;
|
||||
if (op.GetAttr("mode", mode) != GRAPH_SUCCESS) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("mode");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (mode != "DCR" && mode != "CRD") {
|
||||
string expected_format_list = ConcatString("DCR, CRD");
|
||||
std::string err_msg = GetAttrValueErrMsg("mode", mode, expected_format_list);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
// verify data_format
|
||||
std::string data_format;
|
||||
if (op.GetAttr("data_format", data_format) != GRAPH_SUCCESS) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("data_format");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (data_format != "NHWC" && data_format != "NCHW" && data_format != "NC1HWC0") {
|
||||
string expected_format_list = ConcatString("NHWC, NCHW, NC1HWC0");
|
||||
std::string err_msg = GetInputFormatNotSupportErrMsg("data_format", expected_format_list, data_format);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
// verify input shape
|
||||
bool check_input_shape = true;
|
||||
check_input_shape = VerifyDepthToSpaceInputShape(op, block_size, input_dims, data_format);
|
||||
if (!check_input_shape) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
IMPLEMT_COMMON_INFERFUNC(DepthToSpaceInfer) {
|
||||
auto op_info = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto input_desc = op_info->MutableInputDesc("x");
|
||||
auto input_dims = input_desc->MutableShape().GetDims();
|
||||
auto input_dtype = input_desc->GetDataType();
|
||||
auto input_format = static_cast<ge::Format>(ge::GetPrimaryFormat(input_desc->GetFormat()));
|
||||
|
||||
auto output_desc = op_info->MutableOutputDesc("y");
|
||||
output_desc->SetDataType(input_dtype);
|
||||
|
||||
// get attr block_size
|
||||
int64_t block_size;
|
||||
if (GRAPH_SUCCESS != op.GetAttr("block_size", block_size)) {
|
||||
std::string err_msg = GetInputInvalidErrMsg("block_size");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
// not dynamic case, only set shape
|
||||
if (!IsUnknown(input_dims)) {
|
||||
std::vector<int64_t> output_dims;
|
||||
output_dims.push_back(input_dims[0]);
|
||||
if (input_format == FORMAT_NCHW) {
|
||||
output_dims.push_back(input_dims[1] / block_size / block_size);
|
||||
output_dims.push_back(input_dims[2] * block_size);
|
||||
output_dims.push_back(input_dims[3] * block_size);
|
||||
} else { // without NCHW all other format set as NHWC
|
||||
output_dims.push_back(input_dims[1] * block_size);
|
||||
output_dims.push_back(input_dims[2] * block_size);
|
||||
output_dims.push_back(input_dims[3] / block_size / block_size);
|
||||
}
|
||||
output_desc->SetShape(GeShape(output_dims));
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// dynamic case, input shape is -2, output is -2
|
||||
if (IsUnknownRankShape(input_dims)) {
|
||||
output_desc->SetShape(GeShape(input_dims));
|
||||
OP_LOGW(TbeGetName(op).c_str(), "input shape is UnknownRank, set output is UnknownRank.");
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// dynamic case, input shape is -1, output is -1
|
||||
std::vector<std::pair<int64_t, int64_t>> input_range;
|
||||
input_desc->GetShapeRange(input_range);
|
||||
MakeUpShapeRange(input_dims, input_range);
|
||||
|
||||
// infer output shape and range
|
||||
std::vector<int64_t> output_dims;
|
||||
std::vector<std::pair<int64_t, int64_t>> output_range;
|
||||
output_dims.push_back(input_dims[0]);
|
||||
output_range.push_back(input_range[0]);
|
||||
int64_t dim;
|
||||
int64_t range_min;
|
||||
int64_t range_max;
|
||||
if (input_format == FORMAT_NCHW) {
|
||||
dim = input_dims[1] == -1 ? -1 : input_dims[1] / block_size / block_size;
|
||||
range_min = input_range[1].first / block_size / block_size;
|
||||
range_min = std::max(int64_t(range_min), int64_t(1));
|
||||
range_max = input_range[1].second == -1 ? -1 : input_range[1].second / block_size / block_size;
|
||||
output_dims.push_back(dim);
|
||||
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
|
||||
dim = input_dims[2] == -1 ? -1 : input_dims[2] * block_size;
|
||||
range_min = input_range[2].first * block_size;
|
||||
range_max = input_range[2].second == -1 ? -1 : input_range[2].second * block_size;
|
||||
output_dims.push_back(dim);
|
||||
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
|
||||
dim = input_dims[3] == -1 ? -1 : input_dims[3] * block_size;
|
||||
range_min = input_range[3].first * block_size;
|
||||
range_max = input_range[3].second == -1 ? -1 : input_range[3].second * block_size;
|
||||
output_dims.push_back(dim);
|
||||
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
|
||||
} else {
|
||||
dim = input_dims[1] == -1 ? -1 : input_dims[1] * block_size;
|
||||
range_min = input_range[1].first * block_size;
|
||||
range_max = input_range[1].second == -1 ? -1 : input_range[1].second * block_size;
|
||||
output_dims.push_back(dim);
|
||||
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
|
||||
dim = input_dims[2] == -1 ? -1 : input_dims[2] * block_size;
|
||||
range_min = input_range[2].first * block_size;
|
||||
range_max = input_range[2].second == -1 ? -1 : input_range[2].second * block_size;
|
||||
output_dims.push_back(dim);
|
||||
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
|
||||
dim = input_dims[3] == -1 ? -1 : input_dims[3] / block_size / block_size;
|
||||
range_min = input_range[3].first / block_size / block_size;
|
||||
range_min = std::max(int64_t(range_min), int64_t(1));
|
||||
range_max = input_range[3].second == -1 ? -1 : input_range[3].second / block_size / block_size;
|
||||
output_dims.push_back(dim);
|
||||
output_range.push_back(std::pair<int64_t, int64_t>(range_min, range_max));
|
||||
}
|
||||
|
||||
output_desc->SetShape(GeShape(output_dims));
|
||||
output_desc->SetShapeRange(output_range);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
COMMON_INFER_FUNC_REG(DepthToSpace, DepthToSpaceInfer);
|
||||
VERIFY_FUNC_REG(DepthToSpace, DepthToSpaceVerify);
|
||||
// -------------------DepthToSpace END-----------------
|
||||
} // namespace ge
|
|
@ -28,7 +28,7 @@ CUST_IMPLEMT_INFERFUNC(TridiagonalMatMul, TridiagonalMatMulInfer) {
|
|||
Shape rhs;
|
||||
TensorDesc rhs_desc = op.GetInputDesc(3);
|
||||
auto rhs_shape = rhs_desc.GetShape().GetDims();
|
||||
if (WithRankAtLeast(rhs_desc, 2, rhs, TbeGetName(op).c_str()) != GRAPH_SUCCESS) {
|
||||
if (WithRankAtLeast(rhs_desc, 2, rhs, op) != GRAPH_SUCCESS) {
|
||||
error_msg =
|
||||
ConcatString("failed to call WithRankatleast function, ", "the rank of input[rhs] must at least be 2, but get ",
|
||||
rhs_desc.GetShape().GetDimNum(), ".");
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
* \file axis_util.h
|
||||
* \brief get the axis value
|
||||
*/
|
||||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_
|
||||
#define OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_
|
||||
#ifndef CUSTOMIZE_OP_PROTO_UTIL_AXIS_UTIL_H_
|
||||
#define CUSTOMIZE_OP_PROTO_UTIL_AXIS_UTIL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
|
@ -142,4 +142,4 @@ class AxisUtil {
|
|||
};
|
||||
} // namespace ge
|
||||
|
||||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_
|
||||
#endif // CUSTOMIZE_OP_PROTO_UTIL_AXIS_UTIL_H_
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd 2019-2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,14 +21,10 @@
|
|||
#include "common_shape_fns.h"
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "op_log.h"
|
||||
#include "error_util.h"
|
||||
#include "graph/utils/op_desc_utils.h"
|
||||
#include "common/util/error_manager/error_manager.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace ge {
|
||||
const std::map<std::string, DataType> dtype_maps{{"DT_FLOAT", DT_FLOAT},
|
||||
|
@ -58,6 +54,42 @@ const std::map<std::string, DataType> dtype_maps{{"DT_FLOAT", DT_FLOAT},
|
|||
{"DT_DUAL", DT_DUAL},
|
||||
{"DT_UNDEFINED", DT_UNDEFINED}};
|
||||
|
||||
graphStatus WithRankAtLeast(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op) {
|
||||
if (rank > INT32_MAX) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
Shape s = tensor.GetShape();
|
||||
std::vector<int64_t> dims = s.GetDims();
|
||||
// dim.size() convert to be type int64_t can't overflow
|
||||
int64_t size = static_cast<int64_t>(dims.size());
|
||||
if (!((size >= rank) || (dims == UNKNOWN_SHAPE))) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", size, "] must be at least [", rank, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out = s;
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus WithRankAtLeast(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape,
|
||||
const ge::Operator &op) {
|
||||
if (rank > INT32_MAX) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape s = tensorDesc->GetShape();
|
||||
std::vector<int64_t> dims = s.GetDims();
|
||||
// dim.size() convert to be type int64_t can't overflow
|
||||
int64_t size = static_cast<int64_t>(dims.size());
|
||||
if ((dims != UNKNOWN_RANK) && (size < rank)) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", size, "] must be at least [", rank, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out_shape = s;
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus WithRankAtLeast(const TensorDesc &tensor, int64_t rank, Shape &out, const char *op_name) {
|
||||
if (rank > INT32_MAX) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
|
@ -95,9 +127,9 @@ graphStatus WithRankAtLeast(const GeTensorDescPtr &tensorDesc, int64_t rank, GeS
|
|||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus WithRankShape(GeShape &shape, int64_t rank, const char *op_name) {
|
||||
graphStatus WithRankShape(GeShape &shape, int64_t rank, const ge::Operator &op) {
|
||||
if (rank > INT32_MAX) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
@ -109,7 +141,7 @@ graphStatus WithRankShape(GeShape &shape, int64_t rank, const char *op_name) {
|
|||
return GRAPH_SUCCESS;
|
||||
}
|
||||
if (existing != rank) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", existing, "] must be [", rank, "]"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", existing, "] must be [", rank, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
@ -118,9 +150,9 @@ graphStatus WithRankShape(GeShape &shape, int64_t rank, const char *op_name) {
|
|||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const char *op_name) {
|
||||
graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op) {
|
||||
if (rank > INT32_MAX) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
Shape s = tensor.GetShape();
|
||||
|
@ -133,16 +165,16 @@ graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const c
|
|||
}
|
||||
|
||||
if (existing != rank) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", existing, "] must be [", rank, "]"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", existing, "] must be [", rank, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out = s;
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape, const char *op_name) {
|
||||
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape, const ge::Operator &op) {
|
||||
if (rank > INT32_MAX) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
@ -155,16 +187,16 @@ graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &o
|
|||
}
|
||||
|
||||
if (existing != rank) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", existing, "] must be [", rank, "]"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", existing, "] must be [", rank, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out_shape = s;
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, Shape &out_shape, const char *op_name) {
|
||||
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, Shape &out_shape, const ge::Operator &op) {
|
||||
if (rank > INT32_MAX) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
@ -177,27 +209,27 @@ graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, Shape &out
|
|||
}
|
||||
|
||||
if (existing != rank) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", existing, "] must be [", rank, "]"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", existing, "] must be [", rank, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out_shape = Shape(s.GetDims());
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus WithValue(int64_t dim, int64_t value, int64_t &out, const char *op_name) {
|
||||
graphStatus WithValue(int64_t dim, int64_t value, int64_t &out, const ge::Operator &op) {
|
||||
out = value;
|
||||
if (dim == UNKNOWN_DIM) {
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
if (dim != value) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("dim[", dim, "] should be ", value));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("dim[", dim, "] should be ", value));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape &prefix_out, const char *op_name) {
|
||||
graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape &prefix_out, const ge::Operator &op) {
|
||||
// Same shape and unknown rank
|
||||
if (!RankKnown(s) || !RankKnown(prefix)) {
|
||||
s_out = s;
|
||||
|
@ -208,7 +240,7 @@ graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape
|
|||
std::vector<int64_t> dims1 = s.GetDims();
|
||||
if ((dims1 != UNKNOWN_RANK) && (dims1.size() < rank)) {
|
||||
std::string err_msg = ConcatString("first shape rank[", dims1.size(), "] must be at least rank[", rank, "]");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
@ -220,7 +252,7 @@ graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape
|
|||
if (Merge(s.GetDim(i), prefix.GetDim(i), dims[i]) != GRAPH_SUCCESS) {
|
||||
std::string err_msg = ConcatString(i, "th dim of first shape", DebugString(s.GetDims()),
|
||||
" is not same as that of prefix shape", DebugString(prefix.GetDims()));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -246,7 +278,7 @@ graphStatus Merge(int64_t dim1, int64_t dim2, int64_t &out) {
|
|||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const char *op_name) {
|
||||
graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const ge::Operator &op) {
|
||||
// Same shape and unknown rank
|
||||
if (s0.GetDims() == s1.GetDims()) {
|
||||
out = s0;
|
||||
|
@ -263,7 +295,7 @@ graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const char *op_n
|
|||
if (s1.GetDimNum() != rank) {
|
||||
std::string err_msg = ConcatString("different rank of first shape", DebugString(s0.GetDims()), " and second shape",
|
||||
DebugString(s1.GetDims()));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
@ -282,7 +314,7 @@ graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const char *op_n
|
|||
} else if (d0 != d1) {
|
||||
std::string err_msg = ConcatString("different ", i, "th dim of first shape", DebugString(s0.GetDims()),
|
||||
" and second shape", DebugString(s1.GetDims()));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -299,7 +331,7 @@ graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const char *op_n
|
|||
if (Merge(s0.GetDim(i), s1.GetDim(i), dims[i]) == GRAPH_FAILED) {
|
||||
std::string err_msg = ConcatString("merge ", i, "th dim failed, first shape", DebugString(s0.GetDims()),
|
||||
" and second shape", DebugString(s1.GetDims()));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -308,7 +340,7 @@ graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const char *op_n
|
|||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus Merge(const GeShape &s0, const GeShape &s1, GeShape &out, const char *op_name) {
|
||||
graphStatus Merge(const GeShape &s0, const GeShape &s1, GeShape &out, const ge::Operator &op) {
|
||||
// Same shape and unknown rank
|
||||
if (s0.GetDims() == s1.GetDims()) {
|
||||
out = s0;
|
||||
|
@ -325,7 +357,7 @@ graphStatus Merge(const GeShape &s0, const GeShape &s1, GeShape &out, const char
|
|||
if (s1.GetDimNum() != rank) {
|
||||
std::string err_msg = ConcatString("different rank of first shape", DebugString(s0.GetDims()), " and second shape",
|
||||
DebugString(s1.GetDims()));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
@ -344,7 +376,7 @@ graphStatus Merge(const GeShape &s0, const GeShape &s1, GeShape &out, const char
|
|||
} else if (d0 != d1) {
|
||||
std::string err_msg = ConcatString("different ", i, "th dim of first shape", DebugString(s0.GetDims()),
|
||||
" and second shape", DebugString(s1.GetDims()));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -361,7 +393,7 @@ graphStatus Merge(const GeShape &s0, const GeShape &s1, GeShape &out, const char
|
|||
if (Merge(s0.GetDim(i), s1.GetDim(i), dims[i]) == GRAPH_FAILED) {
|
||||
std::string err_msg = ConcatString("merge ", i, "th dim failed, first shape", DebugString(s0.GetDims()),
|
||||
" and second shape", DebugString(s1.GetDims()));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), err_msg);
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -403,7 +435,7 @@ void MergeRange(const std::vector<std::pair<int64_t, int64_t>> &shared_shape_ran
|
|||
}
|
||||
|
||||
graphStatus MergeShapeAndRange(const ShapeAndRange &shared_shape_and_range, const ShapeAndRange &value_shape_and_range,
|
||||
ShapeAndRange &out, bool &shape_changed, const char *op_name) {
|
||||
ShapeAndRange &out, bool &shape_changed, const ge::Operator &op) {
|
||||
if (!RankKnown(shared_shape_and_range.shape_)) {
|
||||
out = {Shape(UNKNOWN_RANK), {}, value_shape_and_range.shape_type_};
|
||||
return GRAPH_SUCCESS;
|
||||
|
@ -445,7 +477,7 @@ graphStatus MergeShapeAndRange(const ShapeAndRange &shared_shape_and_range, cons
|
|||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Shape &out, const char *op_name) {
|
||||
graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Shape &out, const ge::Operator &op) {
|
||||
if (!RankKnown(s)) {
|
||||
out = Shape(ge::UNKNOWN_SHAPE);
|
||||
return GRAPH_SUCCESS;
|
||||
|
@ -456,8 +488,8 @@ graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Sh
|
|||
}
|
||||
if (!FastBoundsCheck(dim_index, s.GetDimNum())) {
|
||||
out = Shape();
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("out of range: replace dim[", dim_index_in,
|
||||
"] for shape with rank[", s.GetDimNum(), "]"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
op, ConcatString("out of range: replace dim[", dim_index_in, "] for shape with rank[", s.GetDimNum(), "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::vector<int64_t> dims = s.GetDims();
|
||||
|
@ -466,7 +498,7 @@ graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Sh
|
|||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus ReplaceDim(const GeShape &s, int64_t dim_index_in, int64_t new_dim, GeShape &out, const char *op_name) {
|
||||
graphStatus ReplaceDim(const GeShape &s, int64_t dim_index_in, int64_t new_dim, GeShape &out, const ge::Operator &op) {
|
||||
if (!RankKnown(s)) {
|
||||
out = GeShape(UNKNOWN_RANK);
|
||||
return GRAPH_SUCCESS;
|
||||
|
@ -477,8 +509,8 @@ graphStatus ReplaceDim(const GeShape &s, int64_t dim_index_in, int64_t new_dim,
|
|||
}
|
||||
if (!FastBoundsCheck(dim_index, s.GetDimNum())) {
|
||||
out = GeShape();
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("out of range: replace dim[", dim_index_in,
|
||||
"] for shape with rank[", s.GetDimNum(), "]"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
op, ConcatString("out of range: replace dim[", dim_index_in, "] for shape with rank[", s.GetDimNum(), "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::vector<int64_t> dims = s.GetDims();
|
||||
|
@ -491,7 +523,7 @@ template <typename Ta, typename Tb>
|
|||
bool FastBoundsCheck(const Ta index, const Tb limit) {
|
||||
static_assert(std::is_integral<Ta>::value && std::is_integral<Tb>::value,
|
||||
"FastBoundsCheck can only be used on integer types.");
|
||||
using UIndex = typename std::make_unsigned<decltype(index + limit)>::type;
|
||||
typedef typename std::make_unsigned<decltype(index + limit)>::type UIndex;
|
||||
return static_cast<UIndex>(index) < static_cast<UIndex>(limit);
|
||||
}
|
||||
|
||||
|
@ -512,7 +544,7 @@ graphStatus Add(int64_t dim1, int64_t dim2, int64_t &out) {
|
|||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const char *op_name) {
|
||||
graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const ge::Operator &op) {
|
||||
if (dim2 == 0) {
|
||||
out = dim1;
|
||||
} else if ((dim1 == UNKNOWN_DIM) || (dim2 == UNKNOWN_DIM)) {
|
||||
|
@ -520,7 +552,7 @@ graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const char *op_na
|
|||
} else {
|
||||
if (dim1 < dim2) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
string(op_name), ConcatString("negative dimension caused by subtracting. dim1[", dim1, "], dim2[", dim2, "]"));
|
||||
op, ConcatString("negative dimension caused by subtracting. dim1[", dim1, "], dim2[", dim2, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out = dim1 - dim2;
|
||||
|
@ -528,10 +560,9 @@ graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const char *op_na
|
|||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride, Shape &out, const char *op_name) {
|
||||
graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride, Shape &out, const ge::Operator &op) {
|
||||
if (s.GetDimNum() > INT32_MAX) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name),
|
||||
ConcatString("rank[", s.GetDimNum(), "] cannot exceed kint32max"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", s.GetDimNum(), "] cannot exceed kint32max"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
const int64_t rank = static_cast<int64_t>(s.GetDimNum());
|
||||
|
@ -557,7 +588,7 @@ graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride,
|
|||
start = rank + start;
|
||||
if (start < 0) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
string(op_name), ConcatString("invalid start[", start - rank, "] to get sub shape with rank[", rank, "]"));
|
||||
op, ConcatString("invalid start[", start - rank, "] to get sub shape with rank[", rank, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -566,21 +597,21 @@ graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride,
|
|||
end = rank + end;
|
||||
if (end < 0) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
string(op_name), ConcatString("invalid end[", end - rank, "] to get sub shape with rank[", rank, "]"));
|
||||
op, ConcatString("invalid end[", end - rank, "] to get sub shape with rank[", rank, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
// stride > 0 and start > end
|
||||
if (!((stride <= 0 || start <= end))) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("start[", start, "] should be less than end[",
|
||||
end, "] at positive stride[", stride, "]"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
op, ConcatString("start[", start, "] should be less than end[", end, "] at positive stride[", stride, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
// stride < 0 and start < end
|
||||
if (!(stride >= 0 || start >= end)) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("start[", start, "] should be greater than end[",
|
||||
end, "] at negative stride[", stride, "]"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
op, ConcatString("start[", start, "] should be greater than end[", end, "] at negative stride[", stride, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::vector<int64_t> dims;
|
||||
|
@ -593,10 +624,10 @@ graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride,
|
|||
}
|
||||
|
||||
graphStatus SubShape(const GeShape &src_shape, int64_t start, int64_t end, int64_t stride, GeShape &out_shape,
|
||||
const char *op_name) {
|
||||
const ge::Operator &op) {
|
||||
int64_t src_rank = src_shape.GetDimNum();
|
||||
if (src_rank > static_cast<int64_t>(std::numeric_limits<int32_t>::max())) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", src_rank, "] cannot exceed kint32max"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", src_rank, "] cannot exceed kint32max"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
@ -622,8 +653,7 @@ graphStatus SubShape(const GeShape &src_shape, int64_t start, int64_t end, int64
|
|||
start += src_rank;
|
||||
if (start < 0) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
string(op_name),
|
||||
ConcatString("invalid start[", start - src_rank, "] to get sub shape with rank[", src_rank, "]"));
|
||||
op, ConcatString("invalid start[", start - src_rank, "] to get sub shape with rank[", src_rank, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -632,18 +662,18 @@ graphStatus SubShape(const GeShape &src_shape, int64_t start, int64_t end, int64
|
|||
end += src_rank;
|
||||
if (end < 0) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
string(op_name), ConcatString("invalid end[", end - src_rank, "] to get sub shape with rank[", src_rank, "]"));
|
||||
op, ConcatString("invalid end[", end - src_rank, "] to get sub shape with rank[", src_rank, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (stride > 0 && start > end) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("start[", start, "] should be less than end[",
|
||||
end, "] at positive stride[", stride, "]"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
op, ConcatString("start[", start, "] should be less than end[", end, "] at positive stride[", stride, "]"));
|
||||
return GRAPH_FAILED;
|
||||
} else if (stride < 0 && start < end) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("start[", start, "] should be greater than end[",
|
||||
end, "] at negative stride[", stride, "]"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
op, ConcatString("start[", start, "] should be greater than end[", end, "] at negative stride[", stride, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
@ -698,7 +728,7 @@ graphStatus Concatenate(const GeShape &s1, const GeShape &s2, GeShape &out) {
|
|||
|
||||
graphStatus Matrix(int64_t dim1, int64_t dim2, Shape &out) {
|
||||
std::vector<int64_t> dims;
|
||||
dims.reserve(DIM_VALUE2); // The number of dims is 2.
|
||||
dims.reserve(2); // The number of dims is 2.
|
||||
dims.push_back(dim1);
|
||||
dims.push_back(dim2);
|
||||
Shape s(dims);
|
||||
|
@ -716,7 +746,7 @@ graphStatus Vector(int64_t dim, Shape &out) {
|
|||
}
|
||||
|
||||
static graphStatus GetShapeDataFromShapeTensor(Operator &op, const string &dst_name, int64_t rank,
|
||||
std::vector<int64_t> &data, const char *op_name) {
|
||||
std::vector<int64_t> &data) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto shape_data_desc = op_desc->MutableInputDesc(dst_name);
|
||||
|
||||
|
@ -728,13 +758,13 @@ static graphStatus GetShapeDataFromShapeTensor(Operator &op, const string &dst_n
|
|||
DataType data_type = shape_data_desc->GetDataType();
|
||||
if (dims.size() != static_cast<size_t>(rank)) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
string(op_name), ConcatString("invalid shape data rank[", dims.size(), "], should be [", rank, "]"));
|
||||
op, ConcatString("invalid shape data rank[", dims.size(), "], should be [", rank, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
int64_t dim_value = ((rank > 0) && (dims[0] > 0)) ? dims[0] : 1;
|
||||
data.clear();
|
||||
if (dims[0] < 0) {
|
||||
OP_LOGI(op_name, "Shape rank is %zu, dims[0] value is [%ld]", dims.size(), dims[0]);
|
||||
OP_LOGI(op, "Shape rank is %zu, dims[0] value is [%ld]", dims.size(), dims[0]);
|
||||
data.push_back(UNKNOWN_DIM_NUM);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
@ -747,7 +777,7 @@ static graphStatus GetShapeDataFromShapeTensor(Operator &op, const string &dst_n
|
|||
data.push_back(static_cast<int64_t>(shape_data[i]));
|
||||
}
|
||||
} else {
|
||||
OP_LOGI(TbeGetName(op).c_str(), "Input [%s] is not a const tensor.", dst_name.c_str());
|
||||
OP_LOGI(op, "Input [%s] is not a const tensor.", dst_name.c_str());
|
||||
for (int64_t i = 0; i < dim_value; i++) {
|
||||
data.push_back(UNKNOWN_DIM);
|
||||
}
|
||||
|
@ -759,14 +789,14 @@ static graphStatus GetShapeDataFromShapeTensor(Operator &op, const string &dst_n
|
|||
data.push_back(static_cast<int64_t>(shape_data[i]));
|
||||
}
|
||||
} else {
|
||||
OP_LOGI(TbeGetName(op).c_str(), "Input [%s] is not a const tensor.", dst_name.c_str());
|
||||
OP_LOGI(op, "Input [%s] is not a const tensor.", dst_name.c_str());
|
||||
for (int64_t i = 0; i < dim_value; i++) {
|
||||
data.push_back(UNKNOWN_DIM);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
string(op_name), ConcatString("invalid data type[", DTypeStr(data_type), "], should be DT_INT32 or DT_INT64"));
|
||||
op, ConcatString("invalid data type[", DTypeStr(data_type), "], should be DT_INT32 or DT_INT64"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
@ -774,7 +804,7 @@ static graphStatus GetShapeDataFromShapeTensor(Operator &op, const string &dst_n
|
|||
}
|
||||
|
||||
static graphStatus GetShapeDataFromConstData(const Tensor &tensor, int64_t rank, std::vector<int64_t> &data,
|
||||
const char *op_name) {
|
||||
const ge::Operator &op) {
|
||||
TensorDesc shape_data_desc = tensor.GetTensorDesc();
|
||||
Shape shape_data_shape = shape_data_desc.GetShape();
|
||||
std::vector<int64_t> dims = shape_data_shape.GetDims();
|
||||
|
@ -782,84 +812,83 @@ static graphStatus GetShapeDataFromConstData(const Tensor &tensor, int64_t rank,
|
|||
|
||||
if (dims.size() != static_cast<size_t>(rank)) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
string(op_name), ConcatString("invalid shape data rank[", dims.size(), "], should be [", rank, "]"));
|
||||
op, ConcatString("invalid shape data rank[", dims.size(), "], should be [", rank, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
int64_t dim_value = rank > 0 ? dims[0] : 1;
|
||||
OP_LOGI(op_name, "data_type = %d, dim_value = %ld", data_type, dim_value);
|
||||
OP_LOGI(op, "data_type = %d, dim_value = %ld", data_type, dim_value);
|
||||
data.clear();
|
||||
data.reserve(dim_value);
|
||||
if (data_type == DT_INT32) {
|
||||
const int32_t *shape_data = reinterpret_cast<const int32_t *>(tensor.GetData());
|
||||
for (int64_t i = 0; i < dim_value; i++) {
|
||||
OP_LOGI(op_name, "DT_INT32 i = %ld, shape_data[i] = %ld", i, static_cast<int64_t>(shape_data[i]));
|
||||
OP_LOGI(op, "DT_INT32 i = %ld, shape_data[i] = %ld", i, static_cast<int64_t>(shape_data[i]));
|
||||
data.push_back(static_cast<int64_t>(shape_data[i]));
|
||||
}
|
||||
} else if (data_type == DT_INT64) {
|
||||
const int64_t *shape_data = reinterpret_cast<const int64_t *>(tensor.GetData());
|
||||
for (int64_t i = 0; i < dim_value; i++) {
|
||||
OP_LOGI(op_name, "DT_INT64 i = %ld, shape_data[i] = %ld", i, shape_data[i]);
|
||||
OP_LOGI(op, "DT_INT64 i = %ld, shape_data[i] = %ld", i, shape_data[i]);
|
||||
data.push_back(shape_data[i]);
|
||||
}
|
||||
} else {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
string(op_name), ConcatString("invalid data type[", DTypeStr(data_type), "], should be DT_INT32 or DT_INT64"));
|
||||
op, ConcatString("invalid data type[", DTypeStr(data_type), "], should be DT_INT32 or DT_INT64"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus MakeShapeFromShapeTensor(const Tensor &tensor, Shape &out, const char *op_name) {
|
||||
graphStatus MakeShapeFromShapeTensor(const Tensor &tensor, Shape &out, const ge::Operator &op) {
|
||||
std::vector<int64_t> shape_data;
|
||||
GetShapeDataFromConstData(tensor, 1, shape_data, op_name);
|
||||
GetShapeDataFromConstData(tensor, 1, shape_data, op);
|
||||
out = Shape(shape_data);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus MakeShapeFromShapeTensor(Operator &op, const string &dst_name, GeShape &out, const char *op_name) {
|
||||
graphStatus MakeShapeFromShapeTensor(Operator &op, const string &dst_name, GeShape &out) {
|
||||
std::vector<int64_t> shape_data;
|
||||
if (GetShapeDataFromShapeTensor(op, dst_name, 1, shape_data, op_name) != GRAPH_SUCCESS) {
|
||||
if (GetShapeDataFromShapeTensor(op, dst_name, 1, shape_data) != GRAPH_SUCCESS) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out = GeShape(shape_data);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus MakeDimForScalarInput(const Tensor &tensor, int64_t &out, const char *op_name) {
|
||||
graphStatus MakeDimForScalarInput(const Tensor &tensor, int64_t &out, const ge::Operator &op) {
|
||||
std::vector<int64_t> shape_data;
|
||||
GetShapeDataFromConstData(tensor, 0, shape_data, op_name);
|
||||
GetShapeDataFromConstData(tensor, 0, shape_data, op);
|
||||
out = shape_data[0];
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus WithRankAtMost(const TensorDesc &tensor, int64_t rank, Shape &out, const char *op_name) {
|
||||
graphStatus WithRankAtMost(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op) {
|
||||
if (rank > INT32_MAX) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
Shape s = tensor.GetShape();
|
||||
std::vector<int64_t> dims = s.GetDims();
|
||||
if (!((dims.size() <= static_cast<size_t>(rank)) || (dims == ge::UNKNOWN_SHAPE))) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name),
|
||||
ConcatString("invalid rank[", dims.size(), "], should be at most ", rank));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("invalid rank[", dims.size(), "], should be at most ", rank));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out = s;
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus WithRankAtMost(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape, const char *op_name) {
|
||||
graphStatus WithRankAtMost(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape,
|
||||
const ge::Operator &op) {
|
||||
if (rank > INT32_MAX) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name), ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("rank[", rank, "] cannot exceed kint32max"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
GeShape s = tensorDesc->GetShape();
|
||||
std::vector<int64_t> dims = s.GetDims();
|
||||
if ((dims != ge::UNKNOWN_RANK) && (dims.size() > static_cast<size_t>(rank))) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name),
|
||||
ConcatString("invalid rank[", dims.size(), "], should be at most ", rank));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("invalid rank[", dims.size(), "], should be at most ", rank));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
@ -881,20 +910,19 @@ graphStatus UnchangedShape(Operator &op, const string input_name, const string o
|
|||
}
|
||||
|
||||
graphStatus Divide(const int64_t dividend, const int64_t divisor, const bool evenlyDivisible, int64_t &out,
|
||||
const char *op_name) {
|
||||
const ge::Operator &op) {
|
||||
if (divisor == 1) {
|
||||
out = dividend;
|
||||
} else if ((dividend == ge::UNKNOWN_DIM) || (divisor == ge::UNKNOWN_DIM)) {
|
||||
out = ge::UNKNOWN_DIM;
|
||||
} else {
|
||||
if (divisor <= 0) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name),
|
||||
ConcatString("invalid divisor[", divisor, "], should be positive"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("invalid divisor[", divisor, "], should be positive"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (!((!evenlyDivisible) || (dividend % divisor) == 0)) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
string(op_name), ConcatString("[", dividend, "] cannot be evenly divisible by [", divisor, "]"));
|
||||
op, ConcatString("[", dividend, "] cannot be evenly divisible by [", divisor, "]"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out = dividend / divisor;
|
||||
|
@ -970,25 +998,25 @@ bool ValueKnown(const Shape &shape, const size_t &dim_index) {
|
|||
}
|
||||
|
||||
graphStatus ValidateSparseTensor(const TensorDesc &indices, const TensorDesc &values, const TensorDesc &shape,
|
||||
const char *op_name) {
|
||||
const ge::Operator &op) {
|
||||
// Validate ranks
|
||||
Shape unused_shape;
|
||||
if (WithRank(indices, NUM_VALUE2, unused_shape, op_name) != GRAPH_SUCCESS) { // The rank is 2.
|
||||
if (WithRank(indices, 2, unused_shape, op) != GRAPH_SUCCESS) { // The rank is 2.
|
||||
std::string err_msg = ConcatString("failed to call WithRank function, indices has wrong shape",
|
||||
DebugString(indices.GetShape().GetDims()), ", it should be 2D");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(string(op_name), err_msg);
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(values, 1, unused_shape, op_name) != GRAPH_SUCCESS) {
|
||||
if (WithRank(values, 1, unused_shape, op) != GRAPH_SUCCESS) {
|
||||
std::string err_msg = ConcatString("failed to call WithRank function, values has wrong shape",
|
||||
DebugString(values.GetShape().GetDims()), ", it should be 1D");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(string(op_name), err_msg);
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (WithRank(shape, 1, unused_shape, op_name) != GRAPH_SUCCESS) {
|
||||
if (WithRank(shape, 1, unused_shape, op) != GRAPH_SUCCESS) {
|
||||
std::string err_msg = ConcatString("failed to call WithRank function, shape has wrong shape",
|
||||
DebugString(shape.GetShape().GetDims()), ", it should be 1D");
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(string(op_name), err_msg);
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
|
@ -998,9 +1026,8 @@ graphStatus ValidateSparseTensor(const TensorDesc &indices, const TensorDesc &va
|
|||
if (ValueKnown(indices_shape, 0)) {
|
||||
if (ValueKnown(values_shape, 0)) {
|
||||
if (indices_shape.GetDim(0) != values_shape.GetDim(0)) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(
|
||||
string(op_name), ConcatString("dim[0] of indices and dim[0] of value do not match, ", indices_shape.GetDim(0),
|
||||
" and ", values_shape.GetDim(0)));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("dim[0] of indices and dim[0] of value do not match, ",
|
||||
indices_shape.GetDim(0), " and ", values_shape.GetDim(0)));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -1011,9 +1038,8 @@ graphStatus ValidateSparseTensor(const TensorDesc &indices, const TensorDesc &va
|
|||
if (ValueKnown(indices_shape, 1)) {
|
||||
if (ValueKnown(sparse_shape, 0)) {
|
||||
if (indices_shape.GetDim(1) != sparse_shape.GetDim(0)) {
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(string(op_name),
|
||||
ConcatString("dim[1] of indices and dim[0] of sparse do not match, ",
|
||||
indices_shape.GetDim(1), " and ", sparse_shape.GetDim(0)));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, ConcatString("dim[1] of indices and dim[0] of sparse do not match, ",
|
||||
indices_shape.GetDim(1), " and ", sparse_shape.GetDim(0)));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -1086,45 +1112,38 @@ std::string DTypeStr(DataType dtype) {
|
|||
}
|
||||
|
||||
graphStatus SetShapeAndRange(Operator &op, const ShapeAndRange &feed_shape_and_range) {
|
||||
AscendString op_name;
|
||||
op.GetName(op_name);
|
||||
auto context = op.GetInferenceContext();
|
||||
std::vector<AscendString> marks;
|
||||
context->GetMarks(marks);
|
||||
|
||||
if (!marks.empty()) {
|
||||
OP_LOGI(TbeGetName(op).c_str(), "Set marks[0] = %s", marks[0].GetString());
|
||||
OP_LOGI(op, "Set marks[0] = %s", marks[0].GetString());
|
||||
bool shape_changed = false;
|
||||
auto aicpu_resource_context = reinterpret_cast<AicpuResourceContext *>(context->GetResourceContext(marks[0]));
|
||||
if (aicpu_resource_context == nullptr) {
|
||||
aicpu_resource_context = new (std::nothrow) AicpuResourceContext();
|
||||
if (aicpu_resource_context == nullptr) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(std::string(op_name.GetString()),
|
||||
std::string("new AicpuResourceContext failed."));
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(op, std::string("new AicpuResourceContext failed."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
aicpu_resource_context->shape_and_range_.push_back(feed_shape_and_range);
|
||||
if (context->SetResourceContext(marks[0], aicpu_resource_context) != GRAPH_SUCCESS) {
|
||||
delete aicpu_resource_context;
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(std::string(op_name.GetString()),
|
||||
std::string("set resource context failed."));
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, std::string("set resource context failed."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
shape_changed = true;
|
||||
} else {
|
||||
auto &shape_and_range = aicpu_resource_context->shape_and_range_;
|
||||
if (shape_and_range.empty()) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(std::string(op_name.GetString()),
|
||||
std::string("get resource context shape and ranges failed."));
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, std::string("get resource context shape and ranges failed."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
MergeShapeAndRange(shape_and_range[0], feed_shape_and_range, shape_and_range[0], shape_changed,
|
||||
op_name.GetString());
|
||||
MergeShapeAndRange(shape_and_range[0], feed_shape_and_range, shape_and_range[0], shape_changed, op);
|
||||
}
|
||||
if (shape_changed) {
|
||||
if (context->AddChangedResourceKey(marks[0]) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(std::string(op_name.GetString()),
|
||||
std::string("add change resource key failed."));
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(op, std::string("add change resource key failed."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -1133,23 +1152,19 @@ graphStatus SetShapeAndRange(Operator &op, const ShapeAndRange &feed_shape_and_r
|
|||
}
|
||||
|
||||
graphStatus GetShapeAndRange(Operator &op, ShapeAndRange &out, bool &geted, InferenceContextPtr infer_context) {
|
||||
AscendString op_name;
|
||||
op.GetName(op_name);
|
||||
std::vector<AscendString> marks;
|
||||
infer_context->GetMarks(marks);
|
||||
if (!marks.empty()) {
|
||||
OP_LOGI(TbeGetName(op).c_str(), "Get marks[0] = %s", marks[0].GetString());
|
||||
OP_LOGI(op, "Get marks[0] = %s", marks[0].GetString());
|
||||
if (infer_context->RegisterReliedOnResourceKey(marks[0]) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(std::string(op_name.GetString()),
|
||||
std::string("register relied on resource key failed."));
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(op, std::string("register relied on resource key failed."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
auto aicpu_resource_context = reinterpret_cast<AicpuResourceContext *>(infer_context->GetResourceContext(marks[0]));
|
||||
if (aicpu_resource_context != nullptr) {
|
||||
auto &shape_and_range = aicpu_resource_context->shape_and_range_;
|
||||
if (shape_and_range.empty()) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(std::string(op_name.GetString()),
|
||||
std::string("get resource context shape and ranges failed."));
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(op, std::string("get resource context shape and ranges failed."));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out.shape_ = shape_and_range[0].shape_;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright (c) 2022-2022 Huawei Technologies Co., Ltd. All rights reserved.
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,12 +18,11 @@
|
|||
* \file common_shape_fns.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_
|
||||
#define OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_
|
||||
#ifndef CUSTOMIZE_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_
|
||||
#define CUSTOMIZE_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "graph/tensor.h"
|
||||
#include "graph/operator.h"
|
||||
#include "graph/op_desc.h"
|
||||
|
@ -32,6 +31,7 @@
|
|||
#include "error_code.h"
|
||||
|
||||
namespace ge {
|
||||
|
||||
struct ShapeAndRange {
|
||||
Shape shape_;
|
||||
std::vector<std::pair<int64_t, int64_t>> shape_range_;
|
||||
|
@ -42,6 +42,25 @@ struct AicpuResourceContext : public ResourceContext {
|
|||
std::vector<ShapeAndRange> shape_and_range_;
|
||||
};
|
||||
|
||||
/**
|
||||
* Check whether Shape's rank is at least rank
|
||||
* @param tensor Input tensor
|
||||
* @param rank expect val of Shape
|
||||
* @param out Output Shape
|
||||
* @return status whether Shape's condition Satisfied
|
||||
*/
|
||||
graphStatus WithRankAtLeast(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Check whether Shape's rank is at least rank
|
||||
* @param tensor Input tensor
|
||||
* @param rank expect val of Shape
|
||||
* @param out Output Shape
|
||||
* @return status whether Shape's condition Satisfied
|
||||
*/
|
||||
graphStatus WithRankAtLeast(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape,
|
||||
const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Check whether Shape's rank is at least rank
|
||||
* @param tensor Input tensor
|
||||
|
@ -67,7 +86,7 @@ graphStatus WithRankAtLeast(const GeTensorDescPtr &tensorDesc, int64_t rank, GeS
|
|||
* @param out Output Shape
|
||||
* @return status whether Shape's condition Satisfied
|
||||
*/
|
||||
graphStatus WithRankShape(GeShape &shape, int64_t rank, const char *op_name);
|
||||
graphStatus WithRankShape(GeShape &shape, int64_t rank, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Check whether Shape's rank is equal to rank
|
||||
|
@ -76,7 +95,7 @@ graphStatus WithRankShape(GeShape &shape, int64_t rank, const char *op_name);
|
|||
* @param out Output Shape
|
||||
* @return status whether Shape's condition Satisfied
|
||||
*/
|
||||
graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const char *op_name);
|
||||
graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Check whether Shape's rank is equal to rank
|
||||
|
@ -85,7 +104,7 @@ graphStatus WithRank(const TensorDesc &tensor, int64_t rank, Shape &out, const c
|
|||
* @param out Output Shape
|
||||
* @return status whether Shape's condition Satisfied
|
||||
*/
|
||||
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape, const char *op_name);
|
||||
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Check whether Shape's rank is equal to rank
|
||||
|
@ -94,7 +113,7 @@ graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &o
|
|||
* @param out Output Shape
|
||||
* @return status whether Shape's condition Satisfied
|
||||
*/
|
||||
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, Shape &out_shape, const char *op_name);
|
||||
graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, Shape &out_shape, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Check whether dim is equal to value
|
||||
|
@ -103,7 +122,7 @@ graphStatus WithRank(const GeTensorDescPtr &tensorDesc, int64_t rank, Shape &out
|
|||
* @param out Output dim
|
||||
* @return status whether Dim is equal to value
|
||||
*/
|
||||
graphStatus WithValue(int64_t dim, int64_t value, int64_t &out, const char *op_name);
|
||||
graphStatus WithValue(int64_t dim, int64_t value, int64_t &out, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Merge two shapes
|
||||
|
@ -113,7 +132,7 @@ graphStatus WithValue(int64_t dim, int64_t value, int64_t &out, const char *op_n
|
|||
* @param prefix_out prefix out shape val
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape &prefix_out, const char *op_name);
|
||||
graphStatus MergePrefix(const Shape &s, const Shape &prefix, Shape &s_out, Shape &prefix_out, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Merge two dims of Shape
|
||||
|
@ -131,7 +150,7 @@ graphStatus Merge(int64_t dim1, int64_t dim2, int64_t &out);
|
|||
* @param out merged shape val
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const char *op_name);
|
||||
graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Merge two shapes
|
||||
|
@ -140,7 +159,7 @@ graphStatus Merge(const Shape &s0, const Shape &s1, Shape &out, const char *op_n
|
|||
* @param out merged Geshape val
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus Merge(const GeShape &s0, const GeShape &s1, GeShape &out, const char *op_name);
|
||||
graphStatus Merge(const GeShape &s0, const GeShape &s1, GeShape &out, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Merge two shapes
|
||||
|
@ -171,7 +190,7 @@ void MergeRange(const std::vector<std::pair<int64_t, int64_t>> &shared_shape_ran
|
|||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus MergeShapeAndRange(const ShapeAndRange &shared_shape_and_range, const ShapeAndRange &value_shape_and_range,
|
||||
ShapeAndRange &out, bool &shape_changed, const char *op_name);
|
||||
ShapeAndRange &out, bool &shape_changed, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Replace one dim in a given shape
|
||||
|
@ -181,7 +200,7 @@ graphStatus MergeShapeAndRange(const ShapeAndRange &shared_shape_and_range, cons
|
|||
* @param out new shape
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Shape &out, const char *op_name);
|
||||
graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Shape &out, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Replace one dim in a given shape
|
||||
|
@ -191,7 +210,7 @@ graphStatus ReplaceDim(const Shape &s, int64_t dim_index_in, int64_t new_dim, Sh
|
|||
* @param out new shape
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus ReplaceDim(const GeShape &s, int64_t dim_index_in, int64_t new_dim, GeShape &out, const char *op_name);
|
||||
graphStatus ReplaceDim(const GeShape &s, int64_t dim_index_in, int64_t new_dim, GeShape &out, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Check if it satisfies 0 <= index < limit
|
||||
|
@ -218,7 +237,7 @@ graphStatus Add(int64_t dim1, int64_t dim2, int64_t &out);
|
|||
* @param out Subtract dim val
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const char *op_name);
|
||||
graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Get SubShape according to start end index and step size stride
|
||||
|
@ -229,7 +248,7 @@ graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t &out, const char *op_na
|
|||
* @param out sub shape output
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride, Shape &out, const char *op_name);
|
||||
graphStatus SubShape(const Shape &s, int64_t start, int64_t end, int64_t stride, Shape &out, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Get SubShape according to start end index and step size stride
|
||||
|
@ -251,7 +270,8 @@ graphStatus SubShape(const GeShape &s, size_t start, size_t end, size_t stride,
|
|||
* @param out sub shape output
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus SubShape(const GeShape &s, int64_t start, int64_t end, int64_t stride, GeShape &out, const char *op_name);
|
||||
graphStatus SubShape(const GeShape &s, int64_t start, int64_t end, int64_t stride, GeShape &out,
|
||||
const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Concatenate two shape
|
||||
|
@ -294,17 +314,16 @@ graphStatus Vector(int64_t dim, Shape &out);
|
|||
* @param out shape
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus MakeShapeFromShapeTensor(const Tensor &tensor, Shape &out, const char *op_name);
|
||||
graphStatus MakeShapeFromShapeTensor(const Tensor &tensor, Shape &out, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Make shape from shape tensor
|
||||
* @param op Operator
|
||||
* @param dst_name const string &
|
||||
* @param out GeShape
|
||||
* @param op_name const char *
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus MakeShapeFromShapeTensor(Operator &op, const string &dst_name, GeShape &out, const char *op_name);
|
||||
graphStatus MakeShapeFromShapeTensor(Operator &op, const string &dst_name, GeShape &out);
|
||||
|
||||
/**
|
||||
* Make dim from scalar tensor
|
||||
|
@ -312,7 +331,7 @@ graphStatus MakeShapeFromShapeTensor(Operator &op, const string &dst_name, GeSha
|
|||
* @param out shape
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus MakeDimForScalarInput(const Tensor &tensor, int64_t &out, const char *op_name);
|
||||
graphStatus MakeDimForScalarInput(const Tensor &tensor, int64_t &out, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Check whether Shape's rank is at most rank
|
||||
|
@ -321,7 +340,7 @@ graphStatus MakeDimForScalarInput(const Tensor &tensor, int64_t &out, const char
|
|||
* @param out output Shape
|
||||
* @return status whether Shape's condition Satisfied
|
||||
*/
|
||||
graphStatus WithRankAtMost(const TensorDesc &tensor, int64_t rank, Shape &out, const char *op_name);
|
||||
graphStatus WithRankAtMost(const TensorDesc &tensor, int64_t rank, Shape &out, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Check whether Shape's rank is at most rank
|
||||
|
@ -330,7 +349,7 @@ graphStatus WithRankAtMost(const TensorDesc &tensor, int64_t rank, Shape &out, c
|
|||
* @param out output Shape
|
||||
* @return status whether Shape's condition Satisfied
|
||||
*/
|
||||
graphStatus WithRankAtMost(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape, const char *op_name);
|
||||
graphStatus WithRankAtMost(const GeTensorDescPtr &tensorDesc, int64_t rank, GeShape &out_shape, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* make a empty dim shape
|
||||
|
@ -357,7 +376,7 @@ graphStatus UnchangedShape(Operator &op, const string input_name, const string o
|
|||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus Divide(const int64_t dividend, const int64_t divisor, const bool evenlyDivisible, int64_t &out,
|
||||
const char *op_name);
|
||||
const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* check shape fully defined or not
|
||||
|
@ -410,7 +429,7 @@ bool ValueKnown(const Shape &shape, const size_t &dim_index);
|
|||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus ValidateSparseTensor(const TensorDesc &indices, const TensorDesc &values, const TensorDesc &shape,
|
||||
const char *op_name);
|
||||
const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Fill op_desc with input shape
|
||||
|
@ -448,6 +467,7 @@ std::string DTypeStr(DataType dtype);
|
|||
graphStatus SetShapeAndRange(Operator &op, const ShapeAndRange &feed_shape_and_range);
|
||||
|
||||
graphStatus GetShapeAndRange(Operator &op, ShapeAndRange &out, bool &geted, InferenceContextPtr infer_context);
|
||||
|
||||
} // namespace ge
|
||||
|
||||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_
|
||||
#endif // CUSTOMIZE_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef CANN_OPS_BUILT_IN_CONTEXT_UTIL_H_
|
||||
#define CANN_OPS_BUILT_IN_CONTEXT_UTIL_H_
|
||||
#ifndef CANN_CUSTOMIZE_CONTEXT_UTIL_H_
|
||||
#define CANN_CUSTOMIZE_CONTEXT_UTIL_H_
|
||||
|
||||
#include "runtime/infer_shape_context.h"
|
||||
#include "runtime/tiling_context.h"
|
||||
|
@ -43,4 +43,4 @@ namespace ops {
|
|||
return ret; \
|
||||
}
|
||||
} // namespace ops
|
||||
#endif // CANN_OPS_BUILT_IN_CONTEXT_UTIL_H_
|
||||
#endif // CANN_CUSTOMIZE_CONTEXT_UTIL_H_
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
* \file error_code.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef OPS_COMMON_INC_ERROR_CODE_H_
|
||||
#define OPS_COMMON_INC_ERROR_CODE_H_
|
||||
#ifndef CUSTOMIZE_OP_PROTO_UTILS_ERROR_CODE_H_
|
||||
#define CUSTOMIZE_OP_PROTO_UTILS_ERROR_CODE_H_
|
||||
|
||||
namespace ge {
|
||||
// error code for report purpose.
|
||||
|
@ -59,4 +59,4 @@ enum ViewErrorCode {
|
|||
};
|
||||
} // namespace ge
|
||||
|
||||
#endif // OPS_COMMON_INC_ERROR_CODE_H_
|
||||
#endif // CUSTOMIZE_OP_PROTO_UTILS_ERROR_CODE_H_
|
||||
|
|
|
@ -18,13 +18,12 @@
|
|||
* \file error_util.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef OPS_COMMON_INC_ERROR_UTIL_H_
|
||||
#define OPS_COMMON_INC_ERROR_UTIL_H_
|
||||
#ifndef CUSTOMIZE_OP_PROTO_UTILS_ERROR_UTIL_H_
|
||||
#define CUSTOMIZE_OP_PROTO_UTILS_ERROR_UTIL_H_
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "common/util/error_manager/error_manager.h"
|
||||
#include "error_code.h"
|
||||
#include "graph/operator.h"
|
||||
|
@ -227,4 +226,4 @@ void GeInfershapeErrReport(const std::string &op_name, const std::string &op_typ
|
|||
void CommonRuntimeErrLog(const std::string &opname, const std::string &description);
|
||||
} // namespace ge
|
||||
|
||||
#endif // OPS_COMMON_INC_ERROR_UTIL_H_
|
||||
#endif // CUSTOMIZE_OP_PROTO_UTILS_ERROR_UTIL_H_
|
||||
|
|
|
@ -0,0 +1,225 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file images_ops_shape_fns.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "image_ops_shape_fns.h"
|
||||
#include "op_log.h"
|
||||
#include "error_util.h"
|
||||
#include "graph/utils/op_desc_utils.h"
|
||||
|
||||
namespace ge {
|
||||
graphStatus ColorspaceShapeFn(Operator &op, const std::string &output_name) {
|
||||
Shape shape;
|
||||
graphStatus status = WithRankAtLeast(op.GetInputDesc(0), 1, shape, op);
|
||||
if (status != GRAPH_SUCCESS) {
|
||||
AscendString op_name;
|
||||
op.GetName(op_name);
|
||||
OP_LOGE(op_name.GetString(), "input[images] must 1-D or higher rank.");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
int64_t dim = op.GetInputDesc(0).GetShape().GetDims().back();
|
||||
if (dim != 3) {
|
||||
AscendString op_name;
|
||||
op.GetName(op_name);
|
||||
OP_LOGE(op_name.GetString(), "input[images] last dimension must be size 3.");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
TensorDesc desc = op.GetOutputDescByName(output_name.c_str());
|
||||
desc.SetShape(Shape(shape));
|
||||
return op.UpdateOutputDesc(output_name.c_str(), desc);
|
||||
}
|
||||
|
||||
graphStatus ResizeShapeFn(Operator &op, const std::string &input_name, const std::string &size_input_name,
|
||||
const std::string &output_name) {
|
||||
if (op.GetInputDesc(0).GetShape().GetDims() == UNKNOWN_RANK) {
|
||||
std::vector<int64_t> output_shape(4, UNKNOWN_DIM);
|
||||
TensorDesc td = op.GetOutputDescByName(output_name.c_str());
|
||||
td.SetShape(Shape(output_shape));
|
||||
op.UpdateOutputDesc(output_name.c_str(), td);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
Shape shape;
|
||||
graphStatus status = WithRank(op.GetInputDesc(0), 4, shape, op);
|
||||
if (status != GRAPH_SUCCESS) {
|
||||
AscendString op_name;
|
||||
op.GetName(op_name);
|
||||
OP_LOGE(op_name.GetString(), "input[images] must 4-D.");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
auto dims = op.GetInputDesc(0).GetShape().GetDims();
|
||||
auto channel_dim = dims[3];
|
||||
TensorDesc input_td = op.GetInputDesc(0);
|
||||
if (static_cast<ge::Format>(ge::GetPrimaryFormat(input_td.GetFormat())) == FORMAT_NCHW) {
|
||||
channel_dim = dims[1];
|
||||
}
|
||||
return SetOutputToSizedImage(op, dims[0], size_input_name, channel_dim, output_name);
|
||||
}
|
||||
|
||||
graphStatus SetOutputToSizedImage(Operator &op, const int64_t batch_dim, const std::string &size_input_name,
|
||||
const int64_t channel_dim, const std::string &output_name) {
|
||||
Shape size_shape;
|
||||
graphStatus status = WithRank(op.GetInputDescByName(size_input_name.c_str()), 1, size_shape, op);
|
||||
if (status != GRAPH_SUCCESS) {
|
||||
AscendString op_name;
|
||||
op.GetName(op_name);
|
||||
OP_LOGE(op_name.GetString(), "input size must be 1-D.");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
auto size_dims = op.GetInputDescByName(size_input_name.c_str()).GetShape().GetDims();
|
||||
if (size_dims[0] != 2) {
|
||||
AscendString op_name;
|
||||
op.GetName(op_name);
|
||||
OP_LOGE(op_name.GetString(), "input size must be 1-D tensor of 2 elements.");
|
||||
return GRAPH_PARAM_INVALID;
|
||||
}
|
||||
|
||||
std::vector<std::string> input_infer_depends = {size_input_name};
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
op_desc->SetOpInferDepends(input_infer_depends);
|
||||
|
||||
DataType data_type = DT_FLOAT;
|
||||
// Update DataType when Attr "dtype" is set, used for op ResizeBicubic
|
||||
if (op.GetAttr("dtype", data_type) == GRAPH_SUCCESS) {
|
||||
if ((data_type != DT_FLOAT) && (data_type != DT_UINT8)) {
|
||||
OP_LOGW(op_desc->GetName().c_str(), "Attr dtype should only be DT_FLOAT or DT_UNIT8");
|
||||
} else {
|
||||
OP_LOGI(op_desc->GetName().c_str(), "Update DataType from attr, which is %d", data_type);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor size_tensor;
|
||||
TensorDesc td = op.GetOutputDescByName(output_name.c_str());
|
||||
status = op.GetInputConstData(size_input_name.c_str(), size_tensor);
|
||||
if (status != GRAPH_SUCCESS) {
|
||||
td.SetDataType(data_type);
|
||||
std::vector<int64_t> out_shape;
|
||||
TensorDesc input_td = op.GetInputDesc(0);
|
||||
if (static_cast<ge::Format>(ge::GetPrimaryFormat(input_td.GetFormat())) == FORMAT_NCHW) {
|
||||
out_shape.push_back(batch_dim);
|
||||
out_shape.push_back(channel_dim);
|
||||
out_shape.push_back(-1);
|
||||
out_shape.push_back(-1);
|
||||
} else if (static_cast<ge::Format>(ge::GetPrimaryFormat(input_td.GetFormat())) == FORMAT_NHWC) {
|
||||
out_shape.push_back(batch_dim);
|
||||
out_shape.push_back(-1);
|
||||
out_shape.push_back(-1);
|
||||
out_shape.push_back(channel_dim);
|
||||
} else {
|
||||
std::string error_msg = "Not supported this format";
|
||||
AICPU_INFER_SHAPE_CALL_ERR_REPORT(TbeGetName(op), error_msg);
|
||||
}
|
||||
td.SetShape(Shape(out_shape));
|
||||
op.UpdateOutputDesc(output_name.c_str(), td);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
const int32_t *size_data = reinterpret_cast<const int32_t *>(size_tensor.GetData());
|
||||
|
||||
int64_t size_width = static_cast<int64_t>(size_data[1]);
|
||||
int64_t size_height = static_cast<int64_t>(size_data[0]);
|
||||
std::vector<int64_t> output_shape;
|
||||
|
||||
TensorDesc input_td = op.GetInputDesc(0);
|
||||
if (static_cast<ge::Format>(ge::GetPrimaryFormat(input_td.GetFormat())) == FORMAT_NCHW) {
|
||||
output_shape.push_back(batch_dim);
|
||||
output_shape.push_back(channel_dim);
|
||||
output_shape.push_back(size_height);
|
||||
output_shape.push_back(size_width);
|
||||
} else if (static_cast<ge::Format>(ge::GetPrimaryFormat(input_td.GetFormat())) == FORMAT_NHWC) {
|
||||
output_shape.push_back(batch_dim);
|
||||
output_shape.push_back(size_height);
|
||||
output_shape.push_back(size_width);
|
||||
output_shape.push_back(channel_dim);
|
||||
} else {
|
||||
OP_LOGE(TbeGetName(op).c_str(), "Not supported this format");
|
||||
}
|
||||
td.SetShape(Shape(output_shape));
|
||||
return op.UpdateOutputDesc(output_name.c_str(), td);
|
||||
}
|
||||
|
||||
graphStatus EncodeImageShapeFn(Operator &op) {
|
||||
Shape unused_shape;
|
||||
if (WithRank(op.GetInputDesc(0), 3, unused_shape, op) != GRAPH_SUCCESS) {
|
||||
AscendString op_name;
|
||||
op.GetName(op_name);
|
||||
OP_LOGE(op_name.GetString(), "input rank must be 3 .");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Shape output_shape;
|
||||
(void)Scalar(output_shape);
|
||||
TensorDesc output_tensor = op.GetOutputDescByName("contents");
|
||||
output_tensor.SetDataType(DT_STRING);
|
||||
output_tensor.SetShape(output_shape);
|
||||
return op.UpdateOutputDesc("contents", output_tensor);
|
||||
}
|
||||
|
||||
graphStatus DecodeImageShapeFn(Operator &op) {
|
||||
int channels;
|
||||
if (op.GetAttr("channels", channels) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("Get attr[channels] failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (channels != 0 && channels != 1 && channels != 3 && channels != 4) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("attr[Channels] must be 0,1,3,or 4"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
DataType dtype;
|
||||
if (op.GetAttr("dtype", dtype) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("Get attr[dtype] failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
std::vector<int64_t> dims;
|
||||
if (channels == 0) {
|
||||
dims = {ge::UNKNOWN_DIM, ge::UNKNOWN_DIM, ge::UNKNOWN_DIM};
|
||||
} else {
|
||||
dims = {ge::UNKNOWN_DIM, ge::UNKNOWN_DIM, channels};
|
||||
}
|
||||
|
||||
Shape output_shape(dims);
|
||||
TensorDesc output_tensor = op.GetOutputDesc(0);
|
||||
output_tensor.SetDataType(dtype);
|
||||
output_tensor.SetShape(output_shape);
|
||||
if (op.UpdateOutputDesc("image", output_tensor) != GRAPH_SUCCESS) {
|
||||
AICPU_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), string("Update OutputDesc[image] failed"));
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
bool DimsAllEqualOrUnknown(std::initializer_list<int64_t> &&inputs, int64_t unknown_dim_val) {
|
||||
auto it = inputs.begin();
|
||||
for (; it != inputs.end() && (*it == unknown_dim_val); ++it) {
|
||||
}
|
||||
|
||||
if (it == inputs.end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (auto default_dim_val = *(it++); it != inputs.end(); ++it) {
|
||||
if (*it != default_dim_val && *it != unknown_dim_val) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace ge
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file images_ops_shape_fns.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef CUSTOMIZE_OP_PROTO_UTIL_IMAGES_OPS_SHAPE_FNS_H_
|
||||
#define CUSTOMIZE_OP_PROTO_UTIL_IMAGES_OPS_SHAPE_FNS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
/**
|
||||
* ColorspaceShapeFn, infereshape function of colorspace op
|
||||
* @param op, Operators that need to reason about shape
|
||||
* @param output_name, the name of output
|
||||
* @return status whether infer shape success
|
||||
*/
|
||||
graphStatus ColorspaceShapeFn(Operator &op, const std::string &output_name);
|
||||
|
||||
/**
|
||||
* ResizeShapeFn, infereshape function of image resize op
|
||||
* @param op, Operators that need to reason about shape
|
||||
* @param input_name, the name of input
|
||||
* @param size_input_name, the name of size input name
|
||||
* @param output_name, the name of output
|
||||
* @return status whether infer shape success
|
||||
*/
|
||||
graphStatus ResizeShapeFn(Operator &op, const std::string &input_name, const std::string &size_input_name,
|
||||
const std::string &output_name);
|
||||
|
||||
/**
|
||||
* SetOutputToSizedImage, set output shape of size image op
|
||||
* @param op, Operators that need to set output shape
|
||||
* @param batch_dim, the dim of batch
|
||||
* @param size_input_name, the name of size input
|
||||
* @param channel_dim, the dim of channel
|
||||
* @param output_name, the name of output
|
||||
* @return status whether set output shape success
|
||||
*/
|
||||
graphStatus SetOutputToSizedImage(Operator &op, const int64_t batch_dim, const std::string &size_input_name,
|
||||
const int64_t channel_dim, const std::string &output_name);
|
||||
|
||||
/**
|
||||
* EncodeImageShapeFn, infereshape function of EncodeImage op
|
||||
* @param op, Operators that need to reason about shape
|
||||
* @return status whether infer shape success
|
||||
*/
|
||||
graphStatus EncodeImageShapeFn(Operator &op);
|
||||
|
||||
/**
|
||||
* DecodeImageShapeFn, infereshape function of DecodeImage op
|
||||
* @param op, Operators that need to reason about shape
|
||||
* @return status whether infer shape success
|
||||
*/
|
||||
graphStatus DecodeImageShapeFn(Operator &op);
|
||||
|
||||
/**
|
||||
* EncodeImageShapeFn, infereshape function of EncodeImage op
|
||||
* @param inputs, the list of impu dims
|
||||
* @param unknown_dim_val, the definithion of UNKNOWN_DIM
|
||||
* @return status whether infer shape success
|
||||
*/
|
||||
bool DimsAllEqualOrUnknown(std::initializer_list<int64_t> &&inputs, int64_t unknown_dim_val = UNKNOWN_DIM);
|
||||
|
||||
} // namespace ge
|
||||
|
||||
#endif // CUSTOMIZE_OP_PROTO_UTIL_IMAGES_OPS_SHAPE_FNS_H_
|
|
@ -0,0 +1,162 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file linalg_ops_shape_fns.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "linalg_ops_shape_fns.h"
|
||||
#include "op_log.h"
|
||||
#include "common_shape_fns.h"
|
||||
|
||||
namespace ge {
|
||||
constexpr int64_t kRnak = 2;
|
||||
constexpr int64_t kEnd = -2;
|
||||
|
||||
graphStatus MakeBatchSquareMatrix(const TensorDesc &tensor, Shape &out, const ge::Operator &op) {
|
||||
Shape s;
|
||||
if (WithRankAtLeast(tensor, kRnak, s, op) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "input tensor's rank at least 2.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
size_t existing = s.GetDimNum();
|
||||
int64_t dim1 = s.GetDim(existing - 2);
|
||||
int64_t dim2 = s.GetDim(existing - 1);
|
||||
|
||||
int64_t out_dim = 0;
|
||||
if (Merge(dim1, dim2, out_dim) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "Merge two dimension failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Shape batch_shape;
|
||||
if (SubShape(s, 0, kEnd, 1, batch_shape, op) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "Get SubShape batch_shape Failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (Concatenate(batch_shape, Shape({out_dim, out_dim}), out) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "Concatenate batch_shape and out_dim Failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus MakeBatchSquareMatrix(const GeTensorDescPtr &tensor_desc, GeShape &out, const ge::Operator &op) {
|
||||
GeShape ge_shape;
|
||||
if (WithRankAtLeast(tensor_desc, kRnak, ge_shape, op) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "Input tensor's rank at least 2.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
Shape s(ge_shape.GetDims());
|
||||
size_t existing = s.GetDimNum();
|
||||
int64_t dim1 = s.GetDim(existing - 2);
|
||||
int64_t dim2 = s.GetDim(existing - 1);
|
||||
|
||||
int64_t out_dim = 0;
|
||||
if (Merge(dim1, dim2, out_dim) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "Merge two dimension failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
if (RankKnown(ge_shape)) {
|
||||
GeShape batch_shape;
|
||||
if (SubShape(ge_shape, 0, kEnd, 1, batch_shape, op) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "Get subShape batch_shape failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (Concatenate(batch_shape, GeShape({out_dim, out_dim}), out) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "Concatenate batch_shape and out_dim failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
} else {
|
||||
GeShape unknown_shape(ge::UNKNOWN_SHAPE);
|
||||
out = unknown_shape;
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus MatrixSolve(const TensorDesc &tensor1, const TensorDesc &tensor2, bool square, Shape &out,
|
||||
const ge::Operator &op) {
|
||||
Shape lhs;
|
||||
Shape rhs;
|
||||
if (square) {
|
||||
if (MakeBatchSquareMatrix(tensor1, lhs, op) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "MatrixSolve first input tensor Make Batch Square Matrix failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
} else {
|
||||
if (WithRankAtLeast(tensor1, kRnak, lhs, op) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "MatrixSolve func first input tensor must be at least 2.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
if (WithRankAtLeast(tensor2, kRnak, rhs, op) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "MatrixSolve func second input tensor must be at least 2.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
Shape lhs_batch;
|
||||
Shape rhs_batch;
|
||||
// Make the common batch subshape.
|
||||
if (SubShape(lhs, 0, kEnd, 1, lhs_batch, op) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "SubShape lhs_batch in MatrixSolve func failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (SubShape(rhs, 0, kEnd, 1, rhs_batch, op) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "SubShape rhs_batch in MatrixSolve func failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
int64_t lhs_batch_dim;
|
||||
// Make sure the batch dimensions match between lhs and rhs.
|
||||
if (Merge(lhs_batch.GetDimNum(), rhs_batch.GetDimNum(), lhs_batch_dim) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "Merge dimension lhs_batch and rhs_batch failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
|
||||
int64_t dim_val = 0;
|
||||
int64_t lhs_rank = lhs.GetDimNum();
|
||||
int64_t rhs_rank = rhs.GetDimNum();
|
||||
int64_t dim_lhs = lhs.GetDim(lhs_rank - 2);
|
||||
int64_t dim_rhs = rhs.GetDim(rhs_rank - 2);
|
||||
// lhs and rhs have the same value for m to be compatible.
|
||||
if (Merge(dim_lhs, dim_rhs, dim_val) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "Merge dimension dim_lhs and dim_rhs failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
int64_t dim_ret = lhs.GetDim(lhs_rank - 1);
|
||||
if (square) {
|
||||
if (Merge(dim_val, dim_ret, dim_ret) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "Merge dimension dim_val and dim_ret failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
Shape s;
|
||||
// Build final shape (batch_shape + n + k) in <out>.
|
||||
if (Concatenate(lhs_batch, Shape({dim_ret}), s) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "Concatenate Two Shape failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
int64_t dims = rhs.GetDim(rhs_rank - 1);
|
||||
if (Concatenate(s, Shape({dims}), s) == GRAPH_FAILED) {
|
||||
OP_LOGE(op, "Concatenate Shape s and dims failed.");
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
out = s;
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
} // namespace ge
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file linalg_ops_shape_fns.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef CUSTOMIZE_OP_PROTO_UTIL_LINALG_OPS_SHAPE_FNS_H_
|
||||
#define CUSTOMIZE_OP_PROTO_UTIL_LINALG_OPS_SHAPE_FNS_H_
|
||||
|
||||
#include "graph/tensor.h"
|
||||
#include "graph/ge_tensor.h"
|
||||
#include "graph/op_desc.h"
|
||||
|
||||
namespace ge {
|
||||
|
||||
/**
|
||||
* Generate a square matrix's Shape
|
||||
* @param tensor Input tensor
|
||||
* @param out Output Shape
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus MakeBatchSquareMatrix(const TensorDesc &tensor, Shape &out, const ge::Operator &op);
|
||||
/**
|
||||
* Generate a square matrix's Shape
|
||||
* @param tensor Input ge tensor desc ptr
|
||||
* @param out Output GeShape
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus MakeBatchSquareMatrix(const GeTensorDescPtr &tensor_desc, GeShape &out, const ge::Operator &op);
|
||||
/**
|
||||
* Solving linear equations from matrices common shape func
|
||||
* @param tensor1 first input tensor
|
||||
* @param tensor2 second input tensor
|
||||
* @param square whether matrix is square
|
||||
* @param out Output Shape
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus MatrixSolve(const TensorDesc &tensor1, const TensorDesc &tensor2, bool square, Shape &out,
|
||||
const ge::Operator &op);
|
||||
|
||||
} // namespace ge
|
||||
|
||||
#endif // CUSTOMIZE_OP_PROTO_UTIL_LINALG_OPS_SHAPE_FNS_H_
|
|
@ -0,0 +1,176 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lookup_ops_shape_fns.cpp
|
||||
* \brief
|
||||
*/
|
||||
#include "lookup_ops_shape_fns.h"
|
||||
#include "common_shape_fns.h"
|
||||
#include "graph/utils/op_desc_utils.h"
|
||||
#include "error_util.h"
|
||||
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
|
||||
#include "op_log.h"
|
||||
|
||||
namespace ge {
|
||||
graphStatus ValidateTableResourceHandle(Shape keys, std::vector<ShapeAndType> handleData,
|
||||
ShapeAndType &output_shape_and_type, bool is_lookup, const ge::Operator &op) {
|
||||
Shape unknown_shape(ge::UNKNOWN_SHAPE);
|
||||
if (handleData.size() != 2) {
|
||||
output_shape_and_type.SetShape(unknown_shape);
|
||||
output_shape_and_type.SetType(DT_UNDEFINED);
|
||||
} else {
|
||||
const ShapeAndType &key_shape_and_type = handleData[0];
|
||||
const ShapeAndType &value_shape_and_type = handleData[1];
|
||||
// here need to check key_dtype and value_dtype
|
||||
// but can not get the attr type for key and value
|
||||
output_shape_and_type.SetType(value_shape_and_type.GetDataType());
|
||||
if (is_lookup) {
|
||||
if ((RankKnown(key_shape_and_type.GetShape()) == GRAPH_SUCCESS) && (RankKnown(keys) == GRAPH_SUCCESS)) {
|
||||
int keys_rank = keys.GetDims().size();
|
||||
int keys_suffix_rank = key_shape_and_type.GetShape().GetDims().size();
|
||||
if (keys_rank < keys_suffix_rank) {
|
||||
std::string err_msg = OtherErrMsg("Expected keys to have suffix");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
for (int d = 0; d < keys_suffix_rank; ++d) {
|
||||
int new_dim = key_shape_and_type.GetShape().GetDim(d);
|
||||
if (ReplaceDim(keys, keys_rank - keys_suffix_rank + d, new_dim, keys, op) == GRAPH_FAILED) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
std::vector<int64_t> keys_prefix_vec;
|
||||
keys_prefix_vec.reserve(keys_rank - keys_suffix_rank);
|
||||
for (int d = 0; d < keys_rank - keys_suffix_rank; ++d) {
|
||||
keys_prefix_vec.push_back(keys.GetDim(d));
|
||||
}
|
||||
Shape keys_prefix = Shape(keys_prefix_vec);
|
||||
Shape temp_shape = output_shape_and_type.GetShape();
|
||||
if (Concatenate(keys_prefix, value_shape_and_type.GetShape(), temp_shape) == GRAPH_FAILED) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
output_shape_and_type.SetShape(temp_shape);
|
||||
} else {
|
||||
output_shape_and_type.SetShape(unknown_shape);
|
||||
}
|
||||
} else {
|
||||
Shape temp_shape = output_shape_and_type.GetShape();
|
||||
if (Concatenate(keys, value_shape_and_type.GetShape(), temp_shape) == GRAPH_FAILED) {
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
output_shape_and_type.SetShape(temp_shape);
|
||||
}
|
||||
}
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
graphStatus ValidateTableResourceHandle(const Operator &op, Shape &keys, const DataType &key_dtype,
|
||||
const DataType &value_dtype, const bool &is_lookup,
|
||||
ShapeAndType &output_shape_and_type) {
|
||||
if (op.GetInferenceContext() == nullptr) {
|
||||
OP_LOGI(op, "Op inference context is null, return unknown shape");
|
||||
output_shape_and_type.SetShape(Shape(UNKNOWN_RANK));
|
||||
output_shape_and_type.SetType(DT_UNDEFINED);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
const auto &shapes_and_types = op.GetInferenceContext()->GetInputHandleShapesAndTypes();
|
||||
if (shapes_and_types.empty()) {
|
||||
OP_LOGI(op, "Context GetInputHandleShapesAndTypes result is empty, return unknown shape");
|
||||
output_shape_and_type.SetShape(Shape(UNKNOWN_RANK));
|
||||
output_shape_and_type.SetType(DT_UNDEFINED);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
auto handle_data = shapes_and_types[0];
|
||||
if (handle_data.size() != 2) {
|
||||
OP_LOGI(op, "handle data(shapes_and_types[0]) size is not 2, return unknown shape");
|
||||
output_shape_and_type.SetShape(Shape(UNKNOWN_RANK));
|
||||
output_shape_and_type.SetType(DT_UNDEFINED);
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
const ShapeAndType &key_shape_and_type = handle_data[0];
|
||||
const ShapeAndType &value_shape_and_type = handle_data[1];
|
||||
if (key_shape_and_type.GetDataType() != key_dtype) {
|
||||
std::string err_msg =
|
||||
GetInputDTypeErrMsg("key_dtype", ConcatString(key_shape_and_type.GetDataType()), ConcatString(key_dtype));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
if (value_shape_and_type.GetDataType() != value_dtype) {
|
||||
OP_LOGW(op, "trying to read value with wrong dtype, expected %d, got %d", value_shape_and_type.GetDataType(),
|
||||
value_dtype);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
output_shape_and_type.SetType(value_shape_and_type.GetDataType());
|
||||
|
||||
if (is_lookup) {
|
||||
if (RankKnown(key_shape_and_type.GetShape()) && RankKnown(keys)) {
|
||||
int64_t keys_rank = keys.GetDimNum();
|
||||
int64_t key_suffix_rank = key_shape_and_type.GetShape().GetDimNum();
|
||||
if (keys_rank < key_suffix_rank) {
|
||||
std::string err_msg =
|
||||
OtherErrMsg(ConcatString("Expected keys to have suffix ", key_suffix_rank, ", but saw shape ", keys_rank));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
for (int64_t d = 0; d < key_suffix_rank; ++d) {
|
||||
// Ensure the suffix of keys match what's in the Table.
|
||||
int64_t dim = key_shape_and_type.GetShape().GetDim(d);
|
||||
if (ReplaceDim(keys, keys_rank - key_suffix_rank + d, dim, keys, op) == GRAPH_FAILED) {
|
||||
std::string err_msg =
|
||||
OtherErrMsg(ConcatString("replace dim ", keys_rank - key_suffix_rank + d, " in keys failed"));
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> keys_prefix_vec;
|
||||
keys_prefix_vec.reserve(keys_rank - key_suffix_rank);
|
||||
for (int d = 0; d < keys_rank - key_suffix_rank; ++d) {
|
||||
keys_prefix_vec.push_back(keys.GetDim(d));
|
||||
}
|
||||
Shape keys_prefix(keys_prefix_vec);
|
||||
|
||||
auto temp_shape = output_shape_and_type.GetShape();
|
||||
if (Concatenate(keys_prefix, value_shape_and_type.GetShape(), temp_shape) == GRAPH_FAILED) {
|
||||
std::string err_msg = OtherErrMsg("concatenate keys_prefix and value shape failed");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
output_shape_and_type.SetShape(temp_shape);
|
||||
} else {
|
||||
output_shape_and_type.SetShape(Shape(UNKNOWN_RANK));
|
||||
}
|
||||
} else {
|
||||
auto temp_shape = output_shape_and_type.GetShape();
|
||||
if (Concatenate(keys, value_shape_and_type.GetShape(), temp_shape) == GRAPH_FAILED) {
|
||||
std::string err_msg = OtherErrMsg("concatenate keys and value shape failed");
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op, err_msg);
|
||||
return GRAPH_FAILED;
|
||||
}
|
||||
output_shape_and_type.SetShape(temp_shape);
|
||||
}
|
||||
|
||||
return GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace ge
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file lookup_ops_shape_fns.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef CUSTOMIZE_OP_PROTO_UTIL_LOOKUP_OPS_SHAPE_FNS_H_
|
||||
#define CUSTOMIZE_OP_PROTO_UTIL_LOOKUP_OPS_SHAPE_FNS_H_
|
||||
|
||||
#include <vector>
|
||||
#include "graph/tensor.h"
|
||||
#include "graph/inference_context.h"
|
||||
#include "graph/operator.h"
|
||||
#include "graph/op_desc.h"
|
||||
#include "graph/utils/op_desc_utils.h"
|
||||
|
||||
namespace ge {
|
||||
/**
|
||||
* Validate table resource handle
|
||||
* @param keys keys of the shape
|
||||
* @param handleData vector of handle data
|
||||
* @param output_shape_and_type shape and type that created
|
||||
* @param is_lookup if is lookup
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus ValidateTableResourceHandle(Shape keys, std::vector<ShapeAndType> handleData,
|
||||
ShapeAndType &output_shape_and_type, bool is_lookup, const ge::Operator &op);
|
||||
|
||||
/**
|
||||
* Validate table resource handle
|
||||
* @param op op context
|
||||
* @param keys keys of the shape
|
||||
* @param handleData vector of handle data
|
||||
* @param output_shape_and_type shape and type that created
|
||||
* @param is_lookup if is lookup
|
||||
* @return status whether this operation success
|
||||
*/
|
||||
graphStatus ValidateTableResourceHandle(const Operator &op, Shape &keys, const DataType &key_dtype,
|
||||
const DataType &value_dtype, const bool &is_lookup,
|
||||
ShapeAndType &output_shape_and_type);
|
||||
} // namespace ge
|
||||
|
||||
#endif // CUSTOMIZE_OP_PROTO_UTIL_LOOKUP_OPS_SHAPE_FNS_H_
|
|
@ -0,0 +1,317 @@
|
|||
/**
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file op_attr.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef CUSTOMIZE_OP_PROTO_UTILS_OP_ATTR_H_
|
||||
#define CUSTOMIZE_OP_PROTO_UTILS_OP_ATTR_H_
|
||||
|
||||
#include <vector>
|
||||
#include "op_log.h"
|
||||
#include "external/graph/operator.h"
|
||||
#include "graph/utils/op_desc_utils.h"
|
||||
#include "graph/utils/attr_utils.h"
|
||||
|
||||
namespace ops {
|
||||
using namespace ge;
|
||||
|
||||
// attr base struct
|
||||
struct AttrBase {
|
||||
const int32_t attr_idx;
|
||||
const std::string attr_name;
|
||||
AttrBase(const int attr_idx, const std::string &attr_name) : attr_idx(attr_idx), attr_name(attr_name) {}
|
||||
};
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const struct AttrBase &attr_info, int32_t &value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetInt(op_desc, attr_info.attr_name, value)) {
|
||||
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.attr_name.c_str());
|
||||
return false;
|
||||
}
|
||||
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.attr_name.c_str(), value);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const struct AttrBase &attr_info, int64_t &value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetInt(op_desc, attr_info.attr_name, value)) {
|
||||
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.attr_name.c_str());
|
||||
return false;
|
||||
}
|
||||
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.attr_name.c_str(), value);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @param [in] default_value: default_value
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const struct AttrBase &attr_info, int64_t &value, const int64_t default_value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetInt(op_desc, attr_info.attr_name, value)) {
|
||||
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. set the default value", attr_info.attr_name.c_str());
|
||||
value = default_value;
|
||||
}
|
||||
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.attr_name.c_str(), value);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const struct AttrBase &attr_info, uint64_t &value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetInt(op_desc, attr_info.attr_name, value)) {
|
||||
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.attr_name.c_str());
|
||||
return false;
|
||||
}
|
||||
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.attr_name.c_str(), value);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const std::pair<int64_t, std::string> &attr_info, int32_t &value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetInt(op_desc, attr_info.second, value)) {
|
||||
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.second.c_str());
|
||||
return false;
|
||||
}
|
||||
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.second.c_str(), value);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @param [in] default_value: default_value
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const std::pair<int64_t, std::string> &attr_info, int32_t &value,
|
||||
const int32_t default_value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetInt(op_desc, attr_info.second, value)) {
|
||||
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. set the default value", attr_info.second.c_str());
|
||||
value = default_value;
|
||||
}
|
||||
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.second.c_str(), value);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const std::pair<int64_t, std::string> &attr_info, int64_t &value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetInt(op_desc, attr_info.second, value)) {
|
||||
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.second.c_str());
|
||||
return false;
|
||||
}
|
||||
OP_LOGD("GetAttrValue", "Get the attr of %s is %lld", attr_info.second.c_str(), value);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @param [in] default_value: default_value
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const std::pair<int64_t, std::string> &attr_info, int64_t &value,
|
||||
const int64_t default_value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetInt(op_desc, attr_info.second, value)) {
|
||||
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. set the default value", attr_info.second.c_str());
|
||||
value = default_value;
|
||||
}
|
||||
OP_LOGD("GetAttrValue", "Get the attr of %s is %lld", attr_info.second.c_str(), value);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @param [in] default_value: default_value
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const std::pair<int64_t, std::string> &attr_info, uint32_t &value,
|
||||
const uint32_t default_value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetInt(op_desc, attr_info.second, value)) {
|
||||
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. set the default value", attr_info.second.c_str());
|
||||
value = default_value;
|
||||
}
|
||||
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.second.c_str(), value);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @param [in] default_value: default_value
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const std::pair<int64_t, std::string> &attr_info, bool &value,
|
||||
const bool default_value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetBool(op_desc, attr_info.second, value)) {
|
||||
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. set the default value", attr_info.second.c_str());
|
||||
value = default_value;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const std::pair<int64_t, std::string> &attr_info, bool &value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetBool(op_desc, attr_info.second, value)) {
|
||||
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.second.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const std::pair<int64_t, std::string> &attr_info, vector<int64_t> &value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetListInt(op_desc, attr_info.second, value)) {
|
||||
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.second.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const std::pair<int64_t, std::string> &attr_info, vector<int32_t> &value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetListInt(op_desc, attr_info.second, value)) {
|
||||
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.second.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const struct AttrBase &attr_info, int32_t &value, const int32_t default_value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetInt(op_desc, attr_info.attr_name, value)) {
|
||||
OP_LOGW("GetAttrValue", "Fail to get attr %s automatically. use default value", attr_info.attr_name.c_str());
|
||||
value = default_value;
|
||||
}
|
||||
OP_LOGD("GetAttrValue", "Get the attr of %s is %d", attr_info.attr_name.c_str(), value);
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* @brief: read constvalue from paras store into values
|
||||
* @param [in] paras: ge::Operator
|
||||
* @param [in] attr_info: attr info pair(attr_idx, attr_name)
|
||||
* @param [out] value: store value.
|
||||
* @return bool: flag of success or not
|
||||
*/
|
||||
template <typename T>
|
||||
bool GetAttrValue(const T ¶s, const struct AttrBase &attr_info, float32_t &value) {
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(paras);
|
||||
if (!AttrUtils::GetFloat(op_desc, attr_info.attr_name, value)) {
|
||||
OP_LOGW("GetAttrValue", "Get the attr of %s is failed. return false", attr_info.attr_name.c_str());
|
||||
return false;
|
||||
}
|
||||
OP_LOGD("GetAttrValue", "Get the attr of %s is %f", attr_info.attr_name.c_str(), value);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
#endif // CUSTOMIZE_OP_PROTO_UTILS_OP_ATTR_H_
|
|
@ -19,8 +19,8 @@
|
|||
* \brief common util for op, in this file only original type or class in C++ allowed
|
||||
*/
|
||||
|
||||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_
|
||||
#define OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_
|
||||
#ifndef CUSTOMIZE_OP_PROTO_UTIL_OP_COMMON_UTIL_H_
|
||||
#define CUSTOMIZE_OP_PROTO_UTIL_OP_COMMON_UTIL_H_
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
@ -69,4 +69,4 @@ std::string to_string(const std::set<T> &items) {
|
|||
}
|
||||
} // namespace ops
|
||||
|
||||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_
|
||||
#endif // CUSTOMIZE_OP_PROTO_UTIL_OP_COMMON_UTIL_H_
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef CANN_OPS_BUILT_IN_OPS_CONST_H_
|
||||
#define CANN_OPS_BUILT_IN_OPS_CONST_H_
|
||||
#ifndef CANN_CUSTOMIZE_OPS_CONST_H_
|
||||
#define CANN_CUSTOMIZE_OPS_CONST_H_
|
||||
|
||||
#include <vector>
|
||||
#include "external/graph/operator.h"
|
||||
|
@ -283,4 +283,4 @@ bool GetConstIntToShape(T *context, const int64_t const_idx, gert::Shape &const_
|
|||
return true;
|
||||
}
|
||||
} // namespace ops
|
||||
#endif // CANN_OPS_BUILT_IN_OPS_CONST_H_
|
||||
#endif // CANN_CUSTOMIZE_OPS_CONST_H_
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
* \file op_log.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef GE_OP_LOG_H
|
||||
#define GE_OP_LOG_H
|
||||
#ifndef CUSTOMIZE_OP_PROTO_UTILS_OP_LOG_H
|
||||
#define CUSTOMIZE_OP_PROTO_UTILS_OP_LOG_H
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
@ -71,6 +71,12 @@ typename std::enable_if<std::is_same<ge::OpDescPtr, typename std::decay<T>::type
|
|||
return op_desc != nullptr ? op_desc->GetName().c_str() : "nil";
|
||||
}
|
||||
|
||||
template <class T>
|
||||
typename std::enable_if<std::is_same<ge::Operator, typename std::decay<T>::type>::value, const char *>::type get_cstr(
|
||||
const T &op) {
|
||||
return op.GetName().c_str();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string TbeGetName(const T &op) {
|
||||
ge::AscendString op_ascend_name;
|
||||
|
@ -222,30 +228,35 @@ std::string TbeGetOpType(const T &op) {
|
|||
#define D_FUSION_PASS_LOGD(fmt, ...)
|
||||
#endif
|
||||
|
||||
#define OP_LOGE_IF(condition, return_value, op_name, fmt, ...) \
|
||||
do { \
|
||||
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \
|
||||
if (condition) { \
|
||||
OP_LOGE(get_cstr(op_name), fmt, ##__VA_ARGS__); \
|
||||
return return_value; \
|
||||
} \
|
||||
#define OP_LOGE_IF(condition, return_value, op_name, fmt, ...) \
|
||||
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \
|
||||
do { \
|
||||
if (condition) { \
|
||||
OP_LOGE(get_cstr(op_name), fmt, ##__VA_ARGS__); \
|
||||
return return_value; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGW_IF(condition, op_name, fmt, ...) \
|
||||
do { \
|
||||
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \
|
||||
if (condition) { \
|
||||
OP_LOGW(get_cstr(op_name), fmt, ##__VA_ARGS__); \
|
||||
} \
|
||||
#define OP_LOGW_IF(condition, op_name, fmt, ...) \
|
||||
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \
|
||||
do { \
|
||||
if (condition) { \
|
||||
OP_LOGW(get_cstr(op_name), fmt, ##__VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define OP_LOGI_IF_RETURN(condition, return_value, op_name, fmt, ...) \
|
||||
do { \
|
||||
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \
|
||||
if (condition) { \
|
||||
OP_LOGI(get_cstr(op_name), fmt, ##__VA_ARGS__); \
|
||||
return return_value; \
|
||||
} \
|
||||
#define OP_LOGI_IF_RETURN(condition, return_value, op_name, fmt, ...) \
|
||||
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \
|
||||
do { \
|
||||
if (condition) { \
|
||||
OP_LOGI(get_cstr(op_name), fmt, ##__VA_ARGS__); \
|
||||
return return_value; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#endif // GE_OP_LOG_H
|
||||
inline std::ostream &operator<<(std::ostream &os, const ge::Operator &op) {
|
||||
os << op.GetName();
|
||||
return os;
|
||||
}
|
||||
|
||||
#endif // CUSTOMIZE_OP_PROTO_UTILS_OP_LOG_H
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef CANN_OPS_BUILT_IN_OP_UTIL_H_
|
||||
#define CANN_OPS_BUILT_IN_OP_UTIL_H_
|
||||
#ifndef CANN_CUSTOMIZE_OP_UTIL_H_
|
||||
#define CANN_CUSTOMIZE_OP_UTIL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
@ -218,4 +218,4 @@ inline bool IsConstTensor(const gert::Tensor *input_tensor) {
|
|||
return (input_tensor != nullptr) && (input_tensor->GetAddr() != nullptr);
|
||||
}
|
||||
} // namespace ops
|
||||
#endif // CANN_OPS_BUILT_IN_OP_UTIL_H_
|
||||
#endif // CANN_CUSTOMIZE_OP_UTIL_H_
|
||||
|
|
|
@ -0,0 +1,370 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file reduce_infer_util.cc
|
||||
* \brief
|
||||
*/
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "util.h"
|
||||
#include "reduce_infer_util.h"
|
||||
#include "vector_proto_profiling.h"
|
||||
|
||||
namespace reduce_ops {
|
||||
using namespace std;
|
||||
using namespace ge;
|
||||
|
||||
constexpr int64_t UNKNOWN_DIM_VALUE = -2;
|
||||
|
||||
static bool ConvertAxis(std::vector<int64_t> &axis, const int64_t input_len) {
|
||||
const int64_t input_length = input_len == 0 ? 1 : input_len;
|
||||
// Convert reduce axis
|
||||
for (size_t i = 0; i < axis.size(); ++i) {
|
||||
if (axis[i] < -input_length || axis[i] > (input_length - 1)) {
|
||||
OP_LOGE("ReduceOps", "reduce verify failed, axis: %ld, input_length: %ld", axis[i], input_length);
|
||||
return false;
|
||||
}
|
||||
if (axis[i] < 0) {
|
||||
axis[i] = input_length + axis[i];
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DoReduceInfershapeWithAxesKeepdims(const GeShape &input_shape, std::vector<int64_t> &reduce_axes,
|
||||
GeShape &output_shape) {
|
||||
// case0: input is {-2}, set the output {-2}
|
||||
if (input_shape.IsUnknownDimNum()) {
|
||||
OP_LOGD("ReduceOps", "do unknownrank infershape for Reduce, output is {-2}");
|
||||
output_shape.SetIsUnknownDimNum();
|
||||
return true;
|
||||
}
|
||||
|
||||
auto input_shape_len = input_shape.GetDimNum();
|
||||
OP_LOGD("ReduceOps", "input shape = %s, axes = %s", to_string(input_shape).c_str(), to_string(reduce_axes).c_str());
|
||||
if (!ConvertAxis(reduce_axes, static_cast<int64_t>(input_shape_len))) {
|
||||
OP_LOGE("ReduceOps", "do ConvertAxis failed, will return false");
|
||||
return false;
|
||||
}
|
||||
|
||||
// case1: will reduce all shape, if reduce_axes is empty
|
||||
if (reduce_axes.empty()) {
|
||||
// return the shape(all 1) when reduce_axes is empty and keep_dims = true
|
||||
OP_LOGD("ReduceOps", "do all reduce infershape for Reduce, output is {}");
|
||||
output_shape.SetDimNum(input_shape_len);
|
||||
for (size_t i = 0; i < input_shape_len; ++i) {
|
||||
output_shape.SetDim(i, 1);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// case2: shape is [x, y, z] axis is [0] --> [1, y, z] when keep_dims is true
|
||||
output_shape.SetDimNum(input_shape_len);
|
||||
OP_LOGD("ReduceOps", "do norm infershape for Reduce");
|
||||
output_shape = input_shape;
|
||||
for (size_t i = 0; i < reduce_axes.size(); ++i) {
|
||||
int64_t axis = reduce_axes[i];
|
||||
output_shape.SetDim(axis, 1);
|
||||
}
|
||||
OP_LOGD("ReduceOps", "after reduce output shape = %s", to_string(output_shape).c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DoReduceInfershapeWithAxesNoKeepdims(const GeShape &input_shape, std::vector<int64_t> &reduce_axes,
|
||||
GeShape &output_shape) {
|
||||
// case0: will reduce all shape, if reduce_axes is empty
|
||||
if (reduce_axes.empty()) {
|
||||
// return a scalar shape when reduce_axes is empty and keep_dims = false
|
||||
OP_LOGD("ReduceOps", "reduce_axes is empty, output a scalar");
|
||||
output_shape.SetDimNum(0);
|
||||
return true;
|
||||
}
|
||||
|
||||
// case1: input is {-2}, set the output {-2}
|
||||
if (input_shape.IsUnknownDimNum()) {
|
||||
OP_LOGD("ReduceOps", "input is {-2}, set the output is {-2}");
|
||||
output_shape.SetIsUnknownDimNum();
|
||||
return true;
|
||||
}
|
||||
|
||||
auto input_shape_len = input_shape.GetDimNum();
|
||||
if (!ConvertAxis(reduce_axes, static_cast<int64_t>(input_shape_len))) {
|
||||
OP_LOGE("ReduceOps", "do ConvertAxis failed, will return false");
|
||||
return false;
|
||||
}
|
||||
// case2: shape is [x, y, z] axis is [0] --> [y, z] when keep_dims is false
|
||||
output_shape.SetDimNum(input_shape_len);
|
||||
int64_t output_dim = 0;
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(input_shape_len); ++i) {
|
||||
if (std::find(reduce_axes.begin(), reduce_axes.end(), i) == reduce_axes.end()) {
|
||||
auto input_dim = input_shape.GetDim(i);
|
||||
output_shape.SetDim(output_dim, input_dim);
|
||||
output_dim++;
|
||||
}
|
||||
}
|
||||
output_shape.SetDimNum(output_dim);
|
||||
|
||||
OP_LOGD("ReduceOps", "after reduce output shape = %s", to_string(output_shape).c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DoReduceInfershapeWithAxes(const GeShape &input_shape, const bool keep_dims, std::vector<int64_t> &reduce_axes,
|
||||
GeShape &output_shape) {
|
||||
if (keep_dims) {
|
||||
return DoReduceInfershapeWithAxesKeepdims(input_shape, reduce_axes, output_shape);
|
||||
}
|
||||
|
||||
return DoReduceInfershapeWithAxesNoKeepdims(input_shape, reduce_axes, output_shape);
|
||||
}
|
||||
|
||||
bool DoReduceInferRangeWithAxes(GeTensorDescPtr &tensordesc_input_x, GeTensorDescPtr &tensordesc_output,
|
||||
std::vector<int64_t> &reduce_axes, bool keep_dims) {
|
||||
std::vector<std::pair<int64_t, int64_t>> output_shape_range;
|
||||
std::vector<std::pair<int64_t, int64_t>> input_shape_range;
|
||||
tensordesc_input_x->GetShapeRange(input_shape_range);
|
||||
std::vector<int64_t> input_shape_vec = tensordesc_input_x->MutableShape().GetDims();
|
||||
// If InputShapeRange is None, MakeUpShapeRange will set range.
|
||||
MakeUpShapeRange(input_shape_vec, input_shape_range);
|
||||
if (keep_dims) {
|
||||
output_shape_range = input_shape_range;
|
||||
for (auto item : reduce_axes) {
|
||||
output_shape_range[item] = std::make_pair<int64_t, int64_t>(1, 1);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(input_shape_range.size()); ++i) {
|
||||
if (std::find(reduce_axes.begin(), reduce_axes.end(), i) == reduce_axes.end()) {
|
||||
output_shape_range.push_back(input_shape_range[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
tensordesc_output->SetShapeRange(output_shape_range);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GetConstData(const Operator &op, const int64_t const_input_idx, std::vector<int64_t> &const_values) {
|
||||
ge::Tensor const_tensor;
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto input_name = op_desc->GetInputNameByIndex(const_input_idx);
|
||||
if (op.GetInputConstData(input_name.c_str(), const_tensor) != ge::GRAPH_SUCCESS) {
|
||||
OP_LOGW(TbeGetName(op).c_str(), "constvalue [%s] not exists.", input_name.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
DataType const_dtype = op_desc->MutableInputDesc(const_input_idx)->GetDataType();
|
||||
auto size = const_tensor.GetSize();
|
||||
auto data = const_tensor.GetData();
|
||||
const_values.reserve(size);
|
||||
switch (const_dtype) {
|
||||
case ge::DT_INT64: {
|
||||
size_t count = size / sizeof(int64_t);
|
||||
const int64_t *data_addr = (const int64_t *)(data);
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
const_values.push_back(*(data_addr + i));
|
||||
}
|
||||
} break;
|
||||
case ge::DT_INT32: {
|
||||
size_t count = size / sizeof(int32_t);
|
||||
const int32_t *data_addr = (const int32_t *)(data);
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
const_values.push_back(*(data_addr + i));
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
OP_LOGW(TbeGetName(op).c_str(), "GetConstData of dtype[%s] has not implement.", to_string(const_dtype).c_str());
|
||||
return false;
|
||||
} break;
|
||||
}
|
||||
|
||||
OP_LOGD(TbeGetName(op).c_str(), "get const value = %s", to_string(const_values).c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DoReduceInferShapeWithoutAxes(const Operator &op, GeTensorDescPtr &tensordesc_input_x,
|
||||
GeTensorDescPtr &tensordesc_output, const GeShape &axes_shape, bool keep_dims) {
|
||||
OP_LOGD(TbeGetName(op).c_str(), "the axes is not const, the output will be dynamic shape");
|
||||
const GeShape &input_shape = tensordesc_input_x->MutableShape();
|
||||
// case0: input is {}, set the output {}
|
||||
if (input_shape.IsScalar()) {
|
||||
OP_LOGD(TbeGetName(op).c_str(), "input is scalar, so output is scalar");
|
||||
std::vector<int64_t> output_shape;
|
||||
tensordesc_output->SetShape(GeShape(output_shape));
|
||||
return true;
|
||||
}
|
||||
// case1: input is {-2}, set the output {-2}
|
||||
if (input_shape.IsUnknownDimNum()) {
|
||||
OP_LOGD(TbeGetName(op).c_str(), "input is {-2}, so output {-2}");
|
||||
std::vector<int64_t> output_shape(1, -2);
|
||||
tensordesc_output->SetShape(GeShape(output_shape));
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int64_t> input_shape_vec = input_shape.GetDims();
|
||||
std::vector<int64_t> output_shape;
|
||||
std::vector<std::pair<int64_t, int64_t>> output_shape_range;
|
||||
std::vector<std::pair<int64_t, int64_t>> input_shape_range;
|
||||
tensordesc_input_x->GetShapeRange(input_shape_range);
|
||||
// If InputShapeRange is None, MakeUpShapeRange will set range.
|
||||
MakeUpShapeRange(input_shape_vec, input_shape_range);
|
||||
size_t input_length = input_shape_vec.size();
|
||||
if (keep_dims) {
|
||||
// case2: all output shape dim is -1, range [1, xxx] when keep_dims is true
|
||||
for (size_t item = 0; item < input_length; ++item) {
|
||||
int64_t range_min_value = 1;
|
||||
int64_t range_max_value = input_shape_range[item].second;
|
||||
output_shape_range.push_back(std::make_pair(range_min_value, range_max_value));
|
||||
if (range_max_value == 1) {
|
||||
output_shape.push_back(1);
|
||||
} else {
|
||||
output_shape.push_back(-1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// keep_dims is false
|
||||
// case3: all output shape dim is -1, range is (min_range, max_range)
|
||||
// output dim num = (input dim num - 1), when axes_shape = {} or {1}
|
||||
// case4: all output shape dim is -2
|
||||
int64_t output_dim_num = UNKNOWN_DIM_VALUE;
|
||||
if (!axes_shape.IsUnknownDimNum() && axes_shape.GetDimNum() == 0) {
|
||||
OP_LOGD(TbeGetName(op).c_str(), "the axes is scalar, will reduce one dim for input shape");
|
||||
output_dim_num = input_length - 1;
|
||||
}
|
||||
if (axes_shape.GetDimNum() == 1 && axes_shape.GetDim(0) == 1) {
|
||||
output_dim_num = input_length - 1;
|
||||
OP_LOGD(TbeGetName(op).c_str(), "the shape of axes is [1], will reduce one dim for input shape");
|
||||
}
|
||||
int64_t range_min_value = input_shape_range[0].first;
|
||||
int64_t range_max_value = input_shape_range[0].second;
|
||||
for (size_t item = 0; item < input_length; ++item) {
|
||||
if (input_shape_range[item].first < range_min_value) {
|
||||
range_min_value = input_shape_range[item].first;
|
||||
}
|
||||
if (input_shape_range[item].second == -1) {
|
||||
range_max_value = -1;
|
||||
}
|
||||
if (range_max_value != -1 && input_shape_range[item].second > range_max_value) {
|
||||
range_max_value = input_shape_range[item].second;
|
||||
}
|
||||
}
|
||||
if (output_dim_num == UNKNOWN_DIM_VALUE) {
|
||||
output_shape.push_back(UNKNOWN_DIM_VALUE);
|
||||
} else {
|
||||
for (int64_t item = 0; item < output_dim_num; ++item) {
|
||||
output_shape.push_back(-1);
|
||||
output_shape_range.push_back(std::make_pair(range_min_value, range_max_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
tensordesc_output->SetShape(GeShape(output_shape));
|
||||
tensordesc_output->SetShapeRange(output_shape_range);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CommonReduceInferWithInputAxes(const Operator &op, const int64_t input_x_idx, const int64_t output_idx,
|
||||
const int64_t input_axes_idx, bool keep_dims) {
|
||||
PROFILING_PROTO_INIT(TbeGetName(op).c_str());
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
auto axes_name = op_desc->GetInputNameByIndex(input_axes_idx);
|
||||
op_desc->SetOpInferDepends({axes_name});
|
||||
CHECK(op_desc == nullptr, VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("invalid OpDesc.")),
|
||||
return false);
|
||||
auto tensordesc_input_x = op_desc->MutableInputDesc(input_x_idx);
|
||||
auto tensordesc_input_axes = op_desc->MutableInputDesc(input_axes_idx);
|
||||
auto tensordesc_output = op_desc->MutableOutputDesc(output_idx);
|
||||
CHECK(tensordesc_input_x == nullptr || tensordesc_output == nullptr || tensordesc_input_axes == nullptr,
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("get mutable desc failed.")), return false);
|
||||
auto input_type = tensordesc_input_x->GetDataType();
|
||||
const GeShape &input_shape = tensordesc_input_x->MutableShape();
|
||||
const GeShape &axes_shape = tensordesc_input_axes->MutableShape();
|
||||
tensordesc_output->SetDataType(input_type);
|
||||
|
||||
if (axes_shape.GetDimNum() == 1 && axes_shape.GetDim(0) == 0) {
|
||||
OP_LOGD(TbeGetName(op).c_str(), "axes_shape is [0], set output shape = input shape");
|
||||
tensordesc_output->SetShape(input_shape);
|
||||
std::vector<std::pair<int64_t, int64_t>> input_shape_range;
|
||||
tensordesc_input_x->GetShapeRange(input_shape_range);
|
||||
tensordesc_output->SetShapeRange(input_shape_range);
|
||||
return true;
|
||||
}
|
||||
|
||||
// get const value from input_axes_idx
|
||||
std::vector<int64_t> reduce_axes;
|
||||
if (GetConstData(op, input_axes_idx, reduce_axes)) {
|
||||
PROFILING_PROTO_AFTER_GET_SHAPE_REG();
|
||||
// do infershape with const axes for static op
|
||||
GeShape &output_shape = tensordesc_output->MutableShape();
|
||||
CHECK(!DoReduceInfershapeWithAxes(input_shape, keep_dims, reduce_axes, output_shape),
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("do reduce infershape failed.")),
|
||||
return false);
|
||||
|
||||
// when output is dynamic shape, will infer range
|
||||
if (output_shape.IsUnknownShape()) {
|
||||
if (!output_shape.IsUnknownDimNum()) {
|
||||
CHECK(!DoReduceInferRangeWithAxes(tensordesc_input_x, tensordesc_output, reduce_axes, keep_dims),
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("do reduce infer range failed.")),
|
||||
return false);
|
||||
}
|
||||
OP_LOGD(TbeGetName(op).c_str(), "infer output range end for dynamic output");
|
||||
return true;
|
||||
}
|
||||
OP_LOGD(TbeGetName(op).c_str(), "the output is not dynamic");
|
||||
PROFILING_PROTO_AFTER_INFER_SHAPE_REG();
|
||||
PROFILING_PROTO_END();
|
||||
return true;
|
||||
}
|
||||
|
||||
CHECK(!DoReduceInferShapeWithoutAxes(op, tensordesc_input_x, tensordesc_output, axes_shape, keep_dims),
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("infer reduce range failed.")), return false);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CommonReduceInferWithAttrAxes(const Operator &op, const int64_t input_x_idx, const int64_t output_idx,
|
||||
vector<int64_t> attr_axes, bool keep_dims) {
|
||||
PROFILING_PROTO_INIT(TbeGetName(op).c_str());
|
||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
|
||||
CHECK(op_desc == nullptr, VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("invalid OpDesc.")),
|
||||
return false);
|
||||
auto tensordesc_input_x = op_desc->MutableInputDesc(input_x_idx);
|
||||
auto tensordesc_output = op_desc->MutableOutputDesc(output_idx);
|
||||
CHECK(tensordesc_input_x == nullptr || tensordesc_output == nullptr,
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("get mutable desc failed.")), return false);
|
||||
auto input_type = tensordesc_input_x->GetDataType();
|
||||
const GeShape &input_shape = tensordesc_input_x->MutableShape();
|
||||
tensordesc_output->SetDataType(input_type);
|
||||
|
||||
PROFILING_PROTO_AFTER_GET_SHAPE_REG();
|
||||
// do infershape with const axes for static op
|
||||
GeShape &output_shape = tensordesc_output->MutableShape();
|
||||
CHECK(!DoReduceInfershapeWithAxes(input_shape, keep_dims, attr_axes, output_shape),
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("do reduce infershape failed.")), return false);
|
||||
|
||||
// when output is dynamic shape, will infer range
|
||||
if (output_shape.IsUnknownShape()) {
|
||||
if (!output_shape.IsUnknownDimNum()) {
|
||||
CHECK(!DoReduceInferRangeWithAxes(tensordesc_input_x, tensordesc_output, attr_axes, keep_dims),
|
||||
VECTOR_INFER_SHAPE_INNER_ERR_REPORT(TbeGetName(op), OtherErrMsg("do reduce infer range failed.")),
|
||||
return false);
|
||||
}
|
||||
OP_LOGD(TbeGetName(op), "infer output range end for dynamic output");
|
||||
return true;
|
||||
}
|
||||
OP_LOGD(TbeGetName(op), "the output is not dynamic");
|
||||
PROFILING_PROTO_AFTER_INFER_SHAPE_REG();
|
||||
PROFILING_PROTO_END();
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace reduce_ops
|
|
@ -0,0 +1,119 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file reduce_infer_util.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef CUSTOMIZE_OP_PROTO_UTIL_REDUCE_INFER_UTIL_H_
|
||||
#define CUSTOMIZE_OP_PROTO_UTIL_REDUCE_INFER_UTIL_H_
|
||||
|
||||
#include <memory.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
|
||||
using namespace std;
|
||||
using namespace ge;
|
||||
namespace reduce_ops {
|
||||
|
||||
/*
|
||||
* only do infer shape for reduce with input_shape/axes, when keepdims = true
|
||||
* param[in] input_shape: GeShape input shape
|
||||
* param[in] reduce_axes: reduce axes list
|
||||
* param[in] output_shape: GeShape output shape
|
||||
* return bool:
|
||||
* true:infer success false:infer failed
|
||||
*/
|
||||
bool DoReduceInfershapeWithAxesKeepdims(const GeShape &input_shape, std::vector<int64_t> &reduce_axes,
|
||||
GeShape &output_shape);
|
||||
|
||||
/*
|
||||
* only do infer shape for reduce with input_shape/axes, when keepdims = false
|
||||
* param[in] input_shape: GeShape input shape
|
||||
* param[in] reduce_axes: reduce axes list
|
||||
* param[in] output_shape: GeShape output shape
|
||||
* return bool:
|
||||
* true:infer success false:infer failed
|
||||
*/
|
||||
bool DoReduceInfershapeWithAxesNoKeepdims(const GeShape &input_shape, std::vector<int64_t> &reduce_axes,
|
||||
GeShape &output_shape);
|
||||
|
||||
/*
|
||||
* only do infer shape for reduce with input_shape, axes and keepdims
|
||||
* param[in] input_shape: GeShape input shape
|
||||
* param[in] keep_dims: bool
|
||||
* param[in] reduce_axes: reduce axes list
|
||||
* param[in] output_shape: GeShape output shape
|
||||
* return bool:
|
||||
* true:infer success false:infer failed
|
||||
*/
|
||||
bool DoReduceInfershapeWithAxes(const GeShape &input_shape, const bool keep_dims, std::vector<int64_t> &reduce_axes,
|
||||
GeShape &output_shape);
|
||||
|
||||
/*
|
||||
* only do infer range for reduce
|
||||
* param[in] tensordesc_input_x: GeTensorDescPtr of input tensor
|
||||
* param[in] tensordesc_output: GeTensorDescPtr of output tensor
|
||||
* param[in] reduce_axes: reduce axes list
|
||||
* param[in] keep_dims: bool
|
||||
* return bool:
|
||||
* true:infer success false:infer failed
|
||||
*/
|
||||
bool DoReduceInferRangeWithAxes(GeTensorDescPtr &tensordesc_input_x, GeTensorDescPtr &tensordesc_output,
|
||||
std::vector<int64_t> &reduce_axes, bool keep_dims);
|
||||
|
||||
/*
|
||||
* get const value from const node to vector const_values
|
||||
* param[in] op: op desc get from by ge
|
||||
* param[in] const_input_idx: the input idx for const node
|
||||
* param[in] const_values: the const value
|
||||
* return bool:
|
||||
* true:infer success false:infer failed
|
||||
*/
|
||||
bool GetConstData(const Operator &op, const int64_t const_input_idx, std::vector<int64_t> &const_values);
|
||||
|
||||
/*
|
||||
* infer shape and range for reduce, when the axes is not const
|
||||
* param[in] op: op desc get from by ge
|
||||
* param[in] tensordesc_input_x: GeTensorDescPtr of input tensor
|
||||
* param[in] tensordesc_output: GeTensorDescPtr of output tensor
|
||||
* param[in] axes_shape: the axes shape
|
||||
* param[in] keep_dims: bool
|
||||
* return bool:
|
||||
* true:get value success false:no not get the const value
|
||||
*/
|
||||
bool DoReduceInferShapeWithoutAxes(const Operator &op, GeTensorDescPtr &tensordesc_input_x,
|
||||
GeTensorDescPtr &tensordesc_output, const GeShape &axes_shape, bool keep_dims);
|
||||
|
||||
/*
|
||||
* reduce infershape function, when axes is input
|
||||
* param[in] op: op desc get from by ge
|
||||
* param[in] input_x_idx: the input tensor idx int64
|
||||
* param[in] output_idx: the output tensor idx int64
|
||||
* param[in] input_axes_idx: the input const idx
|
||||
* param[in] keep_dims: bool
|
||||
* return bool:
|
||||
* true:infer success false:infer failed
|
||||
*/
|
||||
bool CommonReduceInferWithInputAxes(const Operator &op, const int64_t input_x_idx, const int64_t output_idx,
|
||||
const int64_t input_axes_idx, bool keep_dims);
|
||||
bool CommonReduceInferWithAttrAxes(const Operator &op, const int64_t input_x_idx, const int64_t output_idx,
|
||||
vector<int64_t> attr_axes, bool keep_dims);
|
||||
} // namespace reduce_ops
|
||||
|
||||
#endif // CUSTOMIZE_OP_PROTO_UTIL_REDUCE_INFER_UTIL_H_
|
|
@ -18,8 +18,8 @@
|
|||
* \file transfer_shape_according_to_format.h
|
||||
* \brief set shape according to original format and current format
|
||||
*/
|
||||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
|
||||
#define OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
|
||||
#ifndef CUSTOMIZE_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
|
||||
#define CUSTOMIZE_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
|
||||
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
|
@ -121,4 +121,4 @@ class ShapeTransferAccordingToFormat {
|
|||
};
|
||||
} // namespace ge
|
||||
|
||||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
|
||||
#endif // CUSTOMIZE_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
* \file util.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_
|
||||
#define OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_
|
||||
#ifndef CUSTOMIZE_OP_PROTO_UTIL_UTIL_H_
|
||||
#define CUSTOMIZE_OP_PROTO_UTIL_UTIL_H_
|
||||
|
||||
#include <memory.h>
|
||||
#include <string>
|
||||
|
@ -641,4 +641,16 @@ const int32_t INDEX_VALUE5 = 5;
|
|||
const int32_t INDEX_VALUE6 = 6;
|
||||
const int32_t INDEX_VALUE7 = 7;
|
||||
const int32_t INDEX_VALUE8 = 8;
|
||||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_
|
||||
|
||||
template <typename T>
|
||||
inline std::string VectorToString(const std::vector<T> &values) {
|
||||
std::stringstream ss;
|
||||
for (auto iter = values.begin(); iter != values.end(); ++iter) {
|
||||
ss << *iter;
|
||||
if (iter != values.end() - 1) {
|
||||
ss << ", ";
|
||||
}
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
#endif // CUSTOMIZE_OP_PROTO_UTIL_UTIL_H_
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
* \file vector_proto_profiling.h
|
||||
* \brief
|
||||
*/
|
||||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_VECTOR_PROTO_PROFILING_H__H_
|
||||
#define OPS_BUILT_IN_OP_PROTO_UTIL_VECTOR_PROTO_PROFILING_H__H_
|
||||
#ifndef CUSTOMIZE_OP_PROTO_UTIL_VECTOR_PROTO_PROFILING_H__H_
|
||||
#define CUSTOMIZE_OP_PROTO_UTIL_VECTOR_PROTO_PROFILING_H__H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -66,4 +66,4 @@ const bool vector_prof_switch = std::getenv("VECTOR_PROF") != nullptr;
|
|||
} \
|
||||
} \
|
||||
}
|
||||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_VECTOR_PROTO_PROFILING_H__H_
|
||||
#endif // CUSTOMIZE_OP_PROTO_UTIL_VECTOR_PROTO_PROFILING_H__H_
|
||||
|
|
|
@ -20,12 +20,110 @@ import stat
|
|||
import sys
|
||||
|
||||
cust_op_lists = [
|
||||
"Cast",
|
||||
"Conj",
|
||||
"IsNan",
|
||||
"SliceGrad",
|
||||
"MaskedSelectGrad",
|
||||
"GatherDGradV2"
|
||||
"acosgrad",
|
||||
"acoshgrad",
|
||||
"adaptiveavgpool2d",
|
||||
"adaptiveavgpool2dgrad",
|
||||
"adaptiveavgpool3d",
|
||||
"adaptiveavgpool3dgrad",
|
||||
"adaptivemaxpool2dgrad",
|
||||
"addn",
|
||||
"adjusthue",
|
||||
"adjustsaturation",
|
||||
"affinegridgrad",
|
||||
"argmax",
|
||||
"argmaxwithvalue",
|
||||
"argmin",
|
||||
"argminwithvalue",
|
||||
"asingrad",
|
||||
"asinhgrad",
|
||||
"bartlettwindow",
|
||||
"betainc",
|
||||
"biasadd",
|
||||
"biasaddgrad",
|
||||
"bincount",
|
||||
"blackmanwindow",
|
||||
"broadcastto",
|
||||
"bucketize",
|
||||
"cauchy",
|
||||
"checknumerics",
|
||||
"cholesky",
|
||||
"choleskygrad",
|
||||
"choleskyinverse",
|
||||
"choleskysolve",
|
||||
"combinednonmaxsuppression",
|
||||
"complex",
|
||||
"complexabs",
|
||||
"concat",
|
||||
"conj",
|
||||
"cos",
|
||||
"cumprod",
|
||||
"cumulativelogsumexp",
|
||||
"dataformatvecpermute",
|
||||
"depthtospace",
|
||||
"diag",
|
||||
"diagpart",
|
||||
"div",
|
||||
"divnonan",
|
||||
"eig",
|
||||
"exp",
|
||||
"expand",
|
||||
"expm1",
|
||||
"extractglimpse",
|
||||
"eye",
|
||||
"filldiagonal",
|
||||
"floordiv",
|
||||
"fractionalavgpool",
|
||||
"fractionalavgpoolgrad",
|
||||
"fractionalmaxpool",
|
||||
"fractionalmaxpoolgrad",
|
||||
"gathernd",
|
||||
"gcd",
|
||||
"geqrf",
|
||||
"hammingwindow",
|
||||
"heaviside",
|
||||
"histogram",
|
||||
"hypot",
|
||||
"identityn",
|
||||
"im2col",
|
||||
"indexfill",
|
||||
"isinf",
|
||||
"isnan",
|
||||
"kldivloss",
|
||||
"kldivlossgrad",
|
||||
"lcm",
|
||||
"leftshift",
|
||||
"lessequal",
|
||||
"listdiff",
|
||||
"log",
|
||||
"log1p",
|
||||
"lognormalreverse",
|
||||
"logspace",
|
||||
"lowerbound",
|
||||
"lusolve",
|
||||
"luunpackgrad",
|
||||
"maskedselect",
|
||||
"maskedselectgrad",
|
||||
"matrixdeterminant",
|
||||
"matrixexp",
|
||||
"matrixlogarithm",
|
||||
"matrixsolve",
|
||||
"matrixtriangularsolve",
|
||||
"maxpool3dgradwithargmax",
|
||||
"maxpool3dwithargmax",
|
||||
"mul",
|
||||
"mulnonan",
|
||||
"multimarginloss",
|
||||
"multimarginlossgrad",
|
||||
"multinomial",
|
||||
"mvlgamma",
|
||||
"mvlgammagrad",
|
||||
"nextafter",
|
||||
"nondeterministicints",
|
||||
"gatherdgradv2",
|
||||
"isnan",
|
||||
"maskedselectgrad",
|
||||
"slicegrad"
|
||||
]
|
||||
|
||||
|
||||
|
@ -56,7 +154,7 @@ def parse_ini_to_obj(ini_file, aicpu_ops_info):
|
|||
info = {}
|
||||
op_name = line[1:-1]
|
||||
info = {}
|
||||
if op_name not in cust_op_lists:
|
||||
if op_name.lower() not in cust_op_lists:
|
||||
op_name = None
|
||||
continue
|
||||
aicpu_ops_info[op_name] = info
|
||||
|
|
|
@ -57,5 +57,52 @@ REG_CUST_OP(GatherDGradV2)
|
|||
.OUTPUT(output, TensorType({DT_BOOL, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64,
|
||||
DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
|
||||
.CUST_OP_END_FACTORY_REG(GatherDGradV2)
|
||||
|
||||
REG_CUST_OP(AffineGridGrad)
|
||||
.INPUT(y_grad, TensorType({DT_FLOAT, DT_FLOAT16}))
|
||||
.INPUT(x_size, TensorType({DT_INT32, DT_INT64}))
|
||||
.OUTPUT(x_grad, TensorType({DT_FLOAT, DT_FLOAT16}))
|
||||
.REQUIRED_ATTR(align_corners, Bool)
|
||||
.CUST_OP_END_FACTORY_REG(AffineGridGrad)
|
||||
REG_CUST_OP(HammingWindow)
|
||||
.INPUT(length, TensorType({DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.REQUIRED_ATTR(periodic, Bool)
|
||||
.REQUIRED_ATTR(alpha, Float)
|
||||
.REQUIRED_ATTR(beta, Float)
|
||||
.REQUIRED_ATTR(dtype, Int)
|
||||
.CUST_OP_END_FACTORY_REG(HammingWindow)
|
||||
|
||||
REG_CUST_OP(IndexFill)
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
|
||||
DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.INPUT(dim, TensorType({DT_INT32}))
|
||||
.INPUT(indices, TensorType({DT_INT32}))
|
||||
.INPUT(value, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
|
||||
DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
|
||||
DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.CUST_OP_END_FACTORY_REG(IndexFill)
|
||||
|
||||
REG_CUST_OP(Mvlgamma)
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.REQUIRED_ATTR(p, Int)
|
||||
.CUST_OP_END_FACTORY_REG(Mvlgamma)
|
||||
|
||||
REG_CUST_OP(MvlgammaGrad)
|
||||
.INPUT(y_grad, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.OUTPUT(x_grad, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.REQUIRED_ATTR(p, Int)
|
||||
.CUST_OP_END_FACTORY_REG(MvlgammaGrad)
|
||||
|
||||
REG_CUST_OP(LogSpace)
|
||||
.INPUT(start, TensorType({DT_FLOAT, DT_FLOAT16}))
|
||||
.INPUT(end, TensorType({DT_FLOAT, DT_FLOAT16}))
|
||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16}))
|
||||
.REQUIRED_ATTR(steps, Int)
|
||||
.REQUIRED_ATTR(base, Int)
|
||||
.CUST_OP_END_FACTORY_REG(LogSpace)
|
||||
} // namespace ge
|
||||
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_ARRAY_OPS_H_
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_ELEWISE_CALCULATION_OPS_H_
|
||||
#define MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_ELEWISE_CALCULATION_OPS_H_
|
||||
|
||||
#include "graph/operator_reg.h"
|
||||
#include "graph/operator.h"
|
||||
#include "transform/graph_ir/custom_op_proto/op_proto_macro.h"
|
||||
|
||||
/* clang-format off */
|
||||
|
||||
namespace ge {
|
||||
REG_CUST_OP(ArgMax)
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
|
||||
DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.INPUT(dimension, TensorType({DT_INT32, DT_INT64}))
|
||||
.ATTR(dtype, Type, DT_INT64)
|
||||
.OUTPUT(y, TensorType({DT_INT32, DT_INT64}))
|
||||
.CUST_OP_END_FACTORY_REG(ArgMax)
|
||||
} // namespace ge
|
||||
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_ELEWISE_CALCULATION_OPS_H_
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_LINALG_OPS_H_
|
||||
#define MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_LINALG_OPS_H_
|
||||
|
||||
#include "graph/operator_reg.h"
|
||||
#include "graph/operator.h"
|
||||
#include "transform/graph_ir/custom_op_proto/op_proto_macro.h"
|
||||
|
||||
/* clang-format off */
|
||||
|
||||
namespace ge {
|
||||
REG_CUST_OP(Geqrf)
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.OUTPUT(tau, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.CUST_OP_END_FACTORY_REG(Geqrf)
|
||||
|
||||
REG_CUST_OP(LuUnpack)
|
||||
.INPUT(LU_data, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
|
||||
.INPUT(LU_pivots, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64}))
|
||||
.OUTPUT(pivots, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
|
||||
.OUTPUT(L, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
|
||||
.OUTPUT(U, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
|
||||
.ATTR(unpack_data, Bool, true)
|
||||
.ATTR(unpack_pivots, Bool, true)
|
||||
.CUST_OP_END_FACTORY_REG(LuUnpack)
|
||||
|
||||
REG_CUST_OP(LuUnpackGrad)
|
||||
.INPUT(L_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
|
||||
.INPUT(U_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
|
||||
.INPUT(LU_data, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
|
||||
.OUTPUT(L_data_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
|
||||
.OUTPUT(U_data_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
|
||||
.REQUIRED_ATTR(L_grad_flag, Bool)
|
||||
.CUST_OP_END_FACTORY_REG(LuUnpackGrad)
|
||||
|
||||
REG_CUST_OP(LuSolve)
|
||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16}))
|
||||
.INPUT(lu_data, TensorType({DT_FLOAT, DT_FLOAT16}))
|
||||
.INPUT(lu_pivots, TensorType({DT_INT32}))
|
||||
.OUTPUT(output, TensorType({DT_FLOAT, DT_FLOAT16}))
|
||||
.CUST_OP_END_FACTORY_REG(LuSolve)
|
||||
} // namespace ge
|
||||
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_LINALG_OPS_H_
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_MATH_OPS_H_
|
||||
#define MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_MATH_OPS_H_
|
||||
|
||||
#include "graph/operator_reg.h"
|
||||
#include "graph/operator.h"
|
||||
#include "transform/graph_ir/custom_op_proto/op_proto_macro.h"
|
||||
|
||||
/* clang-format off */
|
||||
|
||||
namespace ge {
|
||||
REG_CUST_OP(CholeskySolve)
|
||||
.INPUT(x1, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.INPUT(x2, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.REQUIRED_ATTR(upper, Bool)
|
||||
.CUST_OP_END_FACTORY_REG(CholeskySolve)
|
||||
|
||||
REG_CUST_OP(Cauchy)
|
||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16}))
|
||||
.REQUIRED_ATTR(size, ListInt)
|
||||
.REQUIRED_ATTR(sigma, Float)
|
||||
.REQUIRED_ATTR(median, Float)
|
||||
.CUST_OP_END_FACTORY_REG(Cauchy)
|
||||
|
||||
REG_CUST_OP(CholeskyInverse)
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.REQUIRED_ATTR(upper, Bool)
|
||||
.CUST_OP_END_FACTORY_REG(CholeskyInverse)
|
||||
|
||||
REG_CUST_OP(Eig)
|
||||
.INPUT(x, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT}))
|
||||
.OUTPUT(eigen_values, TensorType({DT_COMPLEX128, DT_COMPLEX64}))
|
||||
.OUTPUT(eigen_vectors, TensorType({DT_COMPLEX128, DT_COMPLEX64}))
|
||||
.REQUIRED_ATTR(compute_v, Bool)
|
||||
.CUST_OP_END_FACTORY_REG(Eig)
|
||||
|
||||
REG_CUST_OP(Hypot)
|
||||
.INPUT(x1, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.INPUT(x2, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT}))
|
||||
.CUST_OP_END_FACTORY_REG(Hypot)
|
||||
|
||||
REG_CUST_OP(MatrixLogarithm)
|
||||
.INPUT(x, TensorType({DT_COMPLEX128, DT_COMPLEX64}))
|
||||
.OUTPUT(y, TensorType({DT_COMPLEX128, DT_COMPLEX64}))
|
||||
.CUST_OP_END_FACTORY_REG(MatrixLogarithm)
|
||||
|
||||
REG_CUST_OP(Lcm)
|
||||
.INPUT(x1, TensorType({DT_INT32, DT_INT64}))
|
||||
.INPUT(x2, TensorType({DT_INT32, DT_INT64}))
|
||||
.OUTPUT(y, TensorType({DT_INT32, DT_INT64}))
|
||||
.CUST_OP_END_FACTORY_REG(Lcm)
|
||||
|
||||
REG_CUST_OP(MatrixExp)
|
||||
.INPUT(x, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.OUTPUT(y, TensorType({DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.CUST_OP_END_FACTORY_REG(MatrixExp)
|
||||
|
||||
REG_CUST_OP(Heaviside)
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
|
||||
DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.INPUT(values, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
|
||||
DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
|
||||
DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.CUST_OP_END_FACTORY_REG(Heaviside)
|
||||
|
||||
REG_CUST_OP(Gcd)
|
||||
.INPUT(x1, TensorType({DT_INT32, DT_INT64}))
|
||||
.INPUT(x2, TensorType({DT_INT32, DT_INT64}))
|
||||
.OUTPUT(y, TensorType({DT_INT32, DT_INT64}))
|
||||
.CUST_OP_END_FACTORY_REG(Gcd)
|
||||
} // namespace ge
|
||||
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_MATH_OPS_H_
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_NN_OPS_H_
|
||||
#define MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_NN_OPS_H_
|
||||
|
||||
#include "graph/operator_reg.h"
|
||||
#include "graph/operator.h"
|
||||
#include "transform/graph_ir/custom_op_proto/op_proto_macro.h"
|
||||
|
||||
/* clang-format off */
|
||||
|
||||
namespace ge {
|
||||
REG_CUST_OP(AdaptiveAvgPool3dGrad)
|
||||
.INPUT(input_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
|
||||
.INPUT(orig_input_shape, TensorType({DT_INT32}))
|
||||
.OUTPUT(output_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
|
||||
.CUST_OP_END_FACTORY_REG(AdaptiveAvgPool3dGrad)
|
||||
|
||||
REG_CUST_OP(AdaptiveMaxPool2dGrad)
|
||||
.INPUT(y_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.INPUT(argmax, TensorType({DT_INT32, DT_INT64}))
|
||||
.OUTPUT(x_grad, TensorType({DT_FLOAT, DT_FLOAT16}))
|
||||
.CUST_OP_END_FACTORY_REG(AdaptiveMaxPool2dGrad)
|
||||
|
||||
REG_CUST_OP(AdaptiveAvgPool2D)
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.REQUIRED_ATTR(output_size, ListInt)
|
||||
.CUST_OP_END_FACTORY_REG(AdaptiveAvgPool2D)
|
||||
|
||||
REG_CUST_OP(AdaptiveAvgPool3d)
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT8}))
|
||||
.REQUIRED_ATTR(output_size, ListInt)
|
||||
.CUST_OP_END_FACTORY_REG(AdaptiveAvgPool3d)
|
||||
|
||||
REG_CUST_OP(AdaptiveAvgPool2DGrad)
|
||||
.INPUT(input_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.INPUT(orig_input_shape, TensorType({DT_INT64}))
|
||||
.OUTPUT(output_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.CUST_OP_END_FACTORY_REG(AdaptiveAvgPool2DGrad)
|
||||
|
||||
REG_CUST_OP(MultiMarginLossGrad)
|
||||
.INPUT(y_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.INPUT(target, TensorType({DT_INT64}))
|
||||
.INPUT(weight, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.OUTPUT(x_grad, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.REQUIRED_ATTR(p, Int)
|
||||
.REQUIRED_ATTR(margin, Float)
|
||||
.REQUIRED_ATTR(reduction, String)
|
||||
.CUST_OP_END_FACTORY_REG(MultiMarginLossGrad)
|
||||
|
||||
REG_CUST_OP(MaxPool3DGradWithArgmax)
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
|
||||
DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.INPUT(grads, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
|
||||
DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.INPUT(argmax, TensorType({DT_INT32, DT_INT64}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16,
|
||||
DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.REQUIRED_ATTR(ksize, ListInt)
|
||||
.REQUIRED_ATTR(strides, ListInt)
|
||||
.REQUIRED_ATTR(pads, ListInt)
|
||||
.REQUIRED_ATTR(dilation, ListInt)
|
||||
.REQUIRED_ATTR(ceil_mode, Bool)
|
||||
.REQUIRED_ATTR(data_format, String)
|
||||
.REQUIRED_ATTR(argmax_type, String)
|
||||
.CUST_OP_END_FACTORY_REG(MaxPool3DGradWithArgmax)
|
||||
|
||||
REG_CUST_OP(MultiMarginLoss)
|
||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.INPUT(target, TensorType({DT_INT64}))
|
||||
.INPUT(weight, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.REQUIRED_ATTR(p, Int)
|
||||
.REQUIRED_ATTR(margin, Float)
|
||||
.REQUIRED_ATTR(reduction, String)
|
||||
.CUST_OP_END_FACTORY_REG(MultiMarginLoss)
|
||||
} // namespace ge
|
||||
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_NN_OPS_H_
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_RANDOM_OPS_H_
|
||||
#define MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_RANDOM_OPS_H_
|
||||
|
||||
#include "graph/operator.h"
|
||||
#include "graph/operator_reg.h"
|
||||
#include "transform/graph_ir/custom_op_proto/op_proto_macro.h"
|
||||
|
||||
/* clang-format off */
|
||||
|
||||
namespace ge {
|
||||
REG_CUST_OP(LogNormalReverse)
|
||||
.INPUT(input, TensorType({DT_FLOAT, DT_FLOAT16}))
|
||||
.OUTPUT(output, TensorType({DT_FLOAT, DT_FLOAT16}))
|
||||
.REQUIRED_ATTR(mean, Float)
|
||||
.REQUIRED_ATTR(std, Float)
|
||||
.CUST_OP_END_FACTORY_REG(LogNormalReverse)
|
||||
|
||||
REG_CUST_OP(Dropout2D)
|
||||
.INPUT(x, TensorType({DT_BOOL, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8,
|
||||
DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.OUTPUT(y, TensorType({DT_BOOL, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, DT_INT8,
|
||||
DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}))
|
||||
.OUTPUT(mask, TensorType({DT_BOOL}))
|
||||
.REQUIRED_ATTR(keep_prob, Float)
|
||||
.CUST_OP_END_FACTORY_REG(Dropout2D)
|
||||
} // namespace ge
|
||||
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_RANDOM_OPS_H_
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_SPECTRAL_OPS_H_
|
||||
#define MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_SPECTRAL_OPS_H_
|
||||
|
||||
#include "graph/operator_reg.h"
|
||||
#include "graph/operator.h"
|
||||
#include "transform/graph_ir/custom_op_proto/op_proto_macro.h"
|
||||
|
||||
/* clang-format off */
|
||||
|
||||
namespace ge {
|
||||
REG_CUST_OP(BlackmanWindow)
|
||||
.INPUT(window_length, TensorType({DT_INT32, DT_INT64}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.REQUIRED_ATTR(periodic, Bool)
|
||||
.REQUIRED_ATTR(dtype, Type)
|
||||
.CUST_OP_END_FACTORY_REG(BlackmanWindow)
|
||||
|
||||
REG_CUST_OP(BartlettWindow)
|
||||
.INPUT(window_length, TensorType({DT_INT32, DT_INT64}))
|
||||
.OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16}))
|
||||
.REQUIRED_ATTR(periodic, Bool)
|
||||
.REQUIRED_ATTR(dtype, Type)
|
||||
.CUST_OP_END_FACTORY_REG(BartlettWindow)
|
||||
} // namespace ge
|
||||
#endif // MINDSPORE_CCSRC_GRAPH_IR_CUSTOM_OP_PROTO_CUST_SPECTRAL_OPS_H_
|
||||
|
|
@ -17,8 +17,9 @@
|
|||
#ifndef MINDSPORE_CCSRC_INCLUDE_TRANSFORM_GRAPH_IR_OP_ADAPTER_MAP_H_
|
||||
#define MINDSPORE_CCSRC_INCLUDE_TRANSFORM_GRAPH_IR_OP_ADAPTER_MAP_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -28,12 +29,15 @@ constexpr const char kNameConst[] = "Const";
|
|||
constexpr const char kNameParam[] = "parameter";
|
||||
constexpr const char kNameRandomUniform[] = "RandomUniform";
|
||||
constexpr const char kNameUniformReal[] = "UniformReal";
|
||||
constexpr const char kNameLogNormalReverse[] = "LogNormalReverse";
|
||||
constexpr const char kNameSimpleMean[] = "SimpleMean";
|
||||
constexpr const char kNameSimpleMeanGrad[] = "SimpleMeanGrad";
|
||||
constexpr const char kNameAllReduce[] = "AllReduce";
|
||||
constexpr const char kNameBroadcast[] = "Broadcast";
|
||||
constexpr const char kNameBroadcastTo[] = "BroadcastTo";
|
||||
constexpr const char kNameBroadcastToD[] = "BroadcastToD";
|
||||
constexpr const char kNameBlackmanWindow[] = "BlackmanWindow";
|
||||
constexpr const char kNameBartlettWindow[] = "BartlettWindow";
|
||||
constexpr const char kNameAllgather[] = "AllGather";
|
||||
constexpr const char kNameAllToAllv[] = "AllToAllv";
|
||||
constexpr const char kNameReduceScatter[] = "ReduceScatter";
|
||||
|
@ -47,6 +51,7 @@ constexpr const char kNameSquaredDifference[] = "SquaredDifference";
|
|||
constexpr const char kNamePow[] = "Pow";
|
||||
constexpr const char kNameBatchMatMul[] = "BatchMatMul";
|
||||
constexpr const char kNameBatchMatMulV2[] = "BatchMatMulV2";
|
||||
constexpr const char kNameBincount[] = "Bincount";
|
||||
constexpr const char kNameStridedSlice[] = "StridedSlice";
|
||||
constexpr const char kNameStridedSliceGrad[] = "StridedSliceGrad";
|
||||
constexpr const char kNameExpandDims[] = "ExpandDims";
|
||||
|
@ -54,6 +59,7 @@ constexpr const char kNameLog[] = "Log";
|
|||
constexpr const char kNameLogicalAnd[] = "LogicalAnd";
|
||||
constexpr const char kNameLogicalNot[] = "LogicalNot";
|
||||
constexpr const char kNameLogicalOr[] = "LogicalOr";
|
||||
constexpr const char kNameListDiff[] = "ListDiff";
|
||||
constexpr const char kNameExp[] = "Exp";
|
||||
constexpr const char kNameLessEqual[] = "LessEqual";
|
||||
constexpr const char kNameGreaterEqual[] = "GreaterEqual";
|
||||
|
@ -61,6 +67,8 @@ constexpr const char kNameApproximateEqual[] = "ApproximateEqual";
|
|||
constexpr const char kNameEqual[] = "Equal";
|
||||
constexpr const char kNameNotEqual[] = "NotEqual";
|
||||
constexpr const char kNameFlattenGrad[] = "FlattenGrad";
|
||||
constexpr const char kNameFillDiagonal[] = "FillDiagonal";
|
||||
constexpr const char kNameEye[] = "Eye";
|
||||
constexpr const char kNameConvolution[] = "Convolution";
|
||||
constexpr const char kNameMaxPool3D[] = "MaxPool3D";
|
||||
constexpr const char kNameMaxPool3DGrad[] = "MaxPool3DGrad";
|
||||
|
@ -80,6 +88,7 @@ constexpr const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax";
|
|||
constexpr const char kNameMaxPoolGradWithArgmaxV2[] = "MaxPoolGradWithArgmaxV2";
|
||||
constexpr const char kNameApplyMomentum[] = "ApplyMomentum";
|
||||
constexpr const char kNameDropoutDoMask[] = "DropoutDoMask";
|
||||
constexpr const char kNameDropout2D[] = "Dropout2D";
|
||||
constexpr const char kNameDropOutDoMaskV3[] = "DropOutDoMaskV3";
|
||||
constexpr const char kNameDropOutDoMaskV3D[] = "DropOutDoMaskV3D";
|
||||
constexpr const char kNameDropOutGenMaskV4[] = "DropOutGenMaskV4";
|
||||
|
@ -441,6 +450,7 @@ constexpr const char kNameViewCopy[] = "ViewCopy";
|
|||
constexpr const char kNameSend[] = "Send";
|
||||
constexpr const char kNameReceive[] = "Receive";
|
||||
constexpr const char kNameIndexAdd[] = "IndexAdd";
|
||||
constexpr const char kNameIndexFill[] = "IndexFill";
|
||||
constexpr const char kNameUnique[] = "Unique";
|
||||
constexpr const char kNameDynamicBroadcastGradientArgs[] = "DynamicBroadcastGradientArgs";
|
||||
constexpr const char kNameDynamicStitch[] = "DynamicStitch";
|
||||
|
|
|
@ -253,4 +253,57 @@ ATTR_INPUT_MAP(ViewCopy) = {
|
|||
ATTR_MAP(ViewCopy) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(ViewCopy) = {{0, OUTPUT_DESC(dst)}};
|
||||
REG_ADPT_DESC(ViewCopy, kNameViewCopy, ADPT_DESC(ViewCopy))
|
||||
|
||||
// CheckNumerics
|
||||
INPUT_MAP(CheckNumerics) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(CheckNumerics) = {{"message", ATTR_DESC(message, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(CheckNumerics) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(CheckNumerics, prim::kPrimCheckNumerics->name(), ADPT_DESC(CheckNumerics));
|
||||
|
||||
// HammingWindow
|
||||
CUST_INPUT_MAP(HammingWindow) = {{1, INPUT_DESC(length)}};
|
||||
CUST_ATTR_MAP(HammingWindow) = {{"periodic", ATTR_DESC(periodic, AnyTraits<bool>())},
|
||||
{"alpha", ATTR_DESC(alpha, AnyTraits<float>())},
|
||||
{"beta", ATTR_DESC(beta, AnyTraits<float>())},
|
||||
{"dtype", ATTR_DESC(dtype, AnyTraits<int64_t>())}};
|
||||
CUST_OUTPUT_MAP(HammingWindow) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(HammingWindow, prim::kPrimHammingWindow->name(), CUST_ADPT_DESC(HammingWindow));
|
||||
|
||||
// LowerBound
|
||||
INPUT_MAP(LowerBound) = {{1, INPUT_DESC(sorted_x)}, {2, INPUT_DESC(values)}};
|
||||
ATTR_MAP(LowerBound) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(LowerBound) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(LowerBound, prim::kPrimLowerBound->name(), ADPT_DESC(LowerBound));
|
||||
|
||||
// ListDiff
|
||||
INPUT_MAP(ListDiff) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}};
|
||||
ATTR_MAP(ListDiff) = {{"out_idx", ATTR_DESC(out_idx, AnyTraits<GEType>())}};
|
||||
OUTPUT_MAP(ListDiff) = {{0, OUTPUT_DESC(out)}, {1, OUTPUT_DESC(idx)}};
|
||||
REG_ADPT_DESC(ListDiff, kNameListDiff, ADPT_DESC(ListDiff));
|
||||
|
||||
// IndexFill
|
||||
CUST_INPUT_MAP(IndexFill) = {
|
||||
{1, INPUT_DESC(x)}, {2, INPUT_DESC(dim)}, {3, INPUT_DESC(indices)}, {4, INPUT_DESC(value)}};
|
||||
CUST_ATTR_MAP(IndexFill) = EMPTY_ATTR_MAP;
|
||||
CUST_OUTPUT_MAP(IndexFill) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(IndexFill, kNameIndexFill, CUST_ADPT_DESC(IndexFill));
|
||||
|
||||
// Mvlgamma
|
||||
CUST_INPUT_MAP(Mvlgamma) = {{1, INPUT_DESC(x)}};
|
||||
CUST_ATTR_MAP(Mvlgamma) = {{"p", ATTR_DESC(p, AnyTraits<int64_t>())}};
|
||||
CUST_OUTPUT_MAP(Mvlgamma) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Mvlgamma, prim::kPrimMvlgamma->name(), CUST_ADPT_DESC(Mvlgamma));
|
||||
|
||||
// MvlgammaGrad
|
||||
CUST_INPUT_MAP(MvlgammaGrad) = {{1, INPUT_DESC(y_grad)}, {2, INPUT_DESC(x)}};
|
||||
CUST_ATTR_MAP(MvlgammaGrad) = {{"p", ATTR_DESC(p, AnyTraits<int64_t>())}};
|
||||
CUST_OUTPUT_MAP(MvlgammaGrad) = {{0, OUTPUT_DESC(x_grad)}};
|
||||
REG_ADPT_DESC(MvlgammaGrad, prim::kPrimMvlgammaGrad->name(), CUST_ADPT_DESC(MvlgammaGrad));
|
||||
|
||||
// LogSpace
|
||||
CUST_INPUT_MAP(LogSpace) = {{1, INPUT_DESC(start)}, {2, INPUT_DESC(end)}};
|
||||
CUST_ATTR_MAP(LogSpace) = {{"steps", ATTR_DESC(steps, AnyTraits<int64_t>())},
|
||||
{"base", ATTR_DESC(base, AnyTraits<int64_t>())}};
|
||||
CUST_OUTPUT_MAP(LogSpace) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(LogSpace, prim::kPrimLogSpace->name(), CUST_ADPT_DESC(LogSpace));
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -125,4 +125,28 @@ DECLARE_OP_USE_OUTPUT(AsStrided)
|
|||
|
||||
DECLARE_OP_ADAPTER(ViewCopy)
|
||||
DECLARE_OP_USE_OUTPUT(ViewCopy)
|
||||
|
||||
DECLARE_OP_ADAPTER(CheckNumerics)
|
||||
DECLARE_OP_USE_OUTPUT(CheckNumerics)
|
||||
|
||||
DECLARE_CUST_OP_ADAPTER(HammingWindow)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(HammingWindow)
|
||||
|
||||
DECLARE_OP_ADAPTER(LowerBound)
|
||||
DECLARE_OP_USE_OUTPUT(LowerBound)
|
||||
|
||||
DECLARE_OP_ADAPTER(ListDiff)
|
||||
DECLARE_OP_USE_OUTPUT(ListDiff)
|
||||
|
||||
DECLARE_CUST_OP_ADAPTER(IndexFill)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(IndexFill)
|
||||
|
||||
DECLARE_CUST_OP_ADAPTER(Mvlgamma)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(Mvlgamma)
|
||||
|
||||
DECLARE_CUST_OP_ADAPTER(MvlgammaGrad)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(MvlgammaGrad)
|
||||
|
||||
DECLARE_CUST_OP_ADAPTER(LogSpace)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(LogSpace)
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "transform/graph_ir/custom_op_proto/cust_elewise_calculation_ops.h"
|
||||
#include "ops/ascend_op_name.h"
|
||||
#include "ops/array_ops.h"
|
||||
#include "ops/framework_ops.h"
|
||||
|
@ -279,6 +280,13 @@ ATTR_MAP(OnesLike) = EMPTY_ATTR_MAP;
|
|||
OUTPUT_MAP(OnesLike) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(OnesLike, kNameOnesLike, ADPT_DESC(OnesLike))
|
||||
|
||||
// ArgMax
|
||||
CUST_INPUT_MAP(ArgMax) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(dimension)}};
|
||||
CUST_ATTR_INPUT_MAP(ArgMax) = {{"axis", "dimension"}};
|
||||
CUST_ATTR_MAP(ArgMax) = {{"output_type", ATTR_DESC(dtype, AnyTraits<GEType>())}};
|
||||
CUST_OUTPUT_MAP(ArgMax) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ArgMax, kNameArgmax, CUST_ADPT_DESC(ArgMax));
|
||||
|
||||
// ArgMaxD
|
||||
INPUT_MAP(ArgMaxD) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(ArgMaxD) = {{"axis", ATTR_DESC(dimension, AnyTraits<int64_t>())},
|
||||
|
@ -291,7 +299,6 @@ INPUT_MAP(ArgMaxV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(dimension)}};
|
|||
ATTR_INPUT_MAP(ArgMaxV2) = {{"axis", "dimension"}};
|
||||
ATTR_MAP(ArgMaxV2) = {{"output_type", ATTR_DESC(dtype, AnyTraits<GEType>())}};
|
||||
OUTPUT_MAP(ArgMaxV2) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ArgMax, kNameArgmax, ADPT_DESC(ArgMaxV2))
|
||||
REG_ADPT_DESC(ArgMaxV2, kNameArgMaxV2, ADPT_DESC(ArgMaxV2))
|
||||
|
||||
// ArgMaxWithValue
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include "inc/ops/bitwise_ops.h"
|
||||
#include "inc/ops/elewise_calculation_ops.h"
|
||||
#include "transform/graph_ir/custom_op_proto/cust_elewise_calculation_ops.h"
|
||||
#include "transform/graph_ir/op_declare/op_declare_macro.h"
|
||||
#include "utils/hash_map.h"
|
||||
|
||||
|
@ -47,6 +48,9 @@ DECLARE_OP_USE_OUTPUT(ZerosLike)
|
|||
DECLARE_OP_ADAPTER(OnesLike)
|
||||
DECLARE_OP_USE_OUTPUT(OnesLike)
|
||||
|
||||
DECLARE_CUST_OP_ADAPTER(ArgMax)
|
||||
DECLARE_CUST_OP_USE_OUTPUT(ArgMax)
|
||||
|
||||
DECLARE_OP_ADAPTER(ArgMaxD)
|
||||
DECLARE_OP_USE_OUTPUT(ArgMaxD)
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue