forked from mindspore-Ecosystem/mindspore
modify Gelu、FastGelu to GeLU and FastGeLU
This commit is contained in:
parent
408159e301
commit
30a27b2adb
|
@ -150,10 +150,10 @@
|
||||||
{"op_name": "Conv2DBackpropInput", "inputs": [{"index": 0, "name": "out_backprop", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "filter", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [{"name": "input_sizes", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "stride", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pad_list", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "dilation", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "groups", "param_type": "optional", "type": "int", "value": "all"}, {"name": "format", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "CONVLUTION", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "FracZ"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "conv2d_backprop_input_d.so", "compute_cost": 10, "kernel_name": "conv2d_backprop_input_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
{"op_name": "Conv2DBackpropInput", "inputs": [{"index": 0, "name": "out_backprop", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "filter", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [{"name": "input_sizes", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "stride", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pad_list", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "dilation", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "groups", "param_type": "optional", "type": "int", "value": "all"}, {"name": "format", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "CONVLUTION", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "FracZ"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "conv2d_backprop_input_d.so", "compute_cost": 10, "kernel_name": "conv2d_backprop_input_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
||||||
{"op_name": "ConfusionMulGrad", "inputs": [{"index": 0, "name": "input0", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "input1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "input2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output0", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "output1", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "keep_dims", "param_type": "required", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "confusion_mul_grad.so", "compute_cost": 10, "kernel_name": "confusion_mul_grad", "partial_flag": false, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
{"op_name": "ConfusionMulGrad", "inputs": [{"index": 0, "name": "input0", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "input1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "input2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output0", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "output1", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "keep_dims", "param_type": "required", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "confusion_mul_grad.so", "compute_cost": 10, "kernel_name": "confusion_mul_grad", "partial_flag": false, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
||||||
{"op_name": "DropoutDoMask", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "mask", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "keep_prob", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["", ""], ["", ""], ["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "drop_out_do_mask.so", "compute_cost": 10, "kernel_name": "drop_out_do_mask", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": "dynamicFormat"}
|
{"op_name": "DropoutDoMask", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "mask", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "keep_prob", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["", ""], ["", ""], ["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "drop_out_do_mask.so", "compute_cost": 10, "kernel_name": "drop_out_do_mask", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": "dynamicFormat"}
|
||||||
{"op_name": "Gelu", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "gelu.so", "compute_cost": 10, "kernel_name": "gelu", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": "formatAgnostic"}
|
{"op_name": "GeLU", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "gelu.so", "compute_cost": 10, "kernel_name": "gelu", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": "formatAgnostic"}
|
||||||
{"op_name": "GeluGrad", "inputs": [{"index": 0, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "z", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "gelu_grad.so", "compute_cost": 10, "kernel_name": "gelu_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
{"op_name": "GeLUGrad", "inputs": [{"index": 0, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "z", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "gelu_grad.so", "compute_cost": 10, "kernel_name": "gelu_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
||||||
{"op_name": "FastGelu", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "fast_gelu.so", "compute_cost": 10, "kernel_name": "fast_gelu", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": "formatAgnostic"}
|
{"op_name": "FastGeLU", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "fast_gelu.so", "compute_cost": 10, "kernel_name": "fast_gelu", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": "formatAgnostic"}
|
||||||
{"op_name": "FastGeluGrad", "inputs": [{"index": 0, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "z", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "fast_gelu_grad.so", "compute_cost": 10, "kernel_name": "fast_gelu_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
{"op_name": "FastGeLUGrad", "inputs": [{"index": 0, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "z", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "fast_gelu_grad.so", "compute_cost": 10, "kernel_name": "fast_gelu_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
||||||
{"op_name": "MaxPool", "inputs": [{"index": 0, "name": "input_data", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output_data", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "kernel_size", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pad_mode", "param_type": "required", "type": "str", "value": "all"}, {"name": "format", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "max_pool.so", "compute_cost": 10, "kernel_name": "max_pool", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
{"op_name": "MaxPool", "inputs": [{"index": 0, "name": "input_data", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output_data", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "kernel_size", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pad_mode", "param_type": "required", "type": "str", "value": "all"}, {"name": "format", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "max_pool.so", "compute_cost": 10, "kernel_name": "max_pool", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
||||||
{"op_name": "MaxPoolGrad", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "kernel_size", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pad_mode", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "max_pool_grad.so", "compute_cost": 10, "kernel_name": "max_pool_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
{"op_name": "MaxPoolGrad", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "kernel_size", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pad_mode", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "max_pool_grad.so", "compute_cost": 10, "kernel_name": "max_pool_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
||||||
{"op_name": "MaxPoolGradWithArgmax", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "argmax", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "kernel_size", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pad_mode", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["uint16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["int64", "NC1HWC0"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "max_pool_grad_with_argmax.so", "compute_cost": 10, "kernel_name": "max_pool_grad_with_argmax", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
{"op_name": "MaxPoolGradWithArgmax", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "argmax", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "kernel_size", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pad_mode", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["uint16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["int64", "NC1HWC0"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "max_pool_grad_with_argmax.so", "compute_cost": 10, "kernel_name": "max_pool_grad_with_argmax", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "dynamic_shape": false, "need_check_supported": false, "op_pattern": ""}
|
||||||
|
|
|
@ -22,7 +22,7 @@ HALF = 0.5
|
||||||
|
|
||||||
|
|
||||||
def expand_gelu(expand_info):
|
def expand_gelu(expand_info):
|
||||||
"""Gelu expander"""
|
"""GeLU expander"""
|
||||||
# cal formula are:
|
# cal formula are:
|
||||||
# gelu(x) is 0.5 * x * (1.0 + tanh(y))
|
# gelu(x) is 0.5 * x * (1.0 + tanh(y))
|
||||||
# y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
|
# y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
|
||||||
|
|
|
@ -23,7 +23,7 @@ HALF = 0.5
|
||||||
|
|
||||||
|
|
||||||
def expand_gelugrad(expand_info):
|
def expand_gelugrad(expand_info):
|
||||||
"""GeluGrad expander"""
|
"""GeLUGrad expander"""
|
||||||
# cal formula are:
|
# cal formula are:
|
||||||
# gelu_grad(dy, x) is dy * y'
|
# gelu_grad(dy, x) is dy * y'
|
||||||
# y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
|
# y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
|
||||||
|
|
|
@ -128,7 +128,7 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
operate_type_ = FLOOR;
|
operate_type_ = FLOOR;
|
||||||
} else if (kernel_name == prim::kPrimReciprocal->name()) {
|
} else if (kernel_name == prim::kPrimReciprocal->name()) {
|
||||||
operate_type_ = RECIPROCAL;
|
operate_type_ = RECIPROCAL;
|
||||||
} else if (kernel_name == prim::kPrimGelu->name()) {
|
} else if (kernel_name == prim::kPrimGeLU->name()) {
|
||||||
operate_type_ = GELU;
|
operate_type_ = GELU;
|
||||||
} else if (kernel_name == prim::kPrimAsin->name()) {
|
} else if (kernel_name == prim::kPrimAsin->name()) {
|
||||||
operate_type_ = ASIN;
|
operate_type_ = ASIN;
|
||||||
|
|
|
@ -66,7 +66,7 @@ MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutput
|
||||||
ArithmeticSelfCPUKernel);
|
ArithmeticSelfCPUKernel);
|
||||||
MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
ArithmeticSelfCPUKernel);
|
ArithmeticSelfCPUKernel);
|
||||||
MS_REG_CPU_KERNEL(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
ArithmeticSelfCPUKernel);
|
ArithmeticSelfCPUKernel);
|
||||||
MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||||
ArithmeticSelfCPUKernel);
|
ArithmeticSelfCPUKernel);
|
||||||
|
|
|
@ -147,7 +147,7 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
operate_type_ = TANHGRAD;
|
operate_type_ = TANHGRAD;
|
||||||
} else if (kernel_name == "SqrtGrad") {
|
} else if (kernel_name == "SqrtGrad") {
|
||||||
operate_type_ = SQRTGRAD;
|
operate_type_ = SQRTGRAD;
|
||||||
} else if (kernel_name == "GeluGrad") {
|
} else if (kernel_name == "GeLUGrad") {
|
||||||
operate_type_ = GELUGRAD;
|
operate_type_ = GELUGRAD;
|
||||||
} else if (kernel_name == "AsinGrad") {
|
} else if (kernel_name == "AsinGrad") {
|
||||||
operate_type_ = ASINGRAD;
|
operate_type_ = ASINGRAD;
|
||||||
|
|
|
@ -88,7 +88,7 @@ MS_REG_CPU_KERNEL(
|
||||||
TanhGrad,
|
TanhGrad,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
EltWiseGradCPUKernel);
|
EltWiseGradCPUKernel);
|
||||||
MS_REG_CPU_KERNEL(GeluGrad,
|
MS_REG_CPU_KERNEL(GeLUGrad,
|
||||||
KernelAttr()
|
KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
|
|
@ -18,14 +18,14 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
MS_REG_GPU_KERNEL_ONE(GeluGrad,
|
MS_REG_GPU_KERNEL_ONE(GeLUGrad,
|
||||||
KernelAttr()
|
KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
GeLUGpuGradKernel, float)
|
GeLUGpuGradKernel, float)
|
||||||
MS_REG_GPU_KERNEL_ONE(GeluGrad,
|
MS_REG_GPU_KERNEL_ONE(GeLUGrad,
|
||||||
KernelAttr()
|
KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
|
|
@ -18,9 +18,9 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_GPU_KERNEL_ONE(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
GeluGpuKernel, float)
|
GeluGpuKernel, float)
|
||||||
MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
MS_REG_GPU_KERNEL_ONE(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||||
GeluGpuKernel, half)
|
GeluGpuKernel, half)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -701,7 +701,7 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo
|
||||||
std::unordered_set<PrimitivePtr> GetExpandOps() {
|
std::unordered_set<PrimitivePtr> GetExpandOps() {
|
||||||
std::unordered_set<PrimitivePtr> expand_ops = {
|
std::unordered_set<PrimitivePtr> expand_ops = {
|
||||||
prim::kPrimSquare,
|
prim::kPrimSquare,
|
||||||
prim::kPrimGeluGrad,
|
prim::kPrimGeLUGrad,
|
||||||
#if ENABLE_D
|
#if ENABLE_D
|
||||||
prim::kPrimTile,
|
prim::kPrimTile,
|
||||||
prim::kPrimSqrtGrad,
|
prim::kPrimSqrtGrad,
|
||||||
|
@ -709,7 +709,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
|
||||||
#elif ENABLE_GPU
|
#elif ENABLE_GPU
|
||||||
prim::kPrimBiasAdd,
|
prim::kPrimBiasAdd,
|
||||||
prim::kPrimBiasAddGrad,
|
prim::kPrimBiasAddGrad,
|
||||||
prim::kPrimGelu,
|
prim::kPrimGeLU,
|
||||||
prim::kPrimFusedAdam,
|
prim::kPrimFusedAdam,
|
||||||
prim::kPrimFusedAdamWeightDecay,
|
prim::kPrimFusedAdamWeightDecay,
|
||||||
prim::kPrimReduceMean,
|
prim::kPrimReduceMean,
|
||||||
|
|
|
@ -77,7 +77,7 @@ class RegisterAction {
|
||||||
|
|
||||||
// operator register
|
// operator register
|
||||||
REGISTER(MatMulInfo);
|
REGISTER(MatMulInfo);
|
||||||
REGISTER(GeluInfo);
|
REGISTER(GeLUInfo);
|
||||||
REGISTER(VirtualDatasetInfo);
|
REGISTER(VirtualDatasetInfo);
|
||||||
REGISTER(BatchParallelInfo);
|
REGISTER(BatchParallelInfo);
|
||||||
REGISTER(TanhInfo);
|
REGISTER(TanhInfo);
|
||||||
|
|
|
@ -82,12 +82,12 @@ class ActivationOther : public Activation {
|
||||||
Status GetAttrs() override;
|
Status GetAttrs() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class GeluInfo : public ActivationOther {
|
class GeLUInfo : public ActivationOther {
|
||||||
public:
|
public:
|
||||||
GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
GeLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
const PrimitiveAttrs &attrs)
|
const PrimitiveAttrs &attrs)
|
||||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<GeLUCost>()) {}
|
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<GeLUCost>()) {}
|
||||||
~GeluInfo() override = default;
|
~GeLUInfo() override = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TanhInfo : public ActivationOther {
|
class TanhInfo : public ActivationOther {
|
||||||
|
|
|
@ -187,7 +187,7 @@ constexpr char CONCAT[] = "Concat";
|
||||||
constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits";
|
constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits";
|
||||||
constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLogits";
|
constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLogits";
|
||||||
constexpr char MATMUL[] = "MatMul";
|
constexpr char MATMUL[] = "MatMul";
|
||||||
constexpr char GELU[] = "Gelu";
|
constexpr char GELU[] = "GeLU";
|
||||||
constexpr char TANH[] = "Tanh";
|
constexpr char TANH[] = "Tanh";
|
||||||
constexpr char RECEIVE[] = "Receive";
|
constexpr char RECEIVE[] = "Receive";
|
||||||
constexpr char SEND[] = "Send";
|
constexpr char SEND[] = "Send";
|
||||||
|
|
|
@ -459,7 +459,6 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t>
|
||||||
op_prim->EndRecordAddAttr();
|
op_prim->EndRecordAddAttr();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void ConvertAttrToUnifyMindIR(const OpExecInfoPtr &op_run_info) {
|
void ConvertAttrToUnifyMindIR(const OpExecInfoPtr &op_run_info) {
|
||||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||||
PrimitivePtr op_prim = op_run_info->py_primitive;
|
PrimitivePtr op_prim = op_run_info->py_primitive;
|
||||||
|
@ -479,7 +478,6 @@ void ConvertAttrToUnifyMindIR(const OpExecInfoPtr &op_run_info) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
|
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
|
||||||
if (utils::isa<VectorRef>(base_ref)) {
|
if (utils::isa<VectorRef>(base_ref)) {
|
||||||
auto ref_list = utils::cast<VectorRef>(base_ref);
|
auto ref_list = utils::cast<VectorRef>(base_ref);
|
||||||
|
|
|
@ -101,27 +101,27 @@ ATTR_MAP(TanhGrad) = EMPTY_ATTR_MAP;
|
||||||
OUTPUT_MAP(TanhGrad) = {{0, OUTPUT_DESC(z)}};
|
OUTPUT_MAP(TanhGrad) = {{0, OUTPUT_DESC(z)}};
|
||||||
REG_ADPT_DESC(TanhGrad, prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad))
|
REG_ADPT_DESC(TanhGrad, prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad))
|
||||||
|
|
||||||
// Gelu
|
// GeLU
|
||||||
INPUT_MAP(Gelu) = {{1, INPUT_DESC(x)}};
|
INPUT_MAP(Gelu) = {{1, INPUT_DESC(x)}};
|
||||||
ATTR_MAP(Gelu) = EMPTY_ATTR_MAP;
|
ATTR_MAP(Gelu) = EMPTY_ATTR_MAP;
|
||||||
OUTPUT_MAP(Gelu) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(Gelu) = {{0, OUTPUT_DESC(y)}};
|
||||||
REG_ADPT_DESC(Gelu, prim::kPrimGelu->name(), ADPT_DESC(Gelu))
|
REG_ADPT_DESC(Gelu, prim::kPrimGeLU->name(), ADPT_DESC(Gelu))
|
||||||
|
|
||||||
// GeluGrad
|
// GeLUGrad
|
||||||
INPUT_MAP(GeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(y)}};
|
INPUT_MAP(GeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(y)}};
|
||||||
ATTR_MAP(GeluGrad) = EMPTY_ATTR_MAP;
|
ATTR_MAP(GeluGrad) = EMPTY_ATTR_MAP;
|
||||||
OUTPUT_MAP(GeluGrad) = {{0, OUTPUT_DESC(z)}};
|
OUTPUT_MAP(GeluGrad) = {{0, OUTPUT_DESC(z)}};
|
||||||
REG_ADPT_DESC(GeluGrad, prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad))
|
REG_ADPT_DESC(GeluGrad, prim::kPrimGeLUGrad->name(), ADPT_DESC(GeluGrad))
|
||||||
|
|
||||||
// FastGelu
|
// FastGeLU
|
||||||
INPUT_MAP(FastGelu) = {{1, INPUT_DESC(x)}};
|
INPUT_MAP(FastGelu) = {{1, INPUT_DESC(x)}};
|
||||||
ATTR_MAP(FastGelu) = EMPTY_ATTR_MAP;
|
ATTR_MAP(FastGelu) = EMPTY_ATTR_MAP;
|
||||||
OUTPUT_MAP(FastGelu) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(FastGelu) = {{0, OUTPUT_DESC(y)}};
|
||||||
REG_ADPT_DESC(FastGelu, prim::kPrimFastGelu->name(), ADPT_DESC(FastGelu))
|
REG_ADPT_DESC(FastGelu, prim::kPrimFastGeLU->name(), ADPT_DESC(FastGelu))
|
||||||
|
|
||||||
// FastGeluGrad
|
// FastGeLUGrad
|
||||||
INPUT_MAP(FastGeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}};
|
INPUT_MAP(FastGeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}};
|
||||||
ATTR_MAP(FastGeluGrad) = EMPTY_ATTR_MAP;
|
ATTR_MAP(FastGeluGrad) = EMPTY_ATTR_MAP;
|
||||||
OUTPUT_MAP(FastGeluGrad) = {{0, OUTPUT_DESC(z)}};
|
OUTPUT_MAP(FastGeluGrad) = {{0, OUTPUT_DESC(z)}};
|
||||||
REG_ADPT_DESC(FastGeluGrad, prim::kPrimFastGeluGrad->name(), ADPT_DESC(FastGeluGrad))
|
REG_ADPT_DESC(FastGeluGrad, prim::kPrimFastGeLUGrad->name(), ADPT_DESC(FastGeluGrad))
|
||||||
} // namespace mindspore::transform
|
} // namespace mindspore::transform
|
||||||
|
|
|
@ -65,13 +65,13 @@ AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplGeLU(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplGeLUGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplFastGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplFastGeLU(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplFastGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplFastGeLUGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
|
|
@ -43,6 +43,10 @@ constexpr auto kScalarUsub = "ScalarUsub";
|
||||||
constexpr auto kStack = "Stack";
|
constexpr auto kStack = "Stack";
|
||||||
constexpr auto kUnstack = "Unstack";
|
constexpr auto kUnstack = "Unstack";
|
||||||
constexpr auto kTupleGetItem = "TupleGetItem";
|
constexpr auto kTupleGetItem = "TupleGetItem";
|
||||||
|
constexpr auto kGeLU = "GeLU";
|
||||||
|
constexpr auto kGeLUGrad = "GeLUGrad";
|
||||||
|
constexpr auto kFastGeLU = "FastGeLU";
|
||||||
|
constexpr auto kFastGeLUGrad = "FastGeLUGrad";
|
||||||
|
|
||||||
// Here list all primitives used in backend or some special primitives used by core.
|
// Here list all primitives used in backend or some special primitives used by core.
|
||||||
// Arithmetic
|
// Arithmetic
|
||||||
|
@ -257,11 +261,10 @@ inline const PrimitivePtr kPrimDropout = std::make_shared<Primitive>("Dropout");
|
||||||
inline const PrimitivePtr kPrimUniformReal = std::make_shared<Primitive>("UniformReal");
|
inline const PrimitivePtr kPrimUniformReal = std::make_shared<Primitive>("UniformReal");
|
||||||
inline const PrimitivePtr kPrimCudnnUniformReal = std::make_shared<Primitive>("CudnnUniformReal");
|
inline const PrimitivePtr kPrimCudnnUniformReal = std::make_shared<Primitive>("CudnnUniformReal");
|
||||||
inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
|
inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
|
||||||
inline const PrimitivePtr kPrimGeLU = std::make_shared<Primitive>("Gelu");
|
inline const PrimitivePtr kPrimGeLU = std::make_shared<Primitive>(kGeLU);
|
||||||
inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
|
inline const PrimitivePtr kPrimGeLUGrad = std::make_shared<Primitive>(kGeLUGrad);
|
||||||
inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");
|
inline const PrimitivePtr kPrimFastGeLU = std::make_shared<Primitive>(kFastGeLU);
|
||||||
inline const PrimitivePtr kPrimFastGelu = std::make_shared<Primitive>("FastGelu");
|
inline const PrimitivePtr kPrimFastGeLUGrad = std::make_shared<Primitive>(kFastGeLUGrad);
|
||||||
inline const PrimitivePtr kPrimFastGeluGrad = std::make_shared<Primitive>("FastGeluGrad");
|
|
||||||
inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
|
inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
|
||||||
inline const PrimitivePtr kPrimElu = std::make_shared<Primitive>("Elu");
|
inline const PrimitivePtr kPrimElu = std::make_shared<Primitive>("Elu");
|
||||||
inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6");
|
inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6");
|
||||||
|
|
|
@ -0,0 +1,444 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "tools/optimizer/graph/primitive_adjust_pass.h"
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include "ops/abs.h"
|
||||||
|
#include "ops/batch_norm.h"
|
||||||
|
#include "ops/elu.h"
|
||||||
|
#include "ops/depthwise_conv2d.h"
|
||||||
|
#include "ops/fused_batch_norm.h"
|
||||||
|
#include "ops/fusion/activation.h"
|
||||||
|
#include "ops/fusion/add_fusion.h"
|
||||||
|
#include "ops/fusion/adder_fusion.h"
|
||||||
|
#include "ops/fusion/arg_max_fusion.h"
|
||||||
|
#include "ops/fusion/arg_min_fusion.h"
|
||||||
|
#include "ops/fusion/avg_pool_fusion.h"
|
||||||
|
#include "ops/fusion/conv2d_backprop_filter_fusion.h"
|
||||||
|
#include "ops/fusion/conv2d_backprop_input_fusion.h"
|
||||||
|
#include "ops/fusion/conv2d_fusion.h"
|
||||||
|
#include "ops/fusion/conv2d_transpose_fusion.h"
|
||||||
|
#include "ops/fusion/div_fusion.h"
|
||||||
|
#include "ops/fusion/exp_fusion.h"
|
||||||
|
#include "ops/fusion/l2_normalize_fusion.h"
|
||||||
|
#include "ops/fusion/layer_norm_fusion.h"
|
||||||
|
#include "ops/fusion/max_pool_fusion.h"
|
||||||
|
#include "ops/fusion/mul_fusion.h"
|
||||||
|
#include "ops/fusion/pad_fusion.h"
|
||||||
|
#include "ops/fusion/prelu_fusion.h"
|
||||||
|
#include "ops/fusion/reduce_fusion.h"
|
||||||
|
#include "ops/fusion/scale_fusion.h"
|
||||||
|
#include "ops/fusion/sub_fusion.h"
|
||||||
|
#include "ops/fusion/tile_fusion.h"
|
||||||
|
#include "ops/fusion/topk_fusion.h"
|
||||||
|
#include "ops/gather.h"
|
||||||
|
#include "ops/gelu.h"
|
||||||
|
#include "ops/leaky_relu.h"
|
||||||
|
#include "ops/mat_mul.h"
|
||||||
|
#include "ops/reduce_all.h"
|
||||||
|
#include "ops/reduce_asum.h"
|
||||||
|
#include "ops/reduce_max.h"
|
||||||
|
#include "ops/reduce_mean.h"
|
||||||
|
#include "ops/reduce_min.h"
|
||||||
|
#include "ops/reduce_prod.h"
|
||||||
|
#include "ops/reduce_sum.h"
|
||||||
|
#include "ops/reduce_sum_square.h"
|
||||||
|
#include "ops/relu.h"
|
||||||
|
#include "ops/relu6.h"
|
||||||
|
#include "ops/resize.h"
|
||||||
|
#include "ops/resize_bilinear.h"
|
||||||
|
#include "ops/sigmoid.h"
|
||||||
|
#include "ops/tanh.h"
|
||||||
|
|
||||||
|
using mindspore::ops::kNameAbs;
|
||||||
|
using mindspore::ops::kNameAdd;
|
||||||
|
using mindspore::ops::kNameAdder;
|
||||||
|
using mindspore::ops::kNameArgMax;
|
||||||
|
using mindspore::ops::kNameArgMin;
|
||||||
|
using mindspore::ops::kNameAvgPool;
|
||||||
|
using mindspore::ops::kNameBatchNorm;
|
||||||
|
using mindspore::ops::kNameConv2D;
|
||||||
|
using mindspore::ops::kNameConv2DBackpropFilter;
|
||||||
|
using mindspore::ops::kNameConv2DBackpropInput;
|
||||||
|
using mindspore::ops::kNameConv2dTranspose;
|
||||||
|
using mindspore::ops::kNameDepthWiseConv2D;
|
||||||
|
using mindspore::ops::kNameDiv;
|
||||||
|
using mindspore::ops::kNameElu;
|
||||||
|
using mindspore::ops::kNameExp;
|
||||||
|
using mindspore::ops::kNameGeLU;
|
||||||
|
using mindspore::ops::kNameL2Normalize;
|
||||||
|
using mindspore::ops::kNameLayerNorm;
|
||||||
|
using mindspore::ops::kNameLeakyRelu;
|
||||||
|
using mindspore::ops::kNameMaxPool;
|
||||||
|
using mindspore::ops::kNameMul;
|
||||||
|
using mindspore::ops::kNamePad;
|
||||||
|
using mindspore::ops::kNamePReLU;
|
||||||
|
using mindspore::ops::kNameReduceAll;
|
||||||
|
using mindspore::ops::kNameReduceASum;
|
||||||
|
using mindspore::ops::kNameReduceMax;
|
||||||
|
using mindspore::ops::kNameReduceMean;
|
||||||
|
using mindspore::ops::kNameReduceMin;
|
||||||
|
using mindspore::ops::kNameReduceProd;
|
||||||
|
using mindspore::ops::kNameReduceSum;
|
||||||
|
using mindspore::ops::kNameReduceSumSquare;
|
||||||
|
using mindspore::ops::kNameReLU;
|
||||||
|
using mindspore::ops::kNameReLU6;
|
||||||
|
using mindspore::ops::kNameResizeBilinear;
|
||||||
|
using mindspore::ops::kNameScale;
|
||||||
|
using mindspore::ops::kNameSigmoid;
|
||||||
|
using mindspore::ops::kNameSub;
|
||||||
|
using mindspore::ops::kNameTanh;
|
||||||
|
using mindspore::ops::kNameTile;
|
||||||
|
using mindspore::ops::kNameTopK;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
constexpr auto kNameArgMaxWithValue = "ArgMaxWithValue";
|
||||||
|
constexpr auto kNameArgMinWithValue = "ArgMinWithValue";
|
||||||
|
constexpr auto kNameBatchMatMul = "BatchMatMul";
|
||||||
|
constexpr auto kNameGatherV2 = "GatherV2";
|
||||||
|
constexpr auto kNameTensorAdd = "TensorAdd";
|
||||||
|
std::map<std::string, mindspore::ActivationType> activation_map = {
|
||||||
|
{ops::kNameAbs, mindspore::ABS}, {ops::kNameElu, mindspore::ELU},
|
||||||
|
{ops::kNameGeLU, mindspore::GELU}, {ops::kNameLeakyRelu, mindspore::LEAKY_RELU},
|
||||||
|
{ops::kNameReLU, mindspore::RELU}, {ops::kNameReLU6, mindspore::RELU6},
|
||||||
|
{ops::kNameSigmoid, mindspore::SIGMOID}, {ops::kNameTanh, mindspore::TANH}};
|
||||||
|
|
||||||
|
std::map<std::string, mindspore::ReduceMode> reduce_map = {
|
||||||
|
{ops::kNameReduceAll, mindspore::Reduce_All}, {ops::kNameReduceASum, mindspore::Reduce_ASum},
|
||||||
|
{ops::kNameReduceMax, mindspore::Reduce_Max}, {ops::kNameReduceMean, mindspore::Reduce_Mean},
|
||||||
|
{ops::kNameReduceMin, mindspore::Reduce_Min}, {ops::kNameReduceProd, mindspore::Reduce_Prod},
|
||||||
|
{ops::kNameReduceSum, mindspore::Reduce_Sum}, {ops::kNameReduceSumSquare, mindspore::Reduce_Sum_Square}};
|
||||||
|
|
||||||
|
int AttrAdjust(const PrimitivePtr &prim, const std::string &name, const std::vector<int> &position) {
|
||||||
|
if (prim->GetAttr(name) == nullptr) {
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
auto value_ptr = prim->GetAttr(name);
|
||||||
|
if (utils::isa<ValueSequeuePtr>(value_ptr)) {
|
||||||
|
if (value_ptr->cast<ValueSequeuePtr>()->value().front()->type()->number_type() != kNumberTypeInt64) {
|
||||||
|
MS_LOG(ERROR) << "the func is to adjust attr which is array, please check the attr.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
} else if (value_ptr->type()->number_type() != kNumberTypeInt64) {
|
||||||
|
MS_LOG(ERROR) << "the func is to adjust attr which is array, please check the attr.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto origin_value = CastToInt(prim->GetAttr(name));
|
||||||
|
std::vector<int64_t> new_value;
|
||||||
|
if (name == ops::kKernelSize && origin_value.size() == 1) {
|
||||||
|
new_value.push_back(origin_value[0]);
|
||||||
|
new_value.push_back(origin_value[0]);
|
||||||
|
} else {
|
||||||
|
for (auto index : position) {
|
||||||
|
if (index >= static_cast<int>(origin_value.size())) {
|
||||||
|
MS_LOG(ERROR) << "index is out of range.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
new_value.push_back(static_cast<int64_t>(origin_value[index]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prim->AddAttr(name, MakeValue(new_value));
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
int MoveAttrMapCommon(const ValueNodePtr &value_node) {
|
||||||
|
MS_ASSERT(value_node != nullptr);
|
||||||
|
auto src_prim = GetValueNode<PrimitivePtr>(value_node);
|
||||||
|
if (src_prim == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "value node is invalid.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto dst_prim = std::make_shared<T>();
|
||||||
|
MS_ASSERT(dst_prim != nullptr);
|
||||||
|
dst_prim->SetAttrs(src_prim->attrs());
|
||||||
|
value_node->set_value(dst_prim);
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int MoveAttrMapActivation(const ValueNodePtr &value_node) {
|
||||||
|
MS_ASSERT(value_node != nullptr);
|
||||||
|
auto src_prim = GetValueNode<PrimitivePtr>(value_node);
|
||||||
|
if (src_prim == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "value node is invalid.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto dst_prim = std::make_shared<ops::Activation>();
|
||||||
|
MS_ASSERT(dst_prim != nullptr);
|
||||||
|
dst_prim->SetAttrs(src_prim->attrs());
|
||||||
|
auto iter = activation_map.find(src_prim->name());
|
||||||
|
if (iter == activation_map.end()) {
|
||||||
|
MS_LOG(ERROR) << "activation mode is unsupport.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
dst_prim->set_activation_type(iter->second);
|
||||||
|
value_node->set_value(dst_prim);
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int MoveAttrMapReduce(const ValueNodePtr &value_node) {
|
||||||
|
MS_ASSERT(value_node != nullptr);
|
||||||
|
auto src_prim = GetValueNode<PrimitivePtr>(value_node);
|
||||||
|
if (src_prim == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "value node is invalid.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto dst_prim = std::make_shared<ops::ReduceFusion>();
|
||||||
|
MS_ASSERT(dst_prim != nullptr);
|
||||||
|
dst_prim->SetAttrs(src_prim->attrs());
|
||||||
|
auto iter = reduce_map.find(src_prim->name());
|
||||||
|
if (iter == reduce_map.end()) {
|
||||||
|
MS_LOG(ERROR) << "reduce mode is unsupport.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
dst_prim->set_mode(iter->second);
|
||||||
|
dst_prim->set_coeff(1.0f);
|
||||||
|
value_node->set_value(dst_prim);
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int MoveAttrMapConv2D(const ValueNodePtr &value_node) {
|
||||||
|
MS_ASSERT(value_node != nullptr);
|
||||||
|
auto src_prim = GetValueNode<PrimitivePtr>(value_node);
|
||||||
|
if (src_prim == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "value node is invalid.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto dst_prim = std::make_shared<ops::Conv2DFusion>();
|
||||||
|
MS_ASSERT(dst_prim != nullptr);
|
||||||
|
dst_prim->SetAttrs(src_prim->attrs());
|
||||||
|
auto status = AttrAdjust(dst_prim, ops::kStride, {2, 3});
|
||||||
|
if (status != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "adjust stride failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
status = AttrAdjust(dst_prim, ops::kDilation, {2, 3});
|
||||||
|
if (status != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "adjust dilation failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
status = AttrAdjust(dst_prim, ops::kKernelSize, {0, 1});
|
||||||
|
if (status != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "adjust kernel size failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
int64_t group = 1;
|
||||||
|
if (dst_prim->GetAttr(ops::kGroup) != nullptr) {
|
||||||
|
group = dst_prim->get_group();
|
||||||
|
}
|
||||||
|
if (group > 1) {
|
||||||
|
dst_prim->AddAttr(ops::kIsDepthWise, MakeValue<bool>(true));
|
||||||
|
}
|
||||||
|
dst_prim->set_group(group);
|
||||||
|
value_node->set_value(dst_prim);
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int MoveAttrPool(const ValueNodePtr &value_node) {
|
||||||
|
MS_ASSERT(value_node != nullptr);
|
||||||
|
auto src_prim = GetValueNode<PrimitivePtr>(value_node);
|
||||||
|
if (src_prim == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "value node is invalid.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
PrimitivePtr dst_prim;
|
||||||
|
if (src_prim->name() == kNameAvgPool) {
|
||||||
|
dst_prim = std::make_shared<ops::AvgPoolFusion>();
|
||||||
|
} else if (src_prim->name() == kNameMaxPool) {
|
||||||
|
dst_prim = std::make_shared<ops::MaxPoolFusion>();
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "unsupport pooling type.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
MS_ASSERT(dst_prim != nullptr);
|
||||||
|
dst_prim->SetAttrs(src_prim->attrs());
|
||||||
|
auto status = AttrAdjust(dst_prim, ops::kKernelSize, {2, 3});
|
||||||
|
if (status != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "adjust ksize failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
status = AttrAdjust(dst_prim, ops::kStrides, {2, 3});
|
||||||
|
if (status != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "adjust strides failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
if (dst_prim->GetAttr(ops::kPadding) != nullptr) {
|
||||||
|
dst_prim->AddAttr(ops::kPadMode, dst_prim->GetAttr(ops::kPadding));
|
||||||
|
}
|
||||||
|
value_node->set_value(dst_prim);
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int MoveAttrMapAdder(const ValueNodePtr &value_node) {
|
||||||
|
MS_ASSERT(value_node != nullptr);
|
||||||
|
auto src_prim = GetValueNode<PrimitivePtr>(value_node);
|
||||||
|
if (src_prim == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "value node is invalid.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto dst_prim = std::make_shared<ops::AdderFusion>();
|
||||||
|
MS_ASSERT(dst_prim != nullptr);
|
||||||
|
dst_prim->SetAttrs(src_prim->attrs());
|
||||||
|
auto status = AttrAdjust(dst_prim, ops::kStride, {2, 3});
|
||||||
|
if (status != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "adjust stride failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
status = AttrAdjust(dst_prim, ops::kDilation, {2, 3});
|
||||||
|
if (status != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "adjust dilation failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
status = AttrAdjust(dst_prim, ops::kKernelSize, {0, 1});
|
||||||
|
if (status != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "adjust kernel size failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
value_node->set_value(dst_prim);
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int MoveAttrMapLayerNorm(const ValueNodePtr &value_node) {
|
||||||
|
MS_ASSERT(value_node != nullptr);
|
||||||
|
auto src_prim = GetValueNode<PrimitivePtr>(value_node);
|
||||||
|
if (src_prim == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "value node is invalid.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto dst_prim = std::make_shared<ops::LayerNormFusion>();
|
||||||
|
MS_ASSERT(dst_prim != nullptr);
|
||||||
|
dst_prim->SetAttrs(src_prim->attrs());
|
||||||
|
dst_prim->set_elementwise_affine(true);
|
||||||
|
if (dst_prim->GetAttr(ops::kEpsilon) == nullptr) {
|
||||||
|
dst_prim->set_epsilon(1e-7);
|
||||||
|
}
|
||||||
|
value_node->set_value(dst_prim);
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int MoveAttrMapResize(const ValueNodePtr &value_node) {
|
||||||
|
MS_ASSERT(value_node != nullptr);
|
||||||
|
auto src_prim = GetValueNode<PrimitivePtr>(value_node);
|
||||||
|
if (src_prim == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "value node is invalid.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto dst_prim = std::make_shared<ops::Resize>();
|
||||||
|
auto size = GetValue<std::vector<int64_t>>(src_prim->GetAttr(ops::kSize));
|
||||||
|
dst_prim->set_new_height(size[0]);
|
||||||
|
dst_prim->set_new_width(size[1]);
|
||||||
|
if (dst_prim->GetAttr(ops::kAlignCorners) != nullptr && GetValue<bool>(dst_prim->GetAttr(ops::kAlignCorners))) {
|
||||||
|
dst_prim->set_coordinate_transform_mode(mindspore::ALIGN_CORNERS);
|
||||||
|
}
|
||||||
|
if (src_prim->name() == kNameResizeBilinear) {
|
||||||
|
dst_prim->set_method(ResizeMethod::LINEAR);
|
||||||
|
} else if (src_prim->name() == "ResizeNearestNeighbor") {
|
||||||
|
dst_prim->set_method(ResizeMethod::NEAREST);
|
||||||
|
}
|
||||||
|
value_node->set_value(dst_prim);
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool PrimitiveAdjustPass::Run(const FuncGraphPtr &func_graph) {
|
||||||
|
if (this->fmk_type_ != lite::converter::FmkType_MS) {
|
||||||
|
MS_LOG(INFO) << "The framework type of model should be mindir.";
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
MS_ASSERT(graph != nullptr);
|
||||||
|
auto node_list = TopoSort(func_graph->get_return());
|
||||||
|
int status = lite::RET_OK;
|
||||||
|
for (auto &node : node_list) {
|
||||||
|
if (!utils::isa<CNodePtr>(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_ASSERT(cnode->size() > 0);
|
||||||
|
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||||
|
if (value_node == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "cnode first input is invalid.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
MS_ASSERT(prim != nullptr);
|
||||||
|
auto name = prim->name();
|
||||||
|
auto adjust_func = PrimitiveAdjustRegistry::GetInstance()->GetPrimitiveCreator(name);
|
||||||
|
if (adjust_func == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "dont't need to adjust.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
status = adjust_func(value_node);
|
||||||
|
if (status != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "convert primitive failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameAbs, MoveAttrMapActivation)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameAdd, MoveAttrMapCommon<ops::AddFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameAdder, MoveAttrMapAdder)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameArgMax, MoveAttrMapCommon<ops::ArgMaxFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameArgMaxWithValue, MoveAttrMapCommon<ops::ArgMaxFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameArgMin, MoveAttrMapCommon<ops::ArgMinFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameArgMinWithValue, MoveAttrMapCommon<ops::ArgMinFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameAvgPool, MoveAttrPool)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameBatchMatMul, MoveAttrMapCommon<ops::MatMul>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameBatchNorm, MoveAttrMapCommon<ops::FusedBatchNorm>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropFilter, MoveAttrMapCommon<ops::Conv2DBackpropFilterFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropInput, MoveAttrMapCommon<ops::Conv2DBackpropInputFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameConv2D, MoveAttrMapConv2D)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameDepthWiseConv2D, MoveAttrMapConv2D)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameConv2dTranspose, MoveAttrMapCommon<ops::Conv2dTransposeFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameDiv, MoveAttrMapCommon<ops::DivFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameElu, MoveAttrMapActivation)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameExp, MoveAttrMapCommon<ops::ExpFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameGatherV2, MoveAttrMapCommon<ops::Gather>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameGeLU, MoveAttrMapActivation)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameL2Normalize, MoveAttrMapCommon<ops::L2NormalizeFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameLayerNorm, MoveAttrMapLayerNorm)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameLeakyRelu, MoveAttrMapActivation)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameMaxPool, MoveAttrPool)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameMul, MoveAttrMapCommon<ops::MulFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNamePad, MoveAttrMapCommon<ops::PadFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNamePReLU, MoveAttrMapCommon<ops::PReLUFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameReduceAll, MoveAttrMapReduce)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameReduceASum, MoveAttrMapReduce)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameReduceMax, MoveAttrMapReduce)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameReduceMean, MoveAttrMapReduce)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameReduceMin, MoveAttrMapReduce)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameReduceProd, MoveAttrMapReduce)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameReduceSum, MoveAttrMapReduce)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameReduceSumSquare, MoveAttrMapReduce)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameReLU, MoveAttrMapActivation)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameReLU6, MoveAttrMapActivation)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameResizeBilinear, MoveAttrMapResize)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameScale, MoveAttrMapCommon<ops::ScaleFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameSigmoid, MoveAttrMapActivation)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameSub, MoveAttrMapCommon<ops::SubFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameTanh, MoveAttrMapActivation)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameTensorAdd, MoveAttrMapCommon<ops::AddFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameTile, MoveAttrMapCommon<ops::TileFusion>)
|
||||||
|
REGIST_PRIMITIVE_ADJUST(kNameTopK, MoveAttrMapCommon<ops::TopKFusion>)
|
||||||
|
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -31,6 +31,7 @@ from ... import context
|
||||||
|
|
||||||
env_force_bprop_seq = os.getenv("ENV_FORCE_BPROP_SEQ")
|
env_force_bprop_seq = os.getenv("ENV_FORCE_BPROP_SEQ")
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.BiasAdd)
|
@bprop_getters.register(P.BiasAdd)
|
||||||
def get_bprop_bias_add(self):
|
def get_bprop_bias_add(self):
|
||||||
"""Grad definition for `BiasAdd` operation."""
|
"""Grad definition for `BiasAdd` operation."""
|
||||||
|
@ -313,7 +314,7 @@ def _get_mean_matrix(x_shape, ksize, stride, pad_mode, x_dtype):
|
||||||
|
|
||||||
for h in range(h_output):
|
for h in range(h_output):
|
||||||
for w in range(w_output):
|
for w in range(w_output):
|
||||||
curr_input = assist_input_matrix[h*h_stride : h*h_stride + h_ksize, w*w_stride : w*w_stride + w_ksize]
|
curr_input = assist_input_matrix[h * h_stride: h * h_stride + h_ksize, w * w_stride: w * w_stride + w_ksize]
|
||||||
curr_sum = np.sum(curr_input)
|
curr_sum = np.sum(curr_input)
|
||||||
if curr_sum > 0:
|
if curr_sum > 0:
|
||||||
output[:, :, h, w] = 1. / curr_sum
|
output[:, :, h, w] = 1. / curr_sum
|
||||||
|
@ -681,10 +682,10 @@ def get_bprop_tanh_grad(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.Gelu)
|
@bprop_getters.register(P.GeLU)
|
||||||
def get_bprop_gelu(self):
|
def get_bprop_gelu(self):
|
||||||
"""Grad definition for `Gelu` operation."""
|
"""Grad definition for `GeLU` operation."""
|
||||||
input_grad = G.GeluGrad()
|
input_grad = G.GeLUGrad()
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
dx = input_grad(dout, x, out)
|
dx = input_grad(dout, x, out)
|
||||||
|
@ -693,10 +694,34 @@ def get_bprop_gelu(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.FastGelu)
|
@bprop_getters.register(P.Gelu)
|
||||||
|
def get_bprop_gelu_2(self):
|
||||||
|
"""Grad definition for `GeLU` operation."""
|
||||||
|
input_grad = G.GeLUGrad()
|
||||||
|
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
dx = input_grad(dout, x, out)
|
||||||
|
return (dx,)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.FastGeLU)
|
||||||
def get_bprop_fast_gelu(self):
|
def get_bprop_fast_gelu(self):
|
||||||
"""Grad definition for `FastGelu` operation."""
|
"""Grad definition for `FastGeLU` operation."""
|
||||||
input_grad = G.FastGeluGrad()
|
input_grad = G.FastGeLUGrad()
|
||||||
|
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
dx = input_grad(dout, x)
|
||||||
|
return (dx,)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.FastGelu)
|
||||||
|
def get_bprop_fast_gelu_2(self):
|
||||||
|
"""Grad definition for `FastGeLU` operation."""
|
||||||
|
input_grad = G.FastGeLUGrad()
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
dx = input_grad(dout, x)
|
dx = input_grad(dout, x)
|
||||||
|
@ -713,6 +738,7 @@ def get_bprop_fused_batch_norm(self):
|
||||||
if self.target == "CPU":
|
if self.target == "CPU":
|
||||||
input_grad = G.FusedBatchNormGradCPU(self.epsilon, self.momentum)
|
input_grad = G.FusedBatchNormGradCPU(self.epsilon, self.momentum)
|
||||||
target_cpu = True
|
target_cpu = True
|
||||||
|
|
||||||
def bprop(x, scale, b, mean, variance, out, dout):
|
def bprop(x, scale, b, mean, variance, out, dout):
|
||||||
saved_mean = out[3]
|
saved_mean = out[3]
|
||||||
saved_variance = out[4]
|
saved_variance = out[4]
|
||||||
|
@ -897,6 +923,7 @@ def _range_op(start, limit, delta, dtype):
|
||||||
output_tensor = Tensor(list(range(start, limit, delta)), dtype)
|
output_tensor = Tensor(list(range(start, limit, delta)), dtype)
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _get_1d_shape(in_shape):
|
def _get_1d_shape(in_shape):
|
||||||
"""helper function for Grad TopK"""
|
"""helper function for Grad TopK"""
|
||||||
|
@ -905,6 +932,7 @@ def _get_1d_shape(in_shape):
|
||||||
out_shape *= i
|
out_shape *= i
|
||||||
return (out_shape,)
|
return (out_shape,)
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.TopK)
|
@bprop_getters.register(P.TopK)
|
||||||
def get_bprop_top_kv2(self):
|
def get_bprop_top_kv2(self):
|
||||||
"""Grad definition for `TopK` operation."""
|
"""Grad definition for `TopK` operation."""
|
||||||
|
@ -915,7 +943,6 @@ def get_bprop_top_kv2(self):
|
||||||
dtype = P.DType()
|
dtype = P.DType()
|
||||||
|
|
||||||
def bprop(input_x, k, out, dout):
|
def bprop(input_x, k, out, dout):
|
||||||
|
|
||||||
in_shape = shape_op(input_x)
|
in_shape = shape_op(input_x)
|
||||||
in_lastdim = in_shape[-1]
|
in_lastdim = in_shape[-1]
|
||||||
|
|
||||||
|
@ -976,6 +1003,7 @@ def get_bprop_rnnt_loss(self):
|
||||||
def bprop(acts, labels, act_lens, label_lens, out, dout):
|
def bprop(acts, labels, act_lens, label_lens, out, dout):
|
||||||
grad = out[1]
|
grad = out[1]
|
||||||
return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens)
|
return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@ -1064,6 +1092,7 @@ def get_bprop_dynamic_rnn(self):
|
||||||
dh_prev = expand_dims(dh_prev, 0)
|
dh_prev = expand_dims(dh_prev, 0)
|
||||||
dc_prev = expand_dims(dc_prev, 0)
|
dc_prev = expand_dims(dc_prev, 0)
|
||||||
return dx, dw, db, (0), dh_prev, dc_prev
|
return dx, dw, db, (0), dh_prev, dc_prev
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@ -1082,6 +1111,7 @@ def get_bprop_dynamic_gru_v2(self):
|
||||||
out_h, dy, dout_h[-1], update,
|
out_h, dy, dout_h[-1], update,
|
||||||
reset, new, hidden_new, None, None)
|
reset, new, hidden_new, None, None)
|
||||||
return dx, dw_input, dw_hidden, db_input, db_hidden, (0), dh_prev
|
return dx, dw_input, dw_hidden, db_input, db_hidden, (0), dh_prev
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@ -1181,6 +1211,7 @@ def get_bprop_binary_cross_entropy(self):
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.KLDivLoss)
|
@bprop_getters.register(P.KLDivLoss)
|
||||||
def get_bprop_kl_div_loss(self):
|
def get_bprop_kl_div_loss(self):
|
||||||
"""Grad definition for `KLDivLoss` operation."""
|
"""Grad definition for `KLDivLoss` operation."""
|
||||||
|
@ -1239,6 +1270,7 @@ def get_bprop_basic_lstm_cell(self):
|
||||||
dxt, dht = basic_lstm_cell_input_grad(dgate, w)
|
dxt, dht = basic_lstm_cell_input_grad(dgate, w)
|
||||||
dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
|
dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
|
||||||
return dxt, dht, dct_1, dw, db
|
return dxt, dht, dct_1, dw, db
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,10 +13,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""FastGelu op"""
|
"""FastGeLU op"""
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
fast_gelu_op_info = TBERegOp("FastGelu") \
|
fast_gelu_op_info = TBERegOp("FastGeLU") \
|
||||||
.fusion_type("ELEMWISE") \
|
.fusion_type("ELEMWISE") \
|
||||||
.async_flag(False) \
|
.async_flag(False) \
|
||||||
.binfile_name("fast_gelu.so") \
|
.binfile_name("fast_gelu.so") \
|
||||||
|
@ -33,5 +33,5 @@ fast_gelu_op_info = TBERegOp("FastGelu") \
|
||||||
|
|
||||||
@op_info_register(fast_gelu_op_info)
|
@op_info_register(fast_gelu_op_info)
|
||||||
def _fast_gelu_tbe():
|
def _fast_gelu_tbe():
|
||||||
"""FastGelu TBE register"""
|
"""FastGeLU TBE register"""
|
||||||
return
|
return
|
||||||
|
|
|
@ -13,10 +13,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""FastGeluGrad op"""
|
"""FastGeLUGrad op"""
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
fast_gelu_grad_op_info = TBERegOp("FastGeluGrad") \
|
fast_gelu_grad_op_info = TBERegOp("FastGeLUGrad") \
|
||||||
.fusion_type("ELEMWISE") \
|
.fusion_type("ELEMWISE") \
|
||||||
.async_flag(False) \
|
.async_flag(False) \
|
||||||
.binfile_name("fast_gelu_grad.so") \
|
.binfile_name("fast_gelu_grad.so") \
|
||||||
|
@ -37,5 +37,5 @@ fast_gelu_grad_op_info = TBERegOp("FastGeluGrad") \
|
||||||
|
|
||||||
@op_info_register(fast_gelu_grad_op_info)
|
@op_info_register(fast_gelu_grad_op_info)
|
||||||
def _fast_gelu_grad_tbe():
|
def _fast_gelu_grad_tbe():
|
||||||
"""FastGeluGrad TBE register"""
|
"""FastGeLUGrad TBE register"""
|
||||||
return
|
return
|
||||||
|
|
|
@ -13,10 +13,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""Gelu op"""
|
"""GeLU op"""
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
gelu_op_info = TBERegOp("Gelu") \
|
gelu_op_info = TBERegOp("GeLU") \
|
||||||
.fusion_type("ELEMWISE") \
|
.fusion_type("ELEMWISE") \
|
||||||
.async_flag(False) \
|
.async_flag(False) \
|
||||||
.binfile_name("gelu.so") \
|
.binfile_name("gelu.so") \
|
||||||
|
@ -33,5 +33,5 @@ gelu_op_info = TBERegOp("Gelu") \
|
||||||
|
|
||||||
@op_info_register(gelu_op_info)
|
@op_info_register(gelu_op_info)
|
||||||
def _gelu_tbe():
|
def _gelu_tbe():
|
||||||
"""Gelu TBE register"""
|
"""GeLU TBE register"""
|
||||||
return
|
return
|
||||||
|
|
|
@ -13,10 +13,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""GeluGrad op"""
|
"""GeLUGrad op"""
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
gelu_grad_op_info = TBERegOp("GeluGrad") \
|
gelu_grad_op_info = TBERegOp("GeLUGrad") \
|
||||||
.fusion_type("ELEMWISE") \
|
.fusion_type("ELEMWISE") \
|
||||||
.async_flag(False) \
|
.async_flag(False) \
|
||||||
.binfile_name("gelu_grad.so") \
|
.binfile_name("gelu_grad.so") \
|
||||||
|
@ -38,5 +38,5 @@ gelu_grad_op_info = TBERegOp("GeluGrad") \
|
||||||
|
|
||||||
@op_info_register(gelu_grad_op_info)
|
@op_info_register(gelu_grad_op_info)
|
||||||
def _gelu_grad_tbe():
|
def _gelu_grad_tbe():
|
||||||
"""GeluGrad TBE register"""
|
"""GeLUGrad TBE register"""
|
||||||
return
|
return
|
||||||
|
|
|
@ -43,7 +43,8 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm
|
||||||
from .control_ops import ControlDepend, GeSwitch, Merge
|
from .control_ops import ControlDepend, GeSwitch, Merge
|
||||||
from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey
|
from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey
|
||||||
|
|
||||||
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr,
|
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
|
||||||
|
BitwiseAnd, BitwiseOr,
|
||||||
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
|
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
|
||||||
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, ReduceAny,
|
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, ReduceAny,
|
||||||
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
|
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
|
||||||
|
@ -65,7 +66,8 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
|
||||||
DepthwiseConv2dNative,
|
DepthwiseConv2dNative,
|
||||||
DropoutDoMask, Dropout, Dropout3d, DropoutGenMask, Flatten,
|
DropoutDoMask, Dropout, Dropout3d, DropoutGenMask, Flatten,
|
||||||
FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
|
FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
|
||||||
Gelu, FastGelu, Elu,
|
GeLU, Gelu, FastGeLU, FastGelu, Elu,
|
||||||
|
|
||||||
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder,
|
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder,
|
||||||
LogSoftmax,
|
LogSoftmax,
|
||||||
MaxPool, DataFormatDimMap,
|
MaxPool, DataFormatDimMap,
|
||||||
|
@ -93,7 +95,8 @@ from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg
|
||||||
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, CholeskyTrsm, DetTriangle,
|
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, CholeskyTrsm, DetTriangle,
|
||||||
ProdForceSeA)
|
ProdForceSeA)
|
||||||
from .sparse_ops import SparseToDense
|
from .sparse_ops import SparseToDense
|
||||||
from ._embedding_cache_ops import (CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter,
|
from ._embedding_cache_ops import (CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx,
|
||||||
|
SubAndFilter,
|
||||||
MapUniform, DynamicAssign, PadAndShift)
|
MapUniform, DynamicAssign, PadAndShift)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -174,7 +177,9 @@ __all__ = [
|
||||||
'Unstack',
|
'Unstack',
|
||||||
'Tile',
|
'Tile',
|
||||||
'BiasAdd',
|
'BiasAdd',
|
||||||
|
'GeLU',
|
||||||
'Gelu',
|
'Gelu',
|
||||||
|
'FastGeLU',
|
||||||
'FastGelu',
|
'FastGelu',
|
||||||
'Minimum',
|
'Minimum',
|
||||||
'Maximum',
|
'Maximum',
|
||||||
|
|
|
@ -790,12 +790,12 @@ class BNTrainingUpdateGrad(PrimitiveWithInfer):
|
||||||
return (batch_mean, batch_variance)
|
return (batch_mean, batch_variance)
|
||||||
|
|
||||||
|
|
||||||
class GeluGrad(PrimitiveWithInfer):
|
class GeLUGrad(PrimitiveWithInfer):
|
||||||
"""Gradients of Gelu operation."""
|
"""Gradients of GeLU operation."""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize GeluGrad"""
|
"""Initialize GeLUGrad"""
|
||||||
|
|
||||||
def infer_shape(self, y_backprop_shape, x_shape, y_shape):
|
def infer_shape(self, y_backprop_shape, x_shape, y_shape):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
@ -808,12 +808,12 @@ class GeluGrad(PrimitiveWithInfer):
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
class FastGeluGrad(PrimitiveWithInfer):
|
class FastGeLUGrad(PrimitiveWithInfer):
|
||||||
"""Gradients of FastGelu operation."""
|
"""Gradients of FastGeLU operation."""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""init FastGeluGrad"""
|
"""init FastGeLUGrad"""
|
||||||
|
|
||||||
def infer_shape(self, y_backprop_shape, x_shape):
|
def infer_shape(self, y_backprop_shape, x_shape):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
|
@ -805,6 +805,7 @@ class Gather(PrimitiveWithCheck):
|
||||||
[ 4. 54.]
|
[ 4. 54.]
|
||||||
[ 2. 55.]]
|
[ 2. 55.]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize index_select"""
|
"""Initialize index_select"""
|
||||||
|
@ -826,7 +827,8 @@ class GatherV2(PrimitiveWithCheck):
|
||||||
Same as operator Gather. GatherV2 will be deprecated in the future.
|
Same as operator Gather. GatherV2 will be deprecated in the future.
|
||||||
Please use Gather instead.
|
Please use Gather instead.
|
||||||
"""
|
"""
|
||||||
#deprecate_new_name = "Gather"
|
|
||||||
|
# deprecate_new_name = "Gather"
|
||||||
|
|
||||||
@deprecated("1.1", "Gather", True)
|
@deprecated("1.1", "Gather", True)
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
|
@ -2270,6 +2272,29 @@ def _get_stack_shape(x_shape, x_type, axis, prim_name):
|
||||||
return out_shape
|
return out_shape
|
||||||
|
|
||||||
|
|
||||||
|
class Pack(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Same as operator Stack. Pack will be deprecated in the future.
|
||||||
|
Please use Stack instead.
|
||||||
|
"""
|
||||||
|
@deprecated("1.1", "Stack", True)
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, axis=0):
|
||||||
|
"""Initialize Pack"""
|
||||||
|
validator.check_value_type("axis", axis, [int], self.name)
|
||||||
|
self.axis = axis
|
||||||
|
|
||||||
|
def __infer__(self, value):
|
||||||
|
x_shape = value['shape']
|
||||||
|
x_type = value['dtype']
|
||||||
|
self.add_prim_attr('num', len(x_shape))
|
||||||
|
all_shape = _get_stack_shape(x_shape, x_type, self.axis, self.name)
|
||||||
|
out = {'shape': all_shape,
|
||||||
|
'dtype': x_type[0],
|
||||||
|
'value': None}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Stack(PrimitiveWithInfer):
|
class Stack(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
Stacks a list of tensors in specified axis.
|
Stacks a list of tensors in specified axis.
|
||||||
|
@ -2324,26 +2349,45 @@ class Stack(PrimitiveWithInfer):
|
||||||
'value': None}
|
'value': None}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def Pack(axis=0):
|
|
||||||
|
class Unpack(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Packs a list of tensors in specified axis.
|
Same as operator Unstack. Unpack will be deprecated in the future.
|
||||||
|
Please use Unstack instead.
|
||||||
The usage of Pack is deprecated. Please use Stack.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
logger.warning("WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.")
|
@deprecated("1.1", "Unstack", True)
|
||||||
return Stack(axis)
|
@prim_attr_register
|
||||||
|
def __init__(self, axis=0):
|
||||||
|
"""Initialize Unpack"""
|
||||||
|
validator.check_value_type("axis", axis, [int], self.name)
|
||||||
|
self.axis = axis
|
||||||
|
|
||||||
|
def __infer__(self, x):
|
||||||
def Unpack(axis=0):
|
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
||||||
"""
|
x_shape = list(x['shape'])
|
||||||
Unpacks tensor in specified axis.
|
dim = len(x_shape)
|
||||||
|
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
|
||||||
The usage of Unpack is deprecated. Please use Unstack.
|
if self.axis < 0:
|
||||||
|
self.axis = self.axis + dim
|
||||||
"""
|
output_num = x_shape[self.axis]
|
||||||
logger.warning("WARN_DEPRECATED: The usage of Unpack is deprecated. Please use Unstack.")
|
validator.check_value_type("num", output_num, [int], self.name)
|
||||||
return Unstack(axis)
|
validator.check_positive_int(output_num, "output_num", self.name)
|
||||||
|
self.add_prim_attr('num', output_num)
|
||||||
|
output_valid_check = x_shape[self.axis] - output_num
|
||||||
|
validator.check_int(output_valid_check, 0, Rel.EQ,
|
||||||
|
"The dimension which to unstack divides output_num", self.name)
|
||||||
|
out_shapes = []
|
||||||
|
out_dtypes = []
|
||||||
|
out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:]
|
||||||
|
for _ in range(output_num):
|
||||||
|
out_shapes.append(tuple(out_shape))
|
||||||
|
out_dtypes.append(x['dtype'])
|
||||||
|
out_shapes = tuple(out_shapes)
|
||||||
|
out_dtypes = tuple(out_dtypes)
|
||||||
|
out = {'shape': out_shapes,
|
||||||
|
'dtype': out_dtypes,
|
||||||
|
'value': None}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Unstack(PrimitiveWithInfer):
|
class Unstack(PrimitiveWithInfer):
|
||||||
|
|
|
@ -20,12 +20,14 @@ import operator
|
||||||
from functools import reduce, partial
|
from functools import reduce, partial
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore._checkparam import _check_3d_int_or_tuple
|
from mindspore._checkparam import _check_3d_int_or_tuple
|
||||||
|
from mindspore import log as logger
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ... import context
|
from ... import context
|
||||||
from .. import signature as sig
|
from .. import signature as sig
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
from ..._checkparam import Rel
|
from ..._checkparam import Rel
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
|
from ...common._decorator import deprecated
|
||||||
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
|
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
|
||||||
|
|
||||||
|
|
||||||
|
@ -3245,6 +3247,25 @@ class OneHot(PrimitiveWithInfer):
|
||||||
|
|
||||||
|
|
||||||
class Gelu(PrimitiveWithInfer):
|
class Gelu(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Same as operator GeLU. Gelu will be deprecated in the future.
|
||||||
|
Please use GeLU instead.
|
||||||
|
"""
|
||||||
|
@deprecated("1.1", "GeLU", True)
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize Gelu"""
|
||||||
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
|
def infer_shape(self, input_x):
|
||||||
|
return input_x
|
||||||
|
|
||||||
|
def infer_dtype(self, input_x):
|
||||||
|
validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
|
||||||
|
return input_x
|
||||||
|
|
||||||
|
|
||||||
|
class GeLU(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
Gaussian Error Linear Units activation function.
|
Gaussian Error Linear Units activation function.
|
||||||
|
|
||||||
|
@ -3252,7 +3273,7 @@ class Gelu(PrimitiveWithInfer):
|
||||||
And also please refer to `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
|
And also please refer to `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
|
||||||
<https://arxiv.org/abs/1810.04805>`_.
|
<https://arxiv.org/abs/1810.04805>`_.
|
||||||
|
|
||||||
Gelu is defined as follows:
|
GeLU is defined as follows:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\text{output} = 0.5 * x * (1 + erf(x / \sqrt{2})),
|
\text{output} = 0.5 * x * (1 + erf(x / \sqrt{2})),
|
||||||
|
@ -3260,7 +3281,7 @@ class Gelu(PrimitiveWithInfer):
|
||||||
where :math:`erf` is the "Gauss error function" .
|
where :math:`erf` is the "Gauss error function" .
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **input_x** (Tensor) - Input to compute the Gelu with data type of float16 or float32.
|
- **input_x** (Tensor) - Input to compute the GeLU with data type of float16 or float32.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor, with the same type and shape as input.
|
Tensor, with the same type and shape as input.
|
||||||
|
@ -3273,7 +3294,7 @@ class Gelu(PrimitiveWithInfer):
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> tensor = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
|
>>> tensor = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
|
||||||
>>> gelu = ops.Gelu()
|
>>> gelu = ops.GeLU()
|
||||||
>>> result = gelu(tensor)
|
>>> result = gelu(tensor)
|
||||||
>>> print(result)
|
>>> print(result)
|
||||||
[0.841192 1.9545976 2.9963627]
|
[0.841192 1.9545976 2.9963627]
|
||||||
|
@ -3293,10 +3314,29 @@ class Gelu(PrimitiveWithInfer):
|
||||||
|
|
||||||
|
|
||||||
class FastGelu(PrimitiveWithInfer):
|
class FastGelu(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Same as operator FastGeLU. FastGelu will be deprecated in the future.
|
||||||
|
Please use FastGeLU instead.
|
||||||
|
"""
|
||||||
|
@deprecated("1.1", "FastGeLU", True)
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init FastGelu"""
|
||||||
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
|
def infer_shape(self, input_x):
|
||||||
|
return input_x
|
||||||
|
|
||||||
|
def infer_dtype(self, input_x):
|
||||||
|
validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
|
||||||
|
return input_x
|
||||||
|
|
||||||
|
|
||||||
|
class FastGeLU(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
Fast Gaussian Error Linear Units activation function.
|
Fast Gaussian Error Linear Units activation function.
|
||||||
|
|
||||||
FastGelu is defined as follows:
|
FastGeLU is defined as follows:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\text{output} = \frac {x} {1 + \exp(-1.702 * \left| x \right|)} * \exp(0.851 * (x - \left| x \right|)),
|
\text{output} = \frac {x} {1 + \exp(-1.702 * \left| x \right|)} * \exp(0.851 * (x - \left| x \right|)),
|
||||||
|
@ -3304,7 +3344,7 @@ class FastGelu(PrimitiveWithInfer):
|
||||||
where :math:`x` is the element of the input.
|
where :math:`x` is the element of the input.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **input_x** (Tensor) - Input to compute the FastGelu with data type of float16 or float32.
|
- **input_x** (Tensor) - Input to compute the FastGeLU with data type of float16 or float32.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor, with the same type and shape as input.
|
Tensor, with the same type and shape as input.
|
||||||
|
@ -3317,7 +3357,7 @@ class FastGelu(PrimitiveWithInfer):
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> tensor = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
>>> tensor = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
||||||
>>> fast_gelu = P.FastGelu()
|
>>> fast_gelu = P.FastGeLU()
|
||||||
>>> output = fast_gelu(tensor)
|
>>> output = fast_gelu(tensor)
|
||||||
>>> print(output)
|
>>> print(output)
|
||||||
[[-1.5420423e-01 3.9955849e+00 -9.7664278e-06]
|
[[-1.5420423e-01 3.9955849e+00 -9.7664278e-06]
|
||||||
|
|
|
@ -61,7 +61,7 @@ class MEGeluLargeIn(Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MEGeluLargeIn, self).__init__()
|
super(MEGeluLargeIn, self).__init__()
|
||||||
self.matmul = P.MatMul()
|
self.matmul = P.MatMul()
|
||||||
self.fast_gelu = P.Gelu()
|
self.fast_gelu = P.GeLU()
|
||||||
|
|
||||||
def construct(self, x1, x2):
|
def construct(self, x1, x2):
|
||||||
x = self.matmul(x1, x2)
|
x = self.matmul(x1, x2)
|
||||||
|
|
|
@ -61,7 +61,7 @@ class MEGeluLargeIn(Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MEGeluLargeIn, self).__init__()
|
super(MEGeluLargeIn, self).__init__()
|
||||||
self.matmul = P.MatMul()
|
self.matmul = P.MatMul()
|
||||||
self.gelu = P.Gelu()
|
self.gelu = P.GeLU()
|
||||||
|
|
||||||
def construct(self, x1, x2):
|
def construct(self, x1, x2):
|
||||||
x = self.matmul(x1, x2)
|
x = self.matmul(x1, x2)
|
||||||
|
|
|
@ -28,7 +28,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
class GeluNet(nn.Cell):
|
class GeluNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(GeluNet, self).__init__()
|
super(GeluNet, self).__init__()
|
||||||
self.gelu = P.Gelu()
|
self.gelu = P.GeLU()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
return self.gelu(x)
|
return self.gelu(x)
|
||||||
|
|
|
@ -27,7 +27,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
class GeluNet(nn.Cell):
|
class GeluNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(GeluNet, self).__init__()
|
super(GeluNet, self).__init__()
|
||||||
self.gelu = P.Gelu()
|
self.gelu = P.GeLU()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
return self.gelu(x)
|
return self.gelu(x)
|
||||||
|
|
|
@ -28,7 +28,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
class GeluNet(nn.Cell):
|
class GeluNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(GeluNet, self).__init__()
|
super(GeluNet, self).__init__()
|
||||||
self.gelu = P.Gelu()
|
self.gelu = P.GeLU()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
return self.gelu(x)
|
return self.gelu(x)
|
||||||
|
|
|
@ -27,7 +27,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
class GeluNet(nn.Cell):
|
class GeluNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(GeluNet, self).__init__()
|
super(GeluNet, self).__init__()
|
||||||
self.gelu = P.Gelu()
|
self.gelu = P.GeLU()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
return self.gelu(x)
|
return self.gelu(x)
|
||||||
|
|
|
@ -25,7 +25,7 @@ import mindspore.ops.operations._grad_ops as G
|
||||||
class GeluNet(Cell):
|
class GeluNet(Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(GeluNet, self).__init__()
|
super(GeluNet, self).__init__()
|
||||||
self.gelu = P.Gelu()
|
self.gelu = P.GeLU()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
return self.gelu(x)
|
return self.gelu(x)
|
||||||
|
@ -34,7 +34,7 @@ class GeluNet(Cell):
|
||||||
class GeluGradNet(Cell):
|
class GeluGradNet(Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(GeluGradNet, self).__init__()
|
super(GeluGradNet, self).__init__()
|
||||||
self.gelu_grad = G.GeluGrad()
|
self.gelu_grad = G.GeLUGrad()
|
||||||
|
|
||||||
def construct(self, dy, x, y):
|
def construct(self, dy, x, y):
|
||||||
return self.gelu_grad(dy, x, y)
|
return self.gelu_grad(dy, x, y)
|
||||||
|
|
|
@ -26,18 +26,18 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
|
|
||||||
class GeluInfo;
|
class GeLUInfo;
|
||||||
using GeluInfoPtr = std::shared_ptr<GeluInfo>;
|
using GeLUInfoPtr = std::shared_ptr<GeLUInfo>;
|
||||||
GeluInfoPtr gelu;
|
GeLUInfoPtr gelu;
|
||||||
|
|
||||||
class TestGeluInfo : public UT::Common {
|
class TestGeLUInfo : public UT::Common {
|
||||||
public:
|
public:
|
||||||
TestGeluInfo() {}
|
TestGeLUInfo() {}
|
||||||
void SetUp();
|
void SetUp();
|
||||||
void TearDown() {}
|
void TearDown() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
void TestGeluInfo::SetUp() {
|
void TestGeLUInfo::SetUp() {
|
||||||
RankList dev_list;
|
RankList dev_list;
|
||||||
|
|
||||||
for (int32_t i = 0; i < 130; i++) {
|
for (int32_t i = 0; i < 130; i++) {
|
||||||
|
@ -59,10 +59,10 @@ void TestGeluInfo::SetUp() {
|
||||||
Shapes inputs_shape = {{2, 4, 8, 16}};
|
Shapes inputs_shape = {{2, 4, 8, 16}};
|
||||||
Shapes outputs_shape = {{2, 4, 8, 16}};
|
Shapes outputs_shape = {{2, 4, 8, 16}};
|
||||||
|
|
||||||
gelu = std::make_shared<GeluInfo>("gelu_info", inputs_shape, outputs_shape, attr);
|
gelu = std::make_shared<GeLUInfo>("gelu_info", inputs_shape, outputs_shape, attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestGeluInfo, InferDevMatrixShape1) {
|
TEST_F(TestGeLUInfo, InferDevMatrixShape1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategys inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ TEST_F(TestGeluInfo, InferDevMatrixShape1) {
|
||||||
ASSERT_EQ(dev_matrix_shape, expect);
|
ASSERT_EQ(dev_matrix_shape, expect);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestGeluInfo, InferSliceShape1) {
|
TEST_F(TestGeLUInfo, InferSliceShape1) {
|
||||||
Strategys str = {{2, 4, 1, 16}};
|
Strategys str = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ TEST_F(TestGeluInfo, InferSliceShape1) {
|
||||||
ASSERT_EQ(output_slice_shape, output_slice_shape_expect);
|
ASSERT_EQ(output_slice_shape, output_slice_shape_expect);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestGeluInfo, GetTensorLayout1) {
|
TEST_F(TestGeLUInfo, GetTensorLayout1) {
|
||||||
Strategys str = {{2, 4, 1, 16}};
|
Strategys str = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, str);
|
StrategyPtr strategy = NewStrategy(0, str);
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ TEST_F(TestGeluInfo, GetTensorLayout1) {
|
||||||
ASSERT_EQ(output_tensor_map.array(), output_expect);
|
ASSERT_EQ(output_tensor_map.array(), output_expect);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestGeluInfo, GetForwardOp1) {
|
TEST_F(TestGeLUInfo, GetForwardOp1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategys inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ TEST_F(TestGeluInfo, GetForwardOp1) {
|
||||||
ASSERT_EQ(size, 0);
|
ASSERT_EQ(size, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestGeluInfo, GetMirrorOPs1) {
|
TEST_F(TestGeLUInfo, GetMirrorOPs1) {
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategys inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
||||||
|
@ -138,7 +138,7 @@ TEST_F(TestGeluInfo, GetMirrorOPs1) {
|
||||||
ASSERT_EQ(size, 0);
|
ASSERT_EQ(size, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestGeluInfo, CheckStrategy1) {
|
TEST_F(TestGeLUInfo, CheckStrategy1) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
@ -147,7 +147,7 @@ TEST_F(TestGeluInfo, CheckStrategy1) {
|
||||||
ASSERT_EQ(ret, FAILED);
|
ASSERT_EQ(ret, FAILED);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestGeluInfo, CheckStrategy2) {
|
TEST_F(TestGeLUInfo, CheckStrategy2) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 4, 8}};
|
Strategys inputs = {{2, 4, 8}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
@ -156,7 +156,7 @@ TEST_F(TestGeluInfo, CheckStrategy2) {
|
||||||
ASSERT_EQ(ret, FAILED);
|
ASSERT_EQ(ret, FAILED);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestGeluInfo, CheckStrategy3) {
|
TEST_F(TestGeLUInfo, CheckStrategy3) {
|
||||||
// Success: {{2,4,1,16}}
|
// Success: {{2,4,1,16}}
|
||||||
Strategys inputs = {{2, 4, 1, 16}};
|
Strategys inputs = {{2, 4, 1, 16}};
|
||||||
StrategyPtr strategy = NewStrategy(0, inputs);
|
StrategyPtr strategy = NewStrategy(0, inputs);
|
||||||
|
|
|
@ -1632,12 +1632,12 @@ test_case_nn_ops = [
|
||||||
'block': G.BiasAddGrad(),
|
'block': G.BiasAddGrad(),
|
||||||
'desc_inputs': [[1, 3, 3, 3]],
|
'desc_inputs': [[1, 3, 3, 3]],
|
||||||
'skip': ['backward']}),
|
'skip': ['backward']}),
|
||||||
('Gelu', {
|
('GeLU', {
|
||||||
'block': P.Gelu(),
|
'block': P.GeLU(),
|
||||||
'desc_inputs': [[1, 3, 4, 4]],
|
'desc_inputs': [[1, 3, 4, 4]],
|
||||||
'desc_bprop': [[1, 3, 4, 4]]}),
|
'desc_bprop': [[1, 3, 4, 4]]}),
|
||||||
('GeluGrad', {
|
('GeLUGrad', {
|
||||||
'block': G.GeluGrad(),
|
'block': G.GeLUGrad(),
|
||||||
'desc_inputs': [[2, 2], [2, 2], [2, 2]],
|
'desc_inputs': [[2, 2], [2, 2], [2, 2]],
|
||||||
'desc_bprop': [[2, 2]],
|
'desc_bprop': [[2, 2]],
|
||||||
'skip': ['backward']}),
|
'skip': ['backward']}),
|
||||||
|
|
|
@ -51,7 +51,7 @@ def test_softmax_cross_entropy_loss_auto_parallel():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.matmul = P.MatMul(transpose_b=True)
|
self.matmul = P.MatMul(transpose_b=True)
|
||||||
self.gelu = P.Gelu()
|
self.gelu = P.GeLU()
|
||||||
|
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
out = self.matmul(x, y)
|
out = self.matmul(x, y)
|
||||||
|
|
|
@ -61,7 +61,7 @@ def test_virtual_dataset_3_input():
|
||||||
self.virtual_dataset = _VirtualDataset()
|
self.virtual_dataset = _VirtualDataset()
|
||||||
self.matmul1 = P.MatMul()
|
self.matmul1 = P.MatMul()
|
||||||
self.matmul2 = P.MatMul()
|
self.matmul2 = P.MatMul()
|
||||||
self.gelu = P.Gelu()
|
self.gelu = P.GeLU()
|
||||||
self.bn1 = bn_with_initialize(2048)
|
self.bn1 = bn_with_initialize(2048)
|
||||||
|
|
||||||
def construct(self, x, y, b):
|
def construct(self, x, y, b):
|
||||||
|
|
|
@ -27,7 +27,7 @@ class VirtualDatasetNet(nn.Cell):
|
||||||
self.virtual_dataset = _VirtualDataset()
|
self.virtual_dataset = _VirtualDataset()
|
||||||
self.matmul1 = P.MatMul()
|
self.matmul1 = P.MatMul()
|
||||||
self.matmul2 = P.MatMul()
|
self.matmul2 = P.MatMul()
|
||||||
self.gelu = P.Gelu()
|
self.gelu = P.GeLU()
|
||||||
|
|
||||||
def construct(self, x, y, z):
|
def construct(self, x, y, z):
|
||||||
x, y, z = self.virtual_dataset(x, y, z)
|
x, y, z = self.virtual_dataset(x, y, z)
|
||||||
|
|
|
@ -163,7 +163,7 @@ def test_activations():
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.matmul1 = P.MatMul().shard(strategy1)
|
self.matmul1 = P.MatMul().shard(strategy1)
|
||||||
self.matmul2 = P.MatMul().shard(strategy2)
|
self.matmul2 = P.MatMul().shard(strategy2)
|
||||||
self.gelu = P.Gelu().shard(strategy3)
|
self.gelu = P.GeLU().shard(strategy3)
|
||||||
self.tanh = P.Tanh().shard(strategy3)
|
self.tanh = P.Tanh().shard(strategy3)
|
||||||
self.softmax = P.Softmax().shard(strategy3)
|
self.softmax = P.Softmax().shard(strategy3)
|
||||||
self.logsoftmax = P.LogSoftmax().shard(strategy3)
|
self.logsoftmax = P.LogSoftmax().shard(strategy3)
|
||||||
|
@ -192,7 +192,7 @@ def test_activations_repeated_calculation():
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.matmul1 = P.MatMul().shard(strategy1)
|
self.matmul1 = P.MatMul().shard(strategy1)
|
||||||
self.matmul2 = P.MatMul().shard(strategy2)
|
self.matmul2 = P.MatMul().shard(strategy2)
|
||||||
self.gelu = P.Gelu().shard(strategy3)
|
self.gelu = P.GeLU().shard(strategy3)
|
||||||
self.tanh = P.Tanh().shard(strategy4)
|
self.tanh = P.Tanh().shard(strategy4)
|
||||||
self.softmax = P.Softmax().shard(strategy5)
|
self.softmax = P.Softmax().shard(strategy5)
|
||||||
self.logsoftmax = P.LogSoftmax().shard(strategy6)
|
self.logsoftmax = P.LogSoftmax().shard(strategy6)
|
||||||
|
@ -224,7 +224,7 @@ def test_activations_axis_tuple():
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.matmul1 = P.MatMul().shard(strategy1)
|
self.matmul1 = P.MatMul().shard(strategy1)
|
||||||
self.matmul2 = P.MatMul().shard(strategy2)
|
self.matmul2 = P.MatMul().shard(strategy2)
|
||||||
self.gelu = P.Gelu().shard(strategy3)
|
self.gelu = P.GeLU().shard(strategy3)
|
||||||
self.tanh = P.Tanh().shard(strategy4)
|
self.tanh = P.Tanh().shard(strategy4)
|
||||||
self.softmax = P.Softmax(axis=(0, 1)).shard(strategy5)
|
self.softmax = P.Softmax(axis=(0, 1)).shard(strategy5)
|
||||||
self.logsoftmax = P.LogSoftmax().shard(strategy6)
|
self.logsoftmax = P.LogSoftmax().shard(strategy6)
|
||||||
|
|
|
@ -52,7 +52,7 @@ def test_linear():
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc_nobias = P.MatMul(transpose_b=True).shard(strategy0)
|
self.fc_nobias = P.MatMul(transpose_b=True).shard(strategy0)
|
||||||
self.add = P.Add().shard(strategy1)
|
self.add = P.Add().shard(strategy1)
|
||||||
self.gelu = P.Gelu().shard(strategy2)
|
self.gelu = P.GeLU().shard(strategy2)
|
||||||
|
|
||||||
def construct(self, x, y, bias):
|
def construct(self, x, y, bias):
|
||||||
out = self.fc_nobias(x, y)
|
out = self.fc_nobias(x, y)
|
||||||
|
|
|
@ -57,7 +57,7 @@ class Net(nn.Cell):
|
||||||
def __init__(self, strategy1, strategy2):
|
def __init__(self, strategy1, strategy2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.matmul = P.MatMul().shard(strategy1)
|
self.matmul = P.MatMul().shard(strategy1)
|
||||||
self.gelu = P.Gelu().shard(strategy2)
|
self.gelu = P.GeLU().shard(strategy2)
|
||||||
|
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
out = self.matmul(x, y)
|
out = self.matmul(x, y)
|
||||||
|
|
|
@ -57,7 +57,7 @@ def test_softmax_cross_entropy_loss():
|
||||||
def __init__(self, strategy1, strategy2):
|
def __init__(self, strategy1, strategy2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.matmul = P.MatMul(transpose_b=True).shard(strategy1)
|
self.matmul = P.MatMul(transpose_b=True).shard(strategy1)
|
||||||
self.gelu = P.Gelu().shard(strategy2)
|
self.gelu = P.GeLU().shard(strategy2)
|
||||||
|
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
out = self.matmul(x, y)
|
out = self.matmul(x, y)
|
||||||
|
@ -82,7 +82,7 @@ def test_softmax_cross_entropy_loss_repeated_calculation():
|
||||||
def __init__(self, strategy1, strategy2):
|
def __init__(self, strategy1, strategy2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.matmul = P.MatMul(transpose_b=True).shard(strategy1)
|
self.matmul = P.MatMul(transpose_b=True).shard(strategy1)
|
||||||
self.gelu = P.Gelu().shard(strategy2)
|
self.gelu = P.GeLU().shard(strategy2)
|
||||||
|
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
out = self.matmul(x, y)
|
out = self.matmul(x, y)
|
||||||
|
@ -107,7 +107,7 @@ def test_softmax_cross_entropy_loss_auto_batch_parallel():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.matmul = P.MatMul(transpose_b=True)
|
self.matmul = P.MatMul(transpose_b=True)
|
||||||
self.gelu = P.Gelu()
|
self.gelu = P.GeLU()
|
||||||
|
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
out = self.matmul(x, y)
|
out = self.matmul(x, y)
|
||||||
|
|
|
@ -57,7 +57,7 @@ def test_virtual_dataset_3_input():
|
||||||
self.virtual_dataset = _VirtualDataset().shard(strategy0)
|
self.virtual_dataset = _VirtualDataset().shard(strategy0)
|
||||||
self.matmul1 = P.MatMul().shard(strategy1)
|
self.matmul1 = P.MatMul().shard(strategy1)
|
||||||
self.matmul2 = P.MatMul().shard(strategy2)
|
self.matmul2 = P.MatMul().shard(strategy2)
|
||||||
self.gelu = P.Gelu().shard(strategy3)
|
self.gelu = P.GeLU().shard(strategy3)
|
||||||
|
|
||||||
def construct(self, x, y, b):
|
def construct(self, x, y, b):
|
||||||
x, y, b = self.virtual_dataset(x, y, b)
|
x, y, b = self.virtual_dataset(x, y, b)
|
||||||
|
@ -86,7 +86,7 @@ def test_virtualdataset_cell_3_inputs():
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.matmul1 = P.MatMul().shard(strategy1)
|
self.matmul1 = P.MatMul().shard(strategy1)
|
||||||
self.matmul2 = P.MatMul().shard(strategy2)
|
self.matmul2 = P.MatMul().shard(strategy2)
|
||||||
self.gelu = P.Gelu().shard(strategy3)
|
self.gelu = P.GeLU().shard(strategy3)
|
||||||
|
|
||||||
def construct(self, x, y, b):
|
def construct(self, x, y, b):
|
||||||
out = self.gelu(self.matmul1(x, y))
|
out = self.gelu(self.matmul1(x, y))
|
||||||
|
|
Loading…
Reference in New Issue