diff --git a/graphengine b/graphengine index 45ca7863ac6..1350673d51b 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit 45ca7863ac6410c8e2f83168481ddc6b43bcea33 +Subproject commit 1350673d51b3f8535bc217a7780e6a0b52ff9a41 diff --git a/mindspore/ops/_op_impl/tbe/matmul.py b/mindspore/ops/_op_impl/tbe/matmul.py index c29378f7217..e773191ae88 100644 --- a/mindspore/ops/_op_impl/tbe/matmul.py +++ b/mindspore/ops/_op_impl/tbe/matmul.py @@ -23,16 +23,26 @@ matmul_op_info = TBERegOp("MatMul") \ .compute_cost(10) \ .kernel_name("matmul") \ .partial_flag(True) \ - .attr("transpose_a", "required", "bool", "all") \ - .attr("transpose_b", "required", "bool", "all") \ + .attr("transpose_x1", "required", "bool", "all") \ + .attr("transpose_x2", "required", "bool", "all") \ + .attr("offset_x", "optional", "int", "all") \ .input(0, "x1", False, "required", "all") \ .input(1, "x2", False, "required", "all") \ - .input(2, "x3", False, "optional", "all") \ + .input(2, "bias", False, "optional", "all") \ + .input(3, "offset_w", False, "optional", "all") \ .output(0, "y", False, "required", "all") \ - .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \ - .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.F32_FracNZ) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I8_Default, + DataType.I32_Default) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.I8_Default, + DataType.F16_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.I8_Default, + DataType.F32_FracNZ) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I8_Default, + DataType.F32_NHWC) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I8_Default, + DataType.F32_Default) \ + .dtype_format(DataType.I32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, DataType.I8_Default, + DataType.I32_NHWC) \ .get_op_info()