diff --git a/build.sh b/build.sh index 5fa6113cc75..efabee441f5 100755 --- a/build.sh +++ b/build.sh @@ -46,7 +46,11 @@ update_submodule() git submodule update --init metadef cd "${BASEPATH}" if [[ "X$ENABLE_AKG" = "Xon" ]]; then - git submodule update --init --recursive akg + if [[ "X$ENABLE_D" == "Xon" ]]; then + git submodule update --init akg + else + GIT_LFS_SKIP_SMUDGE=1 git submodule update --init akg + fi fi } diff --git a/cmake/external_libs/tvm_predict.cmake b/cmake/external_libs/tvm_predict.cmake deleted file mode 100644 index 299f1fb09f1..00000000000 --- a/cmake/external_libs/tvm_predict.cmake +++ /dev/null @@ -1,18 +0,0 @@ -if(ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/incubator-tvm/repository/archive/v0.6.0.tar.gz") - set(MD5 "7b22965745cf1c6208a4e367fb86a585") -else() - set(REQ_URL - "https://github.com/apache/incubator-tvm/release/download/v0.6.0/apache-tvm-src-v0.6.0-incubating.tar.gz") - set(MD5 "2d77a005f0046d937b99c67de82f6438") -endif() -set(incubator_tvm_predict_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") -set(incubator_tvm_predict_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") -mindspore_add_pkg(incubator_tvm_predict - VER 0.6.0 - HEAD_ONLY ./ - URL ${REQ_URL} - MD5 ${MD5} - PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/predict/0001-RetBugFix-CustomRuntime_v06.patch) -include_directories(${incubator_tvm_predict_INC}) -add_library(mindspore::incubator_tvm_predict ALIAS incubator_tvm_predict) diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc index 7d0e12c5f24..4be12d6391b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc @@ -215,7 +215,7 @@ class OpInfoExtractor { } } if (op_attr->type().empty()) { - MS_LOG(WARNING) << "Unknown type, ignore attr: " << name; + MS_LOG(DEBUG) << "Unknown type, ignore attr: " << name; continue; } op_info->add_attrs_ptr(op_attr); diff --git a/mindspore/ccsrc/backend/optimizer/pass/custom_op_reg_info_to_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/custom_op_reg_info_to_attr.cc index 2aa6d3399f2..65eaa3ff701 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/custom_op_reg_info_to_attr.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/custom_op_reg_info_to_attr.cc @@ -97,10 +97,15 @@ void AddMissingAttrs(const CNodePtr &cnode, kernel::OpImplyType imply_type, bool need_update = false; for (const auto &attr : all_attrs) { auto attr_name = attr->name(); - if (missing_attrs.find(attr_name) == missing_attrs.end() || attr->param_type() != "required") { + if (missing_attrs.find(attr_name) == missing_attrs.end()) { continue; } + // If attr's param_type is required, it should have default value. + // If attr have default value, we should parse it no matter whether its param_type is required or not. auto default_value = attr->default_value(); + if (default_value.empty() && attr->param_type() != "required") { + continue; + } if (default_value.empty()) { MS_LOG(EXCEPTION) << "attr [" << attr_name << "] in the registration information of op [" << op_name << "] does not have a value." << trace::DumpSourceLines(cnode); @@ -129,8 +134,8 @@ const AnfNodePtr CustomOpRegInfoToAttr::Process(const FuncGraphPtr &, const AnfN auto primitive = AnfAlgo::GetCNodePrimitive(cnode); MS_EXCEPTION_IF_NULL(primitive); auto func_type = AnfAlgo::GetNodeAttr(cnode, kAttrFuncType); - // Only AKG needs to process attr, TBE will process later in the json creating phase. - if (kCustomTypeAkg.find(func_type) == kCustomTypeAkg.end()) { + // AKG/AICPU need to process attr, TBE will process later in the json creating phase. + if (kCustomTypeAkg.find(func_type) == kCustomTypeAkg.end() || func_type == kCustomTypeAICPU) { return nullptr; } // Early return if current node does not have attr @@ -149,7 +154,9 @@ const AnfNodePtr CustomOpRegInfoToAttr::Process(const FuncGraphPtr &, const AnfN if (missing_attrs.empty()) { return nullptr; } - AddMissingAttrs(cnode, kernel::OpImplyType::kAKG, missing_attrs); + kernel::OpImplyType imply_type = + func_type == kCustomTypeAICPU ? kernel::OpImplyType::kAICPU : kernel::OpImplyType::kAKG; + AddMissingAttrs(cnode, imply_type, missing_attrs); return node; } diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index b1a4229cc52..95de4340581 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -92,7 +92,8 @@ def op_info_register(op_info): def custom_info_register(*reg_info): r""" - A decorator which is used to bind the registration information to the `func` parameter of `Custom` op. + A decorator which is used to bind the registration information to the `func` parameter of + :class:`mindspore.ops.Custom`. Note: The 'reg_info' will be added into oplib. @@ -645,7 +646,7 @@ class TBERegOp(RegOp): class CustomRegOp(RegOp): r""" - Class for `Custom` operator info register. + Class used for generating the registration information for the `func` parameter of :class:`mindspore.ops.Custom`. Args: op_name (str): kernel name. No need to set this value as `Custom` operator will generate a unique name diff --git a/mindspore/ops/operations/custom_ops.py b/mindspore/ops/operations/custom_ops.py index a7f523b63dd..85139342ecb 100644 --- a/mindspore/ops/operations/custom_ops.py +++ b/mindspore/ops/operations/custom_ops.py @@ -136,7 +136,7 @@ class Custom(ops.PrimitiveWithInfer): represents the registration information of `func` in a specific target. You need to invoke `CustomRegOp` or the subclass of `RegOp` to generate the reg info for `func`. Then you can invoke `custom_info_register` to bind the reg info to `func` or just pass the reg info to `reg_info` parameter. - The `reg_info` parameter takes higher priority then `custom_info_register` and the reg info in a + The `reg_info` parameter takes higher priority than `custom_info_register` and the reg info in a specific target will be registered only once. If reg info is not set, then we will infer the data types and formats from the inputs of `Custom` operator. @@ -161,7 +161,6 @@ class Custom(ops.PrimitiveWithInfer): ``Ascend`` ``GPU`` ``CPU`` Examples: - >>> import mindspore as ms >>> import mindspore.ops as ops >>> from mindspore.ops import CustomRegOp, custom_info_register, DataType >>> from mindspore.common import dtype as mstype @@ -169,7 +168,7 @@ class Custom(ops.PrimitiveWithInfer): >>> >>> # Example, func_type = "akg" >>> def outer_product(a, b): - ... c = output_tensor((a.shape[0], b.shape[1]), 'float32') + ... c = output_tensor(a.shape, a.dtype) ... for i0 in range(a.shape[0]): ... for i1 in range(b.shape[1]): ... c[i0, i1] = 0.0 @@ -202,19 +201,18 @@ class Custom(ops.PrimitiveWithInfer): ... def square_with_bias(input_x, output_y, bias=0.0, kernel_name="square_with_bias"): ... import te.lang.cce ... from te import tvm - ... from topi import generic ... from topi.cce import util ... ... shape = input_x.get("shape") ... dtype = input_x.get("dtype").lower() ... ... shape = util.shape_refine(shape) - ... data = tvm.placeholder(shape, name="data", dtype=dtype.lower()) + ... data = tvm.placeholder(shape, name="data", dtype=dtype) ... ... with tvm.target.cce(): ... res0 = te.lang.cce.vmul(data, data) ... res = te.lang.cce.vadds(res0, bias) - ... sch = generic.auto_schedule(res) + ... sch = te.lang.cce.auto_schedule(res) ... ... config = {"print_ir": False, ... "name": kernel_name, @@ -225,8 +223,8 @@ class Custom(ops.PrimitiveWithInfer): >>> class TbeNet(Cell): ... def __init__(self): ... super(TbeNet, self).__init__() - ... self.square_with_bias = ops.Custom(square_with_bias, out_shape=[2, 3], out_dtype=mstype.float32, \ - ... func_type="tbe") + ... self.square_with_bias = ops.Custom(square_with_bias, out_shape=lambda x, _: x, \ + ... out_dtype=lambda x, _: x, func_type="tbe") ... def construct(self, x): ... res = self.square_with_bias(x, 1.0) ... return res @@ -258,9 +256,9 @@ class Custom(ops.PrimitiveWithInfer): >>> >>> # Example, func_type = "aot" >>> class AOTSingleOutputNet(Cell): - ... def __init__(self, func, out_shapes, out_types, reg=None): + ... def __init__(self, out_shapes, out_types): ... super(AOTSingleOutputNet, self).__init__() - ... self.program = ops.Custom("./reorganize.so:CustomReorganize", (2, 3), mstype.float32, "aot") + ... self.program = ops.Custom("./reorganize.so:CustomReorganize", out_shapes, out_types, "aot") ... def construct(self, x, y): ... return self.program(x, y) >>> @@ -269,9 +267,9 @@ class Custom(ops.PrimitiveWithInfer): ... return (x1 + x2), (x1 - x2) >>> >>> class PyFuncNet(Cell): - ... def __init__(self, fn, out_shapes, out_types): - ... super().__init__() - ... self.func = ops.Custom(func_multi_output, ((2, 3), (2, 3)), (ms.float32, ms.float32), "pyfunc") + ... def __init__(self): + ... super(PyFuncNet, self).__init__() + ... self.func = ops.Custom(func_multi_output, lambda x, _: (x, x), lambda x, _: (x, x), "pyfunc") ... def construct(self, x1, x2): ... return self.func(x1, x2) """ diff --git a/tests/st/ops/graph_kernel/custom/test_custom_akg.py b/tests/st/ops/graph_kernel/custom/test_custom_akg.py index 609523800c4..b861e1950f8 100644 --- a/tests/st/ops/graph_kernel/custom/test_custom_akg.py +++ b/tests/st/ops/graph_kernel/custom/test_custom_akg.py @@ -15,16 +15,14 @@ import pytest import numpy as np -import mindspore as ms from mindspore import context, Tensor -from mindspore.common import dtype as mstype from mindspore.nn import Cell import mindspore.ops as ops from mindspore.ops import DataType, CustomRegOp, custom_info_register def outer_product(a, b): - c = output_tensor((a.shape[0], b.shape[1]), 'float32') + c = output_tensor(a.shape, a.dtype) for i0 in range(a.shape[0]): for i1 in range(b.shape[1]): @@ -35,8 +33,8 @@ def outer_product(a, b): def cube(a): - c = output_tensor((a.shape[0], a.shape[1]), 'float32') - b = allocate((a.shape[0], a.shape[1]), 'float32', 'local') + c = output_tensor(a.shape, a.dtype) + b = allocate(a.shape, a.dtype, 'local') for i0 in range(a.shape[0]): for i1 in range(a.shape[1]): @@ -49,10 +47,10 @@ def cube(a): class TestHybridTwoInputs(Cell): """Net definition""" - def __init__(self, func, shapes, types): + def __init__(self, func, out_shape, out_dtype): super(TestHybridTwoInputs, self).__init__() - self.program = ops.Custom(func, out_shape=shapes, out_dtype=types, func_type="akg") + self.program = ops.Custom(func, out_shape=out_shape, out_dtype=out_dtype, func_type="akg") def construct(self, x, y): return self.program(x, y) @@ -61,10 +59,10 @@ class TestHybridTwoInputs(Cell): class TestHybridOneInput(Cell): """Net definition""" - def __init__(self, func, shapes, types): + def __init__(self, func, out_shape, out_dtype): super(TestHybridOneInput, self).__init__() - self.program = ops.Custom(func, out_shape=shapes, out_dtype=types, func_type="akg") + self.program = ops.Custom(func, out_shape=out_shape, out_dtype=out_dtype, func_type="akg") def construct(self, x): return self.program(x) @@ -96,7 +94,7 @@ def hybrid_outer_product(): input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32) input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32) - test = TestHybridTwoInputs(outer_product, (4, 4), (ms.float32)) + test = TestHybridTwoInputs(outer_product, lambda x, _: x, lambda x, _: x) output = test(Tensor(input_x), Tensor(input_y)) expect = np.matmul(input_x, input_y) compare_res = np.allclose(expect, output.asnumpy(), 0.001, 0.001) @@ -109,7 +107,7 @@ def hybrid_outer_product_autodiff(): input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32) sens = np.random.normal(0, 1, [4, 4]).astype(np.float32) - test = TestHybridTwoInputs(outer_product, (4, 4), (ms.float32)) + test = TestHybridTwoInputs(outer_product, lambda x, _: x, lambda x, _: x) net = MatMulNN() dx, dy = ops.GradOperation(sens_param=True, get_all=True)(test)(Tensor(input_x), Tensor(input_y), Tensor(sens)) edx, edy = ops.GradOperation(sens_param=True, get_all=True)(net)(Tensor(input_x), Tensor(input_y), Tensor(sens)) @@ -123,7 +121,7 @@ def hybrid_pow_autodiff(): input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32) sens = np.random.normal(0, 1, [4, 4]).astype(np.float32) - test = TestHybridOneInput(cube, (4, 4), (ms.float32)) + test = TestHybridOneInput(cube, lambda x: x, lambda x: x) net = PowNN() dx = ops.GradOperation(sens_param=True)(test)(Tensor(input_x), Tensor(sens)) edx = ops.GradOperation(sens_param=True)(net)(Tensor(input_x), Tensor(sens)) @@ -222,9 +220,9 @@ def v_add(inputs, attrs): class TestIRbuilder(Cell): """Net definition""" - def __init__(self, shape): + def __init__(self): super(TestIRbuilder, self).__init__() - self.program = ops.Custom(v_add, out_shape=shape, out_dtype=mstype.float16, func_type="akg") + self.program = ops.Custom(v_add, out_shape=lambda x: x[0], out_dtype=lambda x: x[0], func_type="akg") def construct(self, x, y): return self.program([x, y]) @@ -235,7 +233,7 @@ def irbuilder_case(): input_x = np.random.normal(0, 1, shape).astype(np.float16) input_y = np.random.normal(0, 1, shape).astype(np.float16) - test = TestIRbuilder(shape) + test = TestIRbuilder() output = test(Tensor(input_x), Tensor(input_y)) compare_res = np.allclose(input_x + input_y, output.asnumpy(), 0.001, 0.001) if not compare_res: diff --git a/tests/st/ops/graph_kernel/custom/test_custom_tbe.py b/tests/st/ops/graph_kernel/custom/test_custom_tbe.py index ad86efbc207..9c834c14813 100644 --- a/tests/st/ops/graph_kernel/custom/test_custom_tbe.py +++ b/tests/st/ops/graph_kernel/custom/test_custom_tbe.py @@ -16,39 +16,35 @@ import pytest import numpy as np from mindspore import context, Tensor -from mindspore.common import dtype as mstype from mindspore.nn import Cell import mindspore.ops as ops from mindspore.ops import TBERegOp, DataType, CustomRegOp, custom_info_register from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like -square_with_bias_op_info = CustomRegOp() \ - .attr("bias", "required", "float") \ - .input(0, "x") \ - .output(0, "y") \ - .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .target("Ascend") \ - .get_op_info() - -@custom_info_register(square_with_bias_op_info) +@custom_info_register(CustomRegOp() \ + .attr("bias", "required", "float") \ + .input(0, "x") \ + .output(0, "y") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .target("Ascend") \ + .get_op_info()) def square_with_bias(input_x, output_y, bias=0.0, kernel_name="square_with_bias"): import te.lang.cce from te import tvm - from topi import generic from topi.cce import util shape = input_x.get("shape") dtype = input_x.get("dtype").lower() shape = util.shape_refine(shape) - data = tvm.placeholder(shape, name="data", dtype=dtype.lower()) + data = tvm.placeholder(shape, name="data", dtype=dtype) with tvm.target.cce(): res0 = te.lang.cce.vmul(data, data) res = te.lang.cce.vadds(res0, bias) - sch = generic.auto_schedule(res) + sch = te.lang.cce.auto_schedule(res) config = {"print_ir": False, "name": kernel_name, @@ -57,18 +53,15 @@ def square_with_bias(input_x, output_y, bias=0.0, kernel_name="square_with_bias" te.lang.cce.cce_build_code(sch, config) -square_with_bias_v2_op_info = CustomRegOp() \ - .attr("bias", "required", "float") \ - .input(0, "input_x") \ - .output(0, "output1") \ - .output(1, "output2") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .target("Ascend") \ - .get_op_info() - - -@custom_info_register(square_with_bias_v2_op_info) +@custom_info_register(CustomRegOp() \ + .attr("bias", "required", "float") \ + .input(0, "input_x") \ + .output(0, "output1") \ + .output(1, "output2") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .target("Ascend") \ + .get_op_info()) def square_with_bias_v2(input_x, output1, output2, bias=0.0, kernel_name="square_with_bias_v2"): import te.lang.cce from te import tvm @@ -76,7 +69,7 @@ def square_with_bias_v2(input_x, output1, output2, bias=0.0, kernel_name="square shape = input_x.get("shape") dtype = input_x.get("dtype").lower() - data = tvm.placeholder(shape, name="data", dtype=dtype.lower()) + data = tvm.placeholder(shape, name="data", dtype=dtype) res0 = te.lang.cce.vmul(data, data) res1 = te.lang.cce.vadds(res0, bias) @@ -113,7 +106,7 @@ def add_n_with_bias(inputs, output, bias, kernel_name="add_n_with_bias"): for i, d in enumerate(inputs): shape = d.get("shape") dtype = d.get("dtype").lower() - data.append(tvm.placeholder(shape, name="input_" + str(i), dtype=dtype.lower())) + data.append(tvm.placeholder(shape, name="input_" + str(i), dtype=dtype)) res = data[0] for i in range(1, len(data)): @@ -138,14 +131,13 @@ class Net1(Cell): def __init__(self): super(Net1, self).__init__() # TBE dsl with attr - self.square_with_bias = ops.Custom(square_with_bias, out_shape=[2, 3], out_dtype=mstype.float32, - func_type="tbe") + self.square_with_bias = ops.Custom(square_with_bias, lambda x, _: x, lambda x, _: x, func_type="tbe") # TBE dsl with multiple inputs and attr - self.add_n_with_bias = ops.Custom(add_n_with_bias, out_shape=[2, 3], out_dtype=mstype.float32, func_type="tbe", + self.add_n_with_bias = ops.Custom(add_n_with_bias, lambda x, _: x[0], lambda x, _: x[0], func_type="tbe", reg_info=add_n_with_bias_op_info) # TBE dsl with multiple outputs and attr - self.square_with_bias_v2 = ops.Custom(square_with_bias_v2, out_shape=([2, 3], [2, 3]), - out_dtype=(mstype.float32, mstype.float32), func_type="tbe") + self.square_with_bias_v2 = ops.Custom(square_with_bias_v2, lambda x, _: (x, x), lambda x, _: (x, x), + func_type="tbe") self.neg = ops.Neg() def construct(self, x): @@ -224,7 +216,7 @@ class Net2(Cell): def __init__(self, bprop_func): super(Net2, self).__init__() - self.square_with_bias = ops.Custom(square_with_bias, out_shape=[3], out_dtype=mstype.float32, bprop=bprop_func, + self.square_with_bias = ops.Custom(square_with_bias, lambda x, _: x, lambda x, _: x, bprop=bprop_func, func_type="tbe") def construct(self, x): @@ -276,28 +268,25 @@ def test_net2_pynative_mode(): grad_case(bprop) -square_with_bias_grad_info = CustomRegOp() \ - .input(0, "x") \ - .input(1, "dout") \ - .output(0, "y") \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ - .target("Ascend") \ - .get_op_info() - - -@custom_info_register(square_with_bias_grad_info) +@custom_info_register(CustomRegOp() \ + .input(0, "x") \ + .input(1, "dout") \ + .output(0, "y") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .target("Ascend") \ + .get_op_info()) def square_with_bias_grad(input_x, dout, output_y, kernel_name="square_with_bias_grad"): import te.lang.cce from te import tvm shape1 = input_x.get("shape") dtype1 = input_x.get("dtype").lower() - data1 = tvm.placeholder(shape1, name="data1", dtype=dtype1.lower()) + data1 = tvm.placeholder(shape1, name="data1", dtype=dtype1) shape2 = dout.get("shape") dtype2 = dout.get("dtype").lower() - data2 = tvm.placeholder(shape2, name="data2", dtype=dtype2.lower()) + data2 = tvm.placeholder(shape2, name="data2", dtype=dtype2) res0 = te.lang.cce.vmuls(data1, 2.0) res = te.lang.cce.vmul(res0, data2) @@ -312,7 +301,7 @@ def square_with_bias_grad(input_x, dout, output_y, kernel_name="square_with_bias def bprop1(): - op = ops.Custom(square_with_bias_grad, [3], mstype.float32, func_type="tbe") + op = ops.Custom(square_with_bias_grad, lambda x, _: x, lambda x, _: x, func_type="tbe") def custom_bprop(data, axis, out, dout): dx = op(data, dout)