diff --git a/akg b/akg index b0f0f9e3ef5..5f5eeb31ffd 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit b0f0f9e3ef5e6aaa788802f35f1daca657b5642a +Subproject commit 5f5eeb31ffdf5a1dc973e7f904dc88ad7581bc5d diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index d1bead0b0af..66c36d15c9a 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -1018,5 +1018,29 @@ int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &kernel_graph, cons auto output_used_num = GetNodeOutputUsedNum(kernel_graph, node); return std::accumulate(output_used_num.begin(), output_used_num.end(), int64_t(0)); } + +void GetCustomOpAttrIndex(const PrimitivePtr &primitive, std::unordered_set *indexes) { + if (primitive == nullptr || primitive->name() != prim::kPrimCustom->name()) { + return; + } + MS_EXCEPTION_IF_NULL(indexes); + auto input_names = primitive->GetAttr(kAttrInputNames); + auto attr_names = primitive->GetAttr(kAttrAttrNames); + if (input_names == nullptr || attr_names == nullptr) { + return; + } + auto input_names_vec = GetValue>(input_names); + auto attr_names_vec = GetValue>(attr_names); + if (input_names_vec.size() >= attr_names_vec.size()) { + size_t offset = input_names_vec.size() - attr_names_vec.size(); + for (size_t i = offset; i < input_names_vec.size(); ++i) { + if (input_names_vec[i] != attr_names_vec[i - offset]) { + MS_LOG(EXCEPTION) << primitive->name() << " found mismatching attr name " << input_names_vec[i] + << "in input_names and " << attr_names_vec[i - offset] << " in attr_names"; + } + indexes->insert(i); + } + } +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index 75da9fe93d1..c405bd4800f 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -229,6 +229,9 @@ std::vector GetNodeOutputUsedNum(const session::KernelGraph &kernel_gra // Get total used number of node's output int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node); + +// Get custom operator attr input indexes +void GetCustomOpAttrIndex(const PrimitivePtr &primitive, std::unordered_set *indexes); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/custom_op_const_input_to_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/custom_op_const_input_to_attr.cc index c4ce1870fdd..c1eeb74c7d6 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/custom_op_const_input_to_attr.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/custom_op_const_input_to_attr.cc @@ -15,13 +15,10 @@ */ #include "backend/optimizer/pass/custom_op_const_input_to_attr.h" -#include #include -#include #include #include "backend/optimizer/common/helper.h" -#include "utils/utils.h" #include "backend/session/anf_runtime_algorithm.h" namespace mindspore { @@ -41,20 +38,8 @@ const AnfNodePtr CustomOpConstInputToAttr::Process(const FuncGraphPtr &, const A } auto primitive = AnfAlgo::GetCNodePrimitive(cnode); MS_EXCEPTION_IF_NULL(primitive); - auto input_names = primitive->GetAttr(kAttrInputNames); - auto attr_names = primitive->GetAttr("attr_names"); - if (input_names == nullptr || attr_names == nullptr) { - return nullptr; - } - auto input_names_vec = GetValue>(input_names); - auto attr_names_vec = GetValue>(attr_names); std::unordered_set attr_indices; - if (input_names_vec.size() >= attr_names_vec.size()) { - size_t offset = input_names_vec.size() - attr_names_vec.size(); - for (size_t i = offset; i < input_names_vec.size(); ++i) { - attr_indices.insert(i); - } - } + GetCustomOpAttrIndex(primitive, &attr_indices); if (attr_indices.empty()) { return nullptr; } diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index ab59bb15d23..310a9236f6e 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -470,9 +470,9 @@ std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) { auto primitive = AnfAlgo::GetCNodePrimitive(node); if (primitive != nullptr) { if (primitive->name() == "Custom") { - auto func_name = primitive->GetAttr("func_name"); - if (func_name) { - return GetValue(func_name); + auto uniq_name = primitive->GetAttr("uniq_name"); + if (uniq_name) { + return GetValue(uniq_name); } } return primitive->name(); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 8757073e93d..9aa7451d906 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -503,7 +503,18 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector MS_EXCEPTION_IF_NULL(op_prim); // Checking whether attr conversion is needed. opt::ConstInputToAttrInfoRegister reg; - bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); + bool reg_exist = false; + if (op_run_info->op_name == prim::kPrimCustom->name()) { + // Custom op needs to set reg dynamically + std::unordered_set attr_indexes; + opt::GetCustomOpAttrIndex(op_prim, &attr_indexes); + if (!attr_indexes.empty()) { + reg_exist = true; + reg.SetConstInputToAttr(attr_indexes); + } + } else { + reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); + } if (op_run_info->is_dynamic_shape && dynamic_shape_const_input_to_attr.find(op_run_info->op_name) == dynamic_shape_const_input_to_attr.end()) { MS_LOG(DEBUG) << "current node is dynamic shape: " << op_run_info->op_name; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 4ed2693ac9e..4f3fa252d34 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -336,6 +336,7 @@ constexpr auto kHcomOpTypeReduceScatter = "HcomReduceScatter"; // attr key name constexpr auto kAttrInputNames = "input_names"; +constexpr auto kAttrAttrNames = "attr_names"; constexpr auto kAttrIsAICPUKernel = "is_AICPU_kernel"; constexpr auto kIsBackendCast = "is_backed_cast"; constexpr auto kAttrOutputNames = "output_names"; diff --git a/mindspore/ops/operations/custom_ops.py b/mindspore/ops/operations/custom_ops.py index b506e41d747..4d51898e6e9 100644 --- a/mindspore/ops/operations/custom_ops.py +++ b/mindspore/ops/operations/custom_ops.py @@ -91,7 +91,7 @@ class CustomRegOp(RegOp): Please note that target and the `func_type` of `Custom` op have some constraints. If func_type is "akg", target can be one of ["Ascend", "GPU"]. If func_type is "tbe", target can only be "Ascend". - If func_type is "aot", target can only be "GPU". + If func_type is "aot", target can only be ["GPU", "CPU"]. If func_type is "py_func", target can only be "CPU". Default: None. """ @@ -264,7 +264,7 @@ class Custom(ops.PrimitiveWithInfer): >>> >>> class Net(Cell): ... def __init__(self): - ... super(Net1, self).__init__() + ... super(Net, self).__init__() ... self.square_with_bias = Custom(square_with_bias, out_shape=[2, 3], out_dtype=mstype.float32, \ ... func_type="tbe") ... @@ -296,6 +296,7 @@ class Custom(ops.PrimitiveWithInfer): """ registered_func = {} + attr_dict = {} # Save input_names and attr_names for func. def __init__(self, func, out_shape, out_dtype, func_type, bprop=None, reg_info=None): ops.PrimitiveWithInfer.__init__(self, "Custom") @@ -303,16 +304,23 @@ class Custom(ops.PrimitiveWithInfer): self.supported_targets = ["Ascend", "GPU", "CPU"] self.func = func self.func_name = "" + self.uniq_name = "" + self.imply_path = "" if callable(self.func): # Get the original function if func is decorated if "__wrapped__" in self.func.__dict__: self.func = self.func.__dict__["__wrapped__"] + self.imply_path = os.path.realpath(inspect.getfile(self.func)) self.func_name = self.func.__name__ + self.uniq_name = self.name + "_" + self.func_name + "_" + str(id(self.func)) elif isinstance(self.func, str): self.func_name = self.func + self.uniq_name = self.name + "_" + self.func_name else: raise TypeError("func should be of type function or str, but got {}".format(type(self.func))) self.add_prim_attr("func_name", self.func_name) + self.add_prim_attr("uniq_name", self.uniq_name) + self.add_prim_attr("imply_path", self.imply_path) self.out_shape = out_shape self.out_dtype = out_dtype self.bprop = bprop @@ -333,6 +341,7 @@ class Custom(ops.PrimitiveWithInfer): else: self.func_type = "hybrid" self.add_prim_attr("func_type", self.func_type) + self.update_attr() def infer_shape(self, *args): if callable(self.out_shape): @@ -353,7 +362,6 @@ class Custom(ops.PrimitiveWithInfer): if reg_info is None and hasattr(self.func, "reg_info"): reg_info = getattr(self.func, "reg_info") reg_info_list = self.get_expanded_list(reg_info) - already_add_attr = False for reg_info in reg_info_list: if not isinstance(reg_info, (str, dict)): continue @@ -366,17 +374,10 @@ class Custom(ops.PrimitiveWithInfer): # Register reg_info = self.reformat_reg_info(reg_info, target) reg_info_str = json.dumps(reg_info) - if isinstance(self.func, str): - imply_path = self.func - else: - imply_path = os.path.realpath(inspect.getfile(self.func)) op_lib = Oplib() - if not op_lib.reg_op(reg_info_str, imply_path): - raise ValueError('Invalid reg info {}: {}\n'.format(imply_path, reg_info_str)) - # Add inputs name to attr - if not already_add_attr: - self.add_inputs_name_to_attr(reg_info) - already_add_attr = True + if not op_lib.reg_op(reg_info_str, self.imply_path): + raise ValueError('Invalid reg info {}: {}\n'.format(self.imply_path, reg_info_str)) + self.save_attr(reg_info) self.save_register_status(target) def get_expanded_list(self, data): @@ -405,7 +406,7 @@ class Custom(ops.PrimitiveWithInfer): registered_targets = getattr(self.func, "registered_targets", []) registered_targets.append(target) setattr(self.func, "registered_targets", registered_targets) - elif isinstance(self, str): + elif isinstance(self.func, str): if isinstance(Custom.registered_func.get(self.func), list): Custom.registered_func[self.func].append(target) else: @@ -415,7 +416,7 @@ class Custom(ops.PrimitiveWithInfer): """Reformat registration information.""" if not isinstance(reg_info, dict): raise TypeError("reg_info should be of type dict, but got {}".format(type(reg_info))) - reg_info["op_name"] = self.func_name + reg_info["op_name"] = self.uniq_name reg_info["imply_type"] = self.get_imply_type(reg_info, target) # Supplement necessary info for TBE if these information is missing in reg_info if reg_info["imply_type"] == "TBE": @@ -468,13 +469,19 @@ class Custom(ops.PrimitiveWithInfer): func_type_to_imply_type = {"akg": "AKG", "tbe": "TBE", "aot": target, "py_func": target} return func_type_to_imply_type.get(self.func_type, "AKG") - def add_inputs_name_to_attr(self, reg_info): - """Save inputs name to primitive's attr.""" + def save_attr(self, reg_info): + """Save input_names and attr_names of current func.""" if not isinstance(reg_info, dict): return tensor_inputs = reg_info.get("inputs", []) attr = reg_info.get("attr", []) + if not isinstance(tensor_inputs, (list, tuple)): + tensor_inputs = [tensor_inputs] + if not isinstance(attr, (list, tuple)): + attr = [attr] + # input_names include tensor input names and attr input names input_names = [] + # attr_names only includes attr input names attr_names = [] for item in tensor_inputs: if isinstance(item, dict) and item.get("name") is not None: @@ -483,6 +490,46 @@ class Custom(ops.PrimitiveWithInfer): if isinstance(item, dict) and item.get("name") is not None: input_names.append(item["name"]) attr_names.append(item["name"]) - # input_names include tensor input names and attr names - self.add_prim_attr("input_names", input_names) - self.add_prim_attr("attr_names", attr_names) + cur_attr = {"input_names": input_names, "attr_names": attr_names} + # If func does not have attr, save current attr. + # Else, check if current attr is same as previous saved one. + prev_input_names = input_names + prev_attr_names = attr_names + if callable(self.func): + func_attr = getattr(self.func, "func_attr", None) + if not isinstance(func_attr, dict): + setattr(self.func, "func_attr", cur_attr) + else: + prev_input_names = func_attr.get("input_names") + prev_attr_names = func_attr.get("attr_names") + elif isinstance(self.func, str): + func_attr = Custom.attr_dict.get(self.func) + if not isinstance(func_attr, dict): + Custom.attr_dict[self.func] = cur_attr + else: + prev_input_names = func_attr.get("input_names") + prev_attr_names = func_attr.get("attr_names") + if not isinstance(prev_input_names, list): + raise TypeError("func {}: previous saved input_names should be a list, but got {}" + .format(self.func, type(prev_input_names))) + if len(input_names) != len(prev_input_names): + raise ValueError("func {}: input_names's length {} is different from previous saved one {}" + .format(self.func, len(input_names), len(prev_input_names))) + if attr_names != prev_attr_names: + raise ValueError("func {}: attr_names {} is different from previous saved one {}" + .format(self.func, attr_names, prev_attr_names)) + + def update_attr(self): + """Add input_names and attr_names to primitive's attr.""" + func_attr = {} + if callable(self.func): + func_attr = getattr(self.func, "func_attr", None) + elif isinstance(self.func, str): + func_attr = Custom.attr_dict.get(self.func) + if isinstance(func_attr, dict): + input_names = func_attr.get("input_names") + attr_names = func_attr.get("attr_names") + if input_names: + self.add_prim_attr("input_names", input_names) + if attr_names: + self.add_prim_attr("attr_names", attr_names) diff --git a/tests/st/ops/graph_kernel/custom/test_custom_akg.py b/tests/st/ops/graph_kernel/custom/test_custom_akg.py index ffb80271cfa..155d36bf92a 100644 --- a/tests/st/ops/graph_kernel/custom/test_custom_akg.py +++ b/tests/st/ops/graph_kernel/custom/test_custom_akg.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ +import pytest import numpy as np from mindspore import context, Tensor from mindspore.common import dtype as mstype @@ -53,6 +54,7 @@ def outer_product(a, b): class TestHybrid(Cell): """Net definition""" + def __init__(self): super(TestHybrid, self).__init__() @@ -65,12 +67,7 @@ class TestHybrid(Cell): return self.program(x, y) -def test_hybrid(): - """ - Feature: ALL To ALL - Description: hybrid test cases. - Expectation: the result match with numpy result - """ +def hybrid_case(): input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32) input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32) @@ -82,24 +79,58 @@ def test_hybrid(): raise ValueError("Precision error, compare result: {}".format(compare_res)) -def test_hybrid_ascend(): +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_hybrid_ascend_graph_mode(): """ - Feature: ALL To ALL - Description: hybrid ascend test cases. + Feature: test case for Custom op with func_type="akg" + Description: ascend test case, akg dsl using hybrid grammar in GRAPH_MODE. Expectation: the result match with numpy result """ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - test_hybrid() + hybrid_case() -def test_hybrid_gpu(): +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_hybrid_ascend_pynative_mode(): """ - Feature: ALL To ALL - Description: hybrid gpu test cases. + Feature: test case for Custom op with func_type="akg" + Description: ascend test case, akg dsl using hybrid grammar in PYNATIVE_MODE. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + hybrid_case() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_hybrid_gpu_graph_mode(): + """ + Feature: test case for Custom op with func_type="akg" + Description: gpu test case, akg dsl using hybrid grammar in GRAPH_MODE. Expectation: the result match with numpy result """ context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - test_hybrid() + hybrid_case() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_hybrid_gpu_pynative_mode(): + """ + Feature: test case for Custom op with func_type="akg" + Description: gpu test case, akg dsl using hybrid grammar in PYNATIVE_MODE. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + hybrid_case() v_add_ascend_info = CustomRegOp() \ @@ -135,6 +166,7 @@ def v_add(inputs, attrs): class TestIRbuilder(Cell): """Net definition""" + def __init__(self, shape): super(TestIRbuilder, self).__init__() self.program = Custom(v_add, out_shape=shape, out_dtype=mstype.float16, func_type="akg") @@ -143,12 +175,7 @@ class TestIRbuilder(Cell): return self.program([x, y]) -def test_irbuider(): - """ - Feature: ALL To ALL - Description: irbuider test cases. - Expectation: the result match with numpy result - """ +def irbuider_case(): shape = (4, 5) input_x = np.random.normal(0, 1, shape).astype(np.float16) input_y = np.random.normal(0, 1, shape).astype(np.float16) @@ -160,21 +187,55 @@ def test_irbuider(): raise ValueError("Precision error, compare result: {}".format(compare_res)) -def test_irbuider_ascend(): +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_irbuider_ascend_graph_mode(): """ - Feature: ALL To ALL - Description: irbuider ascend test cases. + Feature: test case for Custom op with func_type="akg" + Description: ascend test case, akg dsl using irbuider grammar in GRAPH_MODE. Expectation: the result match with numpy result """ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - test_irbuider() + irbuider_case() -def test_irbuider_gpu(): +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_irbuider_ascend_pynative_mode(): """ - Feature: ALL To ALL - Description: irbuider gpu test cases. + Feature: test case for Custom op with func_type="akg" + Description: ascend test case, akg dsl using irbuider grammar in PYNATIVE_MODE. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + irbuider_case() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_irbuider_gpu_graph_mode(): + """ + Feature: test case for Custom op with func_type="akg" + Description: gpu test case, akg dsl using irbuider grammar in GRAPH_MODE. Expectation: the result match with numpy result """ context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - test_irbuider() + irbuider_case() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_irbuider_gpu_pynative_mode(): + """ + Feature: test case for Custom op with func_type="akg" + Description: gpu test case, akg dsl using irbuider grammar in PYNATIVE_MODE. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + irbuider_case() diff --git a/tests/st/ops/graph_kernel/custom/test_custom_tbe.py b/tests/st/ops/graph_kernel/custom/test_custom_tbe.py index d582f93956e..cd7c1160ca9 100644 --- a/tests/st/ops/graph_kernel/custom/test_custom_tbe.py +++ b/tests/st/ops/graph_kernel/custom/test_custom_tbe.py @@ -21,6 +21,7 @@ from mindspore.nn import Cell from mindspore.ops import operations as P from mindspore.ops.op_info_register import TBERegOp, DataType from mindspore.ops.operations.custom_ops import Custom, CustomRegOp, custom_op_info_register +from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like square_with_bias_op_info = CustomRegOp() \ .fusion_type("OPAQUE") \ @@ -158,17 +159,7 @@ class Net1(Cell): return res -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_onecard -def test_multi_input_multi_output_with_attr(): - """ - Feature: ALL To ALL - Description: test cases for Custom op with multiple inputs, outputs and attr. - Expectation: the result match with numpy result - """ - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +def multi_input_multi_output_with_attr(): dtype = np.float32 x = np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]).astype(dtype) expect0 = np.array([[9.0, 9.0, 9.0], [441.0, 441.0, 441.0]]).astype(dtype) @@ -196,10 +187,38 @@ def test_multi_input_multi_output_with_attr(): raise ValueError("Precision error, compare result: {}".format(compare_res)) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_net1_graph_mode(): + """ + Feature: test case for Custom op with func_type="tbe" + Description: test cases with multiple inputs, outputs and attr in GRAPH_MODE. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + multi_input_multi_output_with_attr() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_net1_pynative_mode(): + """ + Feature: test case for Custom op with func_type="tbe" + Description: test cases with multiple inputs, outputs and attr in PYNATIVE_MODE. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + multi_input_multi_output_with_attr() + + def bprop(data, axis, out, dout): gradient = data * 2 dx = gradient * dout - return (dx,) + return dx, zeros_like(axis) class Net2(Cell): @@ -215,17 +234,7 @@ class Net2(Cell): return res -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_onecard -def test_bprop(): - """ - Feature: ALL To ALL - Description: test cases for bprop function of Custom op. - Expectation: the result match with numpy result - """ - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +def grad_case(): x = np.array([1.0, 4.0, 9.0]).astype(np.float32) sens = np.array([1.0, 1.0, 1.0]).astype(np.float32) expect = np.array([2.0, 8.0, 18.0]).astype(np.float32) @@ -237,3 +246,31 @@ def test_bprop(): compare_res = np.allclose(expect, dx_np, 0.0001, 0.0001) if not compare_res: raise ValueError("Precision error, compare result: {}".format(compare_res)) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_net2_graph_mode(): + """ + Feature: test case for Custom op with func_type="tbe" + Description: grad test case in GRAPH_MODE. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + grad_case() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_net2_pynative_mode(): + """ + Feature: test case for Custom op with func_type="tbe" + Description: grad test case in PYNATIVE_MODE. + Expectation: the result match with numpy result + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + grad_case() diff --git a/tests/st/ops/graph_kernel/test_user_define.py b/tests/st/ops/graph_kernel/test_user_define.py deleted file mode 100644 index b02c2349462..00000000000 --- a/tests/st/ops/graph_kernel/test_user_define.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import inspect -import numpy as np -import pytest -from mindspore import context, ops, Tensor -from mindspore.common import dtype as mstype -from mindspore.nn import Cell - - -class UserDefined(ops.PrimitiveWithInfer): - def __init__(self, func, shape, dtype, func_type=None): - ops.PrimitiveWithInfer.__init__(self, "UserDefined") - self.add_prim_attr('akg', True) - - if "__wrapped__" in func.__dict__: - func = func.__dict__["__wrapped__"] - func_name = func.__name__ - self.add_prim_attr('func_name', func_name) - func_source_str = inspect.getsource(func) - - if func_type is None: - if "ir_builder" in func_source_str: - func_type = "ir_builder" - elif "compute" in func_source_str: - func_type = "tvm_compute" - else: - func_type = "hybrid" - - self.add_prim_attr('func_source_str', func_source_str) - self.add_prim_attr('func_type', func_type) - - self._shape = shape - self._dtype = dtype - - def infer_shape(self, *args): - if callable(self._shape): - return self._shape(*args) - return self._shape - - def infer_dtype(self, *args): - if callable(self._dtype): - return self._dtype(*args) - return self._dtype - - -def outer_product(a, b): - c = output_tensor((a.shape[0], b.shape[1]), 'float32') - - for i0 in range(a.shape[0]): - for i1 in range(b.shape[1]): - c[i0, i1] = 0.0 - for i2 in range(a.shape[1]): - c[i0, i1] = c[i0, i1] + (a[i0, i2] * b[i2, i1]) - return c - - -class TestHybrid(Cell): - def __init__(self): - super(TestHybrid, self).__init__() - - def infer_func(x, y): - return x - - self.program = UserDefined( - outer_product, shape=infer_func, dtype=infer_func) - - def construct(self, x, y): - return self.program(x, y) - - -def v_add(inputs, attrs): - def vadd_func(dst, data_1, data_2): - ib = tvm.ir_builder.create() - with ib.for_range_n(data_1.shape, "i") as i: - ib.store(dst, i, ib.load(data_1, i) + ib.load(data_2, i)) - return ib.get() - data_1, data_2 = inputs[0], inputs[1] - return tvm.extern(data_1.shape, [data_1, data_2], - lambda ins, outs: vadd_func(outs[0], ins[0], ins[1]), - name="v_add", dtype=data_1.dtype) - - -class TestIRbuilder(Cell): - def __init__(self, shape): - super(TestIRbuilder, self).__init__() - self.program = UserDefined( - v_add, shape=shape, dtype=mstype.float16) - - def construct(self, x, y): - return self.program(x, y) - - -def test_user_defined_hybrid(): - - input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32) - input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32) - - test = TestHybrid() - output = test(Tensor(input_x), Tensor(input_y)) - expect = np.matmul(input_x, input_y) - assert np.allclose(expect, output.asnumpy(), 0.001, 0.001) - - -def test_user_defined_irbuider(): - - shape = (4, 5) - input_x = np.random.normal(0, 1, shape).astype(np.float16) - input_y = np.random.normal(0, 1, shape).astype(np.float16) - - test = TestIRbuilder(shape) - output = test(Tensor(input_x), Tensor(input_y)) - assert np.allclose(input_x + input_y, output.asnumpy(), 0.001, 0.001) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_user_defined_gpu(): - context.set_context(mode=0, enable_graph_kernel=True) - test_user_defined_hybrid() - test_user_defined_irbuider()