forked from mindspore-Ecosystem/mindspore
!214 modify GPU operator information registration
Merge pull request !214 from Maoweiyong/modify-gpu-op-register
This commit is contained in:
commit
fb19655ea6
|
@ -30,7 +30,7 @@ Note:
|
|||
|
||||
from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
|
||||
from .op_info_register import op_info_register, AiCPURegOp, TBERegOp, DataType
|
||||
from .op_info_register import op_info_register, AkgRegOp, AiCPURegOp, TBERegOp, DataType
|
||||
from .primitive import constexpr
|
||||
from .._c_expression import signature_rw, signature_kind
|
||||
|
||||
|
@ -40,6 +40,6 @@ __primitive__ = [
|
|||
]
|
||||
|
||||
__all__ = ["get_vm_impl_fn", "vm_impl_registry",
|
||||
"op_info_register", "AiCPURegOp", "TBERegOp", "DataType",
|
||||
"op_info_register", "AkgRegOp", "AiCPURegOp", "TBERegOp", "DataType",
|
||||
"constexpr"]
|
||||
__all__.extend(__primitive__)
|
||||
|
|
|
@ -13,45 +13,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""Cast op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Cast",
|
||||
"imply_type": "AutoDiff",
|
||||
"fusion_type": "OPAQUE",
|
||||
"processor": "cuda",
|
||||
"attr": [
|
||||
{
|
||||
"name": "dst_type",
|
||||
"param_type": "required",
|
||||
"type": "str"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float32"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
cast_op_info = AkgRegOp("Cast") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x") \
|
||||
.output(0, "output") \
|
||||
.attr("dst_type", "required", "str") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(cast_op_info)
|
||||
def _cast_akg():
|
||||
"""Cast AutoDiff register"""
|
||||
return
|
||||
|
|
|
@ -13,50 +13,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""Equal op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Equal",
|
||||
"imply_type": "AutoDiff",
|
||||
"fusion_type": "OPAQUE",
|
||||
"processor": "cuda",
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"bool", "bool"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
equal_op_info = AkgRegOp("Equal") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x") \
|
||||
.input(1, "y") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(equal_op_info)
|
||||
def _equal_akg():
|
||||
"""Equal AutoDiff register"""
|
||||
return
|
||||
|
|
|
@ -13,40 +13,18 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""HSigmoid op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "HSigmoid",
|
||||
"imply_type": "AutoDiff",
|
||||
"fusion_type": "OPAQUE",
|
||||
"processor": "cuda",
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
hsigmoid_op_info = AkgRegOp("HSigmoid") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(hsigmoidgrad_op_info)
|
||||
def _hsigmoid_akg():
|
||||
"""HSigmoid AutoDiff register"""
|
||||
return
|
||||
|
|
|
@ -13,50 +13,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""HSigmoidGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "HSigmoidGrad",
|
||||
"imply_type": "AutoDiff",
|
||||
"fusion_type": "OPAQUE",
|
||||
"processor": "cuda",
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y_grad"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
hsigmoidgrad_op_info = AkgRegOp("HSigmoidGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "y_grad") \
|
||||
.input(1, "x") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(hsigmoidgrad_op_info)
|
||||
def _hsigmoid_grad_akg():
|
||||
"""HSigmoidGrad AutoDiff register"""
|
||||
return
|
||||
|
|
|
@ -13,40 +13,18 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""HSwish op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "HSwish",
|
||||
"imply_type": "AutoDiff",
|
||||
"fusion_type": "OPAQUE",
|
||||
"processor": "cuda",
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
hswish_op_info = AkgRegOp("HSwish") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(hsigmoidgrad_op_info)
|
||||
def _hswish_akg():
|
||||
"""HSwish AutoDiff register"""
|
||||
return
|
||||
|
|
|
@ -13,50 +13,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""HSwishGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "HSwishGrad",
|
||||
"imply_type": "AutoDiff",
|
||||
"fusion_type": "OPAQUE",
|
||||
"processor": "cuda",
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y_grad"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
hswishgrad_op_info = AkgRegOp("HSwishGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "y_grad") \
|
||||
.input(1, "x") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(hsigmoidgrad_op_info)
|
||||
def _hswish_grad_akg():
|
||||
"""HSwishGrad AutoDiff register"""
|
||||
return
|
||||
|
|
|
@ -13,40 +13,18 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""SimpleMean op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "SimpleMean",
|
||||
"imply_type": "AutoDiff",
|
||||
"fusion_type": "OPAQUE",
|
||||
"processor": "cuda",
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
mean_op_info = AkgRegOp("SimpleMean") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(mean_op_info)
|
||||
def _simple_mean_akg():
|
||||
"""SimpleMean AutoDiff register"""
|
||||
return
|
||||
|
|
|
@ -13,45 +13,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""SimpleMeanGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "SimpleMeanGrad",
|
||||
"imply_type": "AutoDiff",
|
||||
"fusion_type": "OPAQUE",
|
||||
"processor": "cuda",
|
||||
"attr": [
|
||||
{
|
||||
"name": "input_shape",
|
||||
"param_type": "required",
|
||||
"type": "listInt"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "HEAD"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
mean_grad_op_info = AkgRegOp("SimpleMeanGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "HEAD") \
|
||||
.output(0, "output") \
|
||||
.attr("input_shape", "required", "listInt") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(mean_grad_op_info)
|
||||
def _simple_mean_grad_akg():
|
||||
"""SimpleMeanGrad AutoDiff register"""
|
||||
return
|
||||
|
|
|
@ -13,50 +13,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""Mul op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Mul",
|
||||
"imply_type": "AutoDiff",
|
||||
"fusion_type": "OPAQUE",
|
||||
"processor": "cuda",
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
mul_op_info = AkgRegOp("Mul") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x") \
|
||||
.input(1, "y") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(mul_op_info)
|
||||
def _mul_akg():
|
||||
"""Mul AutoDiff register"""
|
||||
return
|
||||
|
|
|
@ -13,40 +13,18 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""ReLU6 op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ReLU6",
|
||||
"imply_type": "AutoDiff",
|
||||
"fusion_type": "OPAQUE",
|
||||
"processor": "cuda",
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
relu_op_info = AkgRegOp("ReLU6") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(relu_op_info)
|
||||
def _relu6_akg():
|
||||
"""ReLU6 AutoDiff register"""
|
||||
return
|
||||
|
|
|
@ -13,50 +13,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""ReLU6Grad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ReLU6Grad",
|
||||
"imply_type": "AutoDiff",
|
||||
"fusion_type": "OPAQUE",
|
||||
"processor": "cuda",
|
||||
"attr": [
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y_grad"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
relu_grad_op_info = AkgRegOp("ReLU6Grad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "y_grad") \
|
||||
.input(1, "x") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(relu_grad_op_info)
|
||||
def _relu6_grad_akg():
|
||||
"""ReLU6Grad AutoDiff register"""
|
||||
return
|
||||
|
|
|
@ -13,45 +13,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""Squeeze op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Squeeze",
|
||||
"imply_type": "AutoDiff",
|
||||
"fusion_type": "OPAQUE",
|
||||
"processor": "cuda",
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "optional",
|
||||
"type": "listInt"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
squeeze_op_info = AkgRegOp("SqueezeGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x") \
|
||||
.output(0, "output") \
|
||||
.attr("axis", "optional", "listInt") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(squeeze_op_info)
|
||||
def _squeeze_akg():
|
||||
"""Squeeze AutoDiff register"""
|
||||
return
|
||||
|
|
|
@ -13,50 +13,20 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""SqueezeGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "SqueezeGrad",
|
||||
"imply_type": "AutoDiff",
|
||||
"fusion_type": "OPAQUE",
|
||||
"processor": "cuda",
|
||||
"attr": [
|
||||
{
|
||||
"name": "x_shape",
|
||||
"param_type": "required",
|
||||
"type": "listInt"
|
||||
},
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "optional",
|
||||
"type": "listInt"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "y_grad"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
squeeze_grad_op_info = AkgRegOp("SqueezeGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "y_grad") \
|
||||
.output(0, "output") \
|
||||
.attr("x_shape", "required", "listInt") \
|
||||
.attr("axis", "optional", "listInt") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(squeeze_grad_op_info)
|
||||
def _squeeze_grad_akg():
|
||||
"""SqueezeGrad AutoDiff register"""
|
||||
return
|
||||
|
|
|
@ -13,45 +13,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""Tile op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Tile",
|
||||
"imply_type": "AutoDiff",
|
||||
"fusion_type": "OPAQUE",
|
||||
"processor": "cuda",
|
||||
"attr": [
|
||||
{
|
||||
"name": "multiples",
|
||||
"param_type": "required",
|
||||
"type": "listInt"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "x"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float32", "float16"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "DefaultFormat"
|
||||
],
|
||||
"name": "output"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
tile_op_info = AkgRegOp("Tile") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x") \
|
||||
.output(0, "output") \
|
||||
.attr("multiples", "required", "listInt") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(tile_op_info)
|
||||
def _tile_akg():
|
||||
"""Tile AutoDiff register"""
|
||||
return
|
||||
|
|
|
@ -205,6 +205,64 @@ class RegOp():
|
|||
return op_info
|
||||
|
||||
|
||||
class AkgRegOp(RegOp):
|
||||
"""Class for Akg op info register"""
|
||||
|
||||
def __init__(self, op_name):
|
||||
super(AkgRegOp, self).__init__(op_name)
|
||||
self.imply_type = "AutoDiff"
|
||||
self.processor = "cuda"
|
||||
|
||||
def input(self, index=None, name=None, **kwargs):
|
||||
"""
|
||||
Register Akg op input information.
|
||||
|
||||
Args:
|
||||
index (int): Order of the input. Default: None.
|
||||
name (str): Name of the input. Default: None.
|
||||
kwargs (dict): Other information for the input.
|
||||
"""
|
||||
param_list = [index, name]
|
||||
key_list = ["index", "name"]
|
||||
fn_list = [self._is_int, self._is_string]
|
||||
input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.inputs.append(input_dict)
|
||||
return self
|
||||
|
||||
def output(self, index=None, name=None, **kwargs):
|
||||
"""
|
||||
Register Akg op output information.
|
||||
|
||||
Args:
|
||||
index (int): Order of the output. Default: None.
|
||||
name (str): Name of the output. Default: None.
|
||||
kwargs (dict): Other information for the output.
|
||||
"""
|
||||
param_list = [index, name]
|
||||
key_list = ["index", "name"]
|
||||
fn_list = [self._is_int, self._is_string]
|
||||
output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.outputs.append(output_dict)
|
||||
return self
|
||||
|
||||
def attr(self, name=None, param_type=None, value_type=None, **kwargs):
|
||||
"""
|
||||
Register Akg op attribute information.
|
||||
|
||||
Args:
|
||||
name (str): Name of the attribute. Default: None.
|
||||
param_type (str): Param type of the attribute. Default: None.
|
||||
value_type (str): Value type of the attribute. Default: None.
|
||||
kwargs (dict): Other information for the attribute.
|
||||
"""
|
||||
param_list = [name, param_type, value_type]
|
||||
key_list = ["name", "param_type", "type"]
|
||||
fn_list = [self._is_string]
|
||||
attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.attr_.append(attr_dict)
|
||||
return self
|
||||
|
||||
|
||||
class AiCPURegOp(RegOp):
|
||||
"""Class for AiCPU op info register"""
|
||||
|
||||
|
|
Loading…
Reference in New Issue