forked from mindspore-Ecosystem/mindspore
!25181 add custom op pynative testcases
Merge pull request !25181 from looop5/custom_cases_commit
This commit is contained in:
commit
0b6ecd5b4c
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit b0f0f9e3ef5e6aaa788802f35f1daca657b5642a
|
||||
Subproject commit 5f5eeb31ffdf5a1dc973e7f904dc88ad7581bc5d
|
|
@ -1018,5 +1018,29 @@ int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &kernel_graph, cons
|
|||
auto output_used_num = GetNodeOutputUsedNum(kernel_graph, node);
|
||||
return std::accumulate(output_used_num.begin(), output_used_num.end(), int64_t(0));
|
||||
}
|
||||
|
||||
void GetCustomOpAttrIndex(const PrimitivePtr &primitive, std::unordered_set<size_t> *indexes) {
|
||||
if (primitive == nullptr || primitive->name() != prim::kPrimCustom->name()) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(indexes);
|
||||
auto input_names = primitive->GetAttr(kAttrInputNames);
|
||||
auto attr_names = primitive->GetAttr(kAttrAttrNames);
|
||||
if (input_names == nullptr || attr_names == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
|
||||
auto attr_names_vec = GetValue<std::vector<std::string>>(attr_names);
|
||||
if (input_names_vec.size() >= attr_names_vec.size()) {
|
||||
size_t offset = input_names_vec.size() - attr_names_vec.size();
|
||||
for (size_t i = offset; i < input_names_vec.size(); ++i) {
|
||||
if (input_names_vec[i] != attr_names_vec[i - offset]) {
|
||||
MS_LOG(EXCEPTION) << primitive->name() << " found mismatching attr name " << input_names_vec[i]
|
||||
<< "in input_names and " << attr_names_vec[i - offset] << " in attr_names";
|
||||
}
|
||||
indexes->insert(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -229,6 +229,9 @@ std::vector<int64_t> GetNodeOutputUsedNum(const session::KernelGraph &kernel_gra
|
|||
|
||||
// Get total used number of node's output
|
||||
int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node);
|
||||
|
||||
// Get custom operator attr input indexes
|
||||
void GetCustomOpAttrIndex(const PrimitivePtr &primitive, std::unordered_set<size_t> *indexes);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_
|
||||
|
|
|
@ -15,13 +15,10 @@
|
|||
*/
|
||||
#include "backend/optimizer/pass/custom_op_const_input_to_attr.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -41,20 +38,8 @@ const AnfNodePtr CustomOpConstInputToAttr::Process(const FuncGraphPtr &, const A
|
|||
}
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto input_names = primitive->GetAttr(kAttrInputNames);
|
||||
auto attr_names = primitive->GetAttr("attr_names");
|
||||
if (input_names == nullptr || attr_names == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
|
||||
auto attr_names_vec = GetValue<std::vector<std::string>>(attr_names);
|
||||
std::unordered_set<size_t> attr_indices;
|
||||
if (input_names_vec.size() >= attr_names_vec.size()) {
|
||||
size_t offset = input_names_vec.size() - attr_names_vec.size();
|
||||
for (size_t i = offset; i < input_names_vec.size(); ++i) {
|
||||
attr_indices.insert(i);
|
||||
}
|
||||
}
|
||||
GetCustomOpAttrIndex(primitive, &attr_indices);
|
||||
if (attr_indices.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -470,9 +470,9 @@ std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) {
|
|||
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
||||
if (primitive != nullptr) {
|
||||
if (primitive->name() == "Custom") {
|
||||
auto func_name = primitive->GetAttr("func_name");
|
||||
if (func_name) {
|
||||
return GetValue<std::string>(func_name);
|
||||
auto uniq_name = primitive->GetAttr("uniq_name");
|
||||
if (uniq_name) {
|
||||
return GetValue<std::string>(uniq_name);
|
||||
}
|
||||
}
|
||||
return primitive->name();
|
||||
|
|
|
@ -503,7 +503,18 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t>
|
|||
MS_EXCEPTION_IF_NULL(op_prim);
|
||||
// Checking whether attr conversion is needed.
|
||||
opt::ConstInputToAttrInfoRegister reg;
|
||||
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®);
|
||||
bool reg_exist = false;
|
||||
if (op_run_info->op_name == prim::kPrimCustom->name()) {
|
||||
// Custom op needs to set reg dynamically
|
||||
std::unordered_set<size_t> attr_indexes;
|
||||
opt::GetCustomOpAttrIndex(op_prim, &attr_indexes);
|
||||
if (!attr_indexes.empty()) {
|
||||
reg_exist = true;
|
||||
reg.SetConstInputToAttr(attr_indexes);
|
||||
}
|
||||
} else {
|
||||
reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®);
|
||||
}
|
||||
if (op_run_info->is_dynamic_shape &&
|
||||
dynamic_shape_const_input_to_attr.find(op_run_info->op_name) == dynamic_shape_const_input_to_attr.end()) {
|
||||
MS_LOG(DEBUG) << "current node is dynamic shape: " << op_run_info->op_name;
|
||||
|
|
|
@ -336,6 +336,7 @@ constexpr auto kHcomOpTypeReduceScatter = "HcomReduceScatter";
|
|||
|
||||
// attr key name
|
||||
constexpr auto kAttrInputNames = "input_names";
|
||||
constexpr auto kAttrAttrNames = "attr_names";
|
||||
constexpr auto kAttrIsAICPUKernel = "is_AICPU_kernel";
|
||||
constexpr auto kIsBackendCast = "is_backed_cast";
|
||||
constexpr auto kAttrOutputNames = "output_names";
|
||||
|
|
|
@ -91,7 +91,7 @@ class CustomRegOp(RegOp):
|
|||
Please note that target and the `func_type` of `Custom` op have some constraints.
|
||||
If func_type is "akg", target can be one of ["Ascend", "GPU"].
|
||||
If func_type is "tbe", target can only be "Ascend".
|
||||
If func_type is "aot", target can only be "GPU".
|
||||
If func_type is "aot", target can only be ["GPU", "CPU"].
|
||||
If func_type is "py_func", target can only be "CPU".
|
||||
Default: None.
|
||||
"""
|
||||
|
@ -264,7 +264,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
>>>
|
||||
>>> class Net(Cell):
|
||||
... def __init__(self):
|
||||
... super(Net1, self).__init__()
|
||||
... super(Net, self).__init__()
|
||||
... self.square_with_bias = Custom(square_with_bias, out_shape=[2, 3], out_dtype=mstype.float32, \
|
||||
... func_type="tbe")
|
||||
...
|
||||
|
@ -296,6 +296,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
registered_func = {}
|
||||
attr_dict = {} # Save input_names and attr_names for func.
|
||||
|
||||
def __init__(self, func, out_shape, out_dtype, func_type, bprop=None, reg_info=None):
|
||||
ops.PrimitiveWithInfer.__init__(self, "Custom")
|
||||
|
@ -303,16 +304,23 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
self.supported_targets = ["Ascend", "GPU", "CPU"]
|
||||
self.func = func
|
||||
self.func_name = ""
|
||||
self.uniq_name = ""
|
||||
self.imply_path = ""
|
||||
if callable(self.func):
|
||||
# Get the original function if func is decorated
|
||||
if "__wrapped__" in self.func.__dict__:
|
||||
self.func = self.func.__dict__["__wrapped__"]
|
||||
self.imply_path = os.path.realpath(inspect.getfile(self.func))
|
||||
self.func_name = self.func.__name__
|
||||
self.uniq_name = self.name + "_" + self.func_name + "_" + str(id(self.func))
|
||||
elif isinstance(self.func, str):
|
||||
self.func_name = self.func
|
||||
self.uniq_name = self.name + "_" + self.func_name
|
||||
else:
|
||||
raise TypeError("func should be of type function or str, but got {}".format(type(self.func)))
|
||||
self.add_prim_attr("func_name", self.func_name)
|
||||
self.add_prim_attr("uniq_name", self.uniq_name)
|
||||
self.add_prim_attr("imply_path", self.imply_path)
|
||||
self.out_shape = out_shape
|
||||
self.out_dtype = out_dtype
|
||||
self.bprop = bprop
|
||||
|
@ -333,6 +341,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
else:
|
||||
self.func_type = "hybrid"
|
||||
self.add_prim_attr("func_type", self.func_type)
|
||||
self.update_attr()
|
||||
|
||||
def infer_shape(self, *args):
|
||||
if callable(self.out_shape):
|
||||
|
@ -353,7 +362,6 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
if reg_info is None and hasattr(self.func, "reg_info"):
|
||||
reg_info = getattr(self.func, "reg_info")
|
||||
reg_info_list = self.get_expanded_list(reg_info)
|
||||
already_add_attr = False
|
||||
for reg_info in reg_info_list:
|
||||
if not isinstance(reg_info, (str, dict)):
|
||||
continue
|
||||
|
@ -366,17 +374,10 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
# Register
|
||||
reg_info = self.reformat_reg_info(reg_info, target)
|
||||
reg_info_str = json.dumps(reg_info)
|
||||
if isinstance(self.func, str):
|
||||
imply_path = self.func
|
||||
else:
|
||||
imply_path = os.path.realpath(inspect.getfile(self.func))
|
||||
op_lib = Oplib()
|
||||
if not op_lib.reg_op(reg_info_str, imply_path):
|
||||
raise ValueError('Invalid reg info {}: {}\n'.format(imply_path, reg_info_str))
|
||||
# Add inputs name to attr
|
||||
if not already_add_attr:
|
||||
self.add_inputs_name_to_attr(reg_info)
|
||||
already_add_attr = True
|
||||
if not op_lib.reg_op(reg_info_str, self.imply_path):
|
||||
raise ValueError('Invalid reg info {}: {}\n'.format(self.imply_path, reg_info_str))
|
||||
self.save_attr(reg_info)
|
||||
self.save_register_status(target)
|
||||
|
||||
def get_expanded_list(self, data):
|
||||
|
@ -405,7 +406,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
registered_targets = getattr(self.func, "registered_targets", [])
|
||||
registered_targets.append(target)
|
||||
setattr(self.func, "registered_targets", registered_targets)
|
||||
elif isinstance(self, str):
|
||||
elif isinstance(self.func, str):
|
||||
if isinstance(Custom.registered_func.get(self.func), list):
|
||||
Custom.registered_func[self.func].append(target)
|
||||
else:
|
||||
|
@ -415,7 +416,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
"""Reformat registration information."""
|
||||
if not isinstance(reg_info, dict):
|
||||
raise TypeError("reg_info should be of type dict, but got {}".format(type(reg_info)))
|
||||
reg_info["op_name"] = self.func_name
|
||||
reg_info["op_name"] = self.uniq_name
|
||||
reg_info["imply_type"] = self.get_imply_type(reg_info, target)
|
||||
# Supplement necessary info for TBE if these information is missing in reg_info
|
||||
if reg_info["imply_type"] == "TBE":
|
||||
|
@ -468,13 +469,19 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
func_type_to_imply_type = {"akg": "AKG", "tbe": "TBE", "aot": target, "py_func": target}
|
||||
return func_type_to_imply_type.get(self.func_type, "AKG")
|
||||
|
||||
def add_inputs_name_to_attr(self, reg_info):
|
||||
"""Save inputs name to primitive's attr."""
|
||||
def save_attr(self, reg_info):
|
||||
"""Save input_names and attr_names of current func."""
|
||||
if not isinstance(reg_info, dict):
|
||||
return
|
||||
tensor_inputs = reg_info.get("inputs", [])
|
||||
attr = reg_info.get("attr", [])
|
||||
if not isinstance(tensor_inputs, (list, tuple)):
|
||||
tensor_inputs = [tensor_inputs]
|
||||
if not isinstance(attr, (list, tuple)):
|
||||
attr = [attr]
|
||||
# input_names include tensor input names and attr input names
|
||||
input_names = []
|
||||
# attr_names only includes attr input names
|
||||
attr_names = []
|
||||
for item in tensor_inputs:
|
||||
if isinstance(item, dict) and item.get("name") is not None:
|
||||
|
@ -483,6 +490,46 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
if isinstance(item, dict) and item.get("name") is not None:
|
||||
input_names.append(item["name"])
|
||||
attr_names.append(item["name"])
|
||||
# input_names include tensor input names and attr names
|
||||
self.add_prim_attr("input_names", input_names)
|
||||
self.add_prim_attr("attr_names", attr_names)
|
||||
cur_attr = {"input_names": input_names, "attr_names": attr_names}
|
||||
# If func does not have attr, save current attr.
|
||||
# Else, check if current attr is same as previous saved one.
|
||||
prev_input_names = input_names
|
||||
prev_attr_names = attr_names
|
||||
if callable(self.func):
|
||||
func_attr = getattr(self.func, "func_attr", None)
|
||||
if not isinstance(func_attr, dict):
|
||||
setattr(self.func, "func_attr", cur_attr)
|
||||
else:
|
||||
prev_input_names = func_attr.get("input_names")
|
||||
prev_attr_names = func_attr.get("attr_names")
|
||||
elif isinstance(self.func, str):
|
||||
func_attr = Custom.attr_dict.get(self.func)
|
||||
if not isinstance(func_attr, dict):
|
||||
Custom.attr_dict[self.func] = cur_attr
|
||||
else:
|
||||
prev_input_names = func_attr.get("input_names")
|
||||
prev_attr_names = func_attr.get("attr_names")
|
||||
if not isinstance(prev_input_names, list):
|
||||
raise TypeError("func {}: previous saved input_names should be a list, but got {}"
|
||||
.format(self.func, type(prev_input_names)))
|
||||
if len(input_names) != len(prev_input_names):
|
||||
raise ValueError("func {}: input_names's length {} is different from previous saved one {}"
|
||||
.format(self.func, len(input_names), len(prev_input_names)))
|
||||
if attr_names != prev_attr_names:
|
||||
raise ValueError("func {}: attr_names {} is different from previous saved one {}"
|
||||
.format(self.func, attr_names, prev_attr_names))
|
||||
|
||||
def update_attr(self):
|
||||
"""Add input_names and attr_names to primitive's attr."""
|
||||
func_attr = {}
|
||||
if callable(self.func):
|
||||
func_attr = getattr(self.func, "func_attr", None)
|
||||
elif isinstance(self.func, str):
|
||||
func_attr = Custom.attr_dict.get(self.func)
|
||||
if isinstance(func_attr, dict):
|
||||
input_names = func_attr.get("input_names")
|
||||
attr_names = func_attr.get("attr_names")
|
||||
if input_names:
|
||||
self.add_prim_attr("input_names", input_names)
|
||||
if attr_names:
|
||||
self.add_prim_attr("attr_names", attr_names)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
@ -53,6 +54,7 @@ def outer_product(a, b):
|
|||
|
||||
class TestHybrid(Cell):
|
||||
"""Net definition"""
|
||||
|
||||
def __init__(self):
|
||||
super(TestHybrid, self).__init__()
|
||||
|
||||
|
@ -65,12 +67,7 @@ class TestHybrid(Cell):
|
|||
return self.program(x, y)
|
||||
|
||||
|
||||
def test_hybrid():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: hybrid test cases.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
def hybrid_case():
|
||||
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
|
||||
|
@ -82,24 +79,58 @@ def test_hybrid():
|
|||
raise ValueError("Precision error, compare result: {}".format(compare_res))
|
||||
|
||||
|
||||
def test_hybrid_ascend():
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_hybrid_ascend_graph_mode():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: hybrid ascend test cases.
|
||||
Feature: test case for Custom op with func_type="akg"
|
||||
Description: ascend test case, akg dsl using hybrid grammar in GRAPH_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_hybrid()
|
||||
hybrid_case()
|
||||
|
||||
|
||||
def test_hybrid_gpu():
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_hybrid_ascend_pynative_mode():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: hybrid gpu test cases.
|
||||
Feature: test case for Custom op with func_type="akg"
|
||||
Description: ascend test case, akg dsl using hybrid grammar in PYNATIVE_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
hybrid_case()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_hybrid_gpu_graph_mode():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="akg"
|
||||
Description: gpu test case, akg dsl using hybrid grammar in GRAPH_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_hybrid()
|
||||
hybrid_case()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_hybrid_gpu_pynative_mode():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="akg"
|
||||
Description: gpu test case, akg dsl using hybrid grammar in PYNATIVE_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
hybrid_case()
|
||||
|
||||
|
||||
v_add_ascend_info = CustomRegOp() \
|
||||
|
@ -135,6 +166,7 @@ def v_add(inputs, attrs):
|
|||
|
||||
class TestIRbuilder(Cell):
|
||||
"""Net definition"""
|
||||
|
||||
def __init__(self, shape):
|
||||
super(TestIRbuilder, self).__init__()
|
||||
self.program = Custom(v_add, out_shape=shape, out_dtype=mstype.float16, func_type="akg")
|
||||
|
@ -143,12 +175,7 @@ class TestIRbuilder(Cell):
|
|||
return self.program([x, y])
|
||||
|
||||
|
||||
def test_irbuider():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: irbuider test cases.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
def irbuider_case():
|
||||
shape = (4, 5)
|
||||
input_x = np.random.normal(0, 1, shape).astype(np.float16)
|
||||
input_y = np.random.normal(0, 1, shape).astype(np.float16)
|
||||
|
@ -160,21 +187,55 @@ def test_irbuider():
|
|||
raise ValueError("Precision error, compare result: {}".format(compare_res))
|
||||
|
||||
|
||||
def test_irbuider_ascend():
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_irbuider_ascend_graph_mode():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: irbuider ascend test cases.
|
||||
Feature: test case for Custom op with func_type="akg"
|
||||
Description: ascend test case, akg dsl using irbuider grammar in GRAPH_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_irbuider()
|
||||
irbuider_case()
|
||||
|
||||
|
||||
def test_irbuider_gpu():
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_irbuider_ascend_pynative_mode():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: irbuider gpu test cases.
|
||||
Feature: test case for Custom op with func_type="akg"
|
||||
Description: ascend test case, akg dsl using irbuider grammar in PYNATIVE_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
irbuider_case()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_irbuider_gpu_graph_mode():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="akg"
|
||||
Description: gpu test case, akg dsl using irbuider grammar in GRAPH_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_irbuider()
|
||||
irbuider_case()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_irbuider_gpu_pynative_mode():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="akg"
|
||||
Description: gpu test case, akg dsl using irbuider grammar in PYNATIVE_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
irbuider_case()
|
||||
|
|
|
@ -21,6 +21,7 @@ from mindspore.nn import Cell
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.op_info_register import TBERegOp, DataType
|
||||
from mindspore.ops.operations.custom_ops import Custom, CustomRegOp, custom_op_info_register
|
||||
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
|
||||
square_with_bias_op_info = CustomRegOp() \
|
||||
.fusion_type("OPAQUE") \
|
||||
|
@ -158,17 +159,7 @@ class Net1(Cell):
|
|||
return res
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_multi_input_multi_output_with_attr():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for Custom op with multiple inputs, outputs and attr.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
def multi_input_multi_output_with_attr():
|
||||
dtype = np.float32
|
||||
x = np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]).astype(dtype)
|
||||
expect0 = np.array([[9.0, 9.0, 9.0], [441.0, 441.0, 441.0]]).astype(dtype)
|
||||
|
@ -196,10 +187,38 @@ def test_multi_input_multi_output_with_attr():
|
|||
raise ValueError("Precision error, compare result: {}".format(compare_res))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net1_graph_mode():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="tbe"
|
||||
Description: test cases with multiple inputs, outputs and attr in GRAPH_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
multi_input_multi_output_with_attr()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net1_pynative_mode():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="tbe"
|
||||
Description: test cases with multiple inputs, outputs and attr in PYNATIVE_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
multi_input_multi_output_with_attr()
|
||||
|
||||
|
||||
def bprop(data, axis, out, dout):
|
||||
gradient = data * 2
|
||||
dx = gradient * dout
|
||||
return (dx,)
|
||||
return dx, zeros_like(axis)
|
||||
|
||||
|
||||
class Net2(Cell):
|
||||
|
@ -215,17 +234,7 @@ class Net2(Cell):
|
|||
return res
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bprop():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for bprop function of Custom op.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
def grad_case():
|
||||
x = np.array([1.0, 4.0, 9.0]).astype(np.float32)
|
||||
sens = np.array([1.0, 1.0, 1.0]).astype(np.float32)
|
||||
expect = np.array([2.0, 8.0, 18.0]).astype(np.float32)
|
||||
|
@ -237,3 +246,31 @@ def test_bprop():
|
|||
compare_res = np.allclose(expect, dx_np, 0.0001, 0.0001)
|
||||
if not compare_res:
|
||||
raise ValueError("Precision error, compare result: {}".format(compare_res))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net2_graph_mode():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="tbe"
|
||||
Description: grad test case in GRAPH_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
grad_case()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net2_pynative_mode():
|
||||
"""
|
||||
Feature: test case for Custom op with func_type="tbe"
|
||||
Description: grad test case in PYNATIVE_MODE.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
grad_case()
|
||||
|
|
|
@ -1,135 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import inspect
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import context, ops, Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
|
||||
|
||||
class UserDefined(ops.PrimitiveWithInfer):
|
||||
def __init__(self, func, shape, dtype, func_type=None):
|
||||
ops.PrimitiveWithInfer.__init__(self, "UserDefined")
|
||||
self.add_prim_attr('akg', True)
|
||||
|
||||
if "__wrapped__" in func.__dict__:
|
||||
func = func.__dict__["__wrapped__"]
|
||||
func_name = func.__name__
|
||||
self.add_prim_attr('func_name', func_name)
|
||||
func_source_str = inspect.getsource(func)
|
||||
|
||||
if func_type is None:
|
||||
if "ir_builder" in func_source_str:
|
||||
func_type = "ir_builder"
|
||||
elif "compute" in func_source_str:
|
||||
func_type = "tvm_compute"
|
||||
else:
|
||||
func_type = "hybrid"
|
||||
|
||||
self.add_prim_attr('func_source_str', func_source_str)
|
||||
self.add_prim_attr('func_type', func_type)
|
||||
|
||||
self._shape = shape
|
||||
self._dtype = dtype
|
||||
|
||||
def infer_shape(self, *args):
|
||||
if callable(self._shape):
|
||||
return self._shape(*args)
|
||||
return self._shape
|
||||
|
||||
def infer_dtype(self, *args):
|
||||
if callable(self._dtype):
|
||||
return self._dtype(*args)
|
||||
return self._dtype
|
||||
|
||||
|
||||
def outer_product(a, b):
|
||||
c = output_tensor((a.shape[0], b.shape[1]), 'float32')
|
||||
|
||||
for i0 in range(a.shape[0]):
|
||||
for i1 in range(b.shape[1]):
|
||||
c[i0, i1] = 0.0
|
||||
for i2 in range(a.shape[1]):
|
||||
c[i0, i1] = c[i0, i1] + (a[i0, i2] * b[i2, i1])
|
||||
return c
|
||||
|
||||
|
||||
class TestHybrid(Cell):
|
||||
def __init__(self):
|
||||
super(TestHybrid, self).__init__()
|
||||
|
||||
def infer_func(x, y):
|
||||
return x
|
||||
|
||||
self.program = UserDefined(
|
||||
outer_product, shape=infer_func, dtype=infer_func)
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.program(x, y)
|
||||
|
||||
|
||||
def v_add(inputs, attrs):
|
||||
def vadd_func(dst, data_1, data_2):
|
||||
ib = tvm.ir_builder.create()
|
||||
with ib.for_range_n(data_1.shape, "i") as i:
|
||||
ib.store(dst, i, ib.load(data_1, i) + ib.load(data_2, i))
|
||||
return ib.get()
|
||||
data_1, data_2 = inputs[0], inputs[1]
|
||||
return tvm.extern(data_1.shape, [data_1, data_2],
|
||||
lambda ins, outs: vadd_func(outs[0], ins[0], ins[1]),
|
||||
name="v_add", dtype=data_1.dtype)
|
||||
|
||||
|
||||
class TestIRbuilder(Cell):
|
||||
def __init__(self, shape):
|
||||
super(TestIRbuilder, self).__init__()
|
||||
self.program = UserDefined(
|
||||
v_add, shape=shape, dtype=mstype.float16)
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.program(x, y)
|
||||
|
||||
|
||||
def test_user_defined_hybrid():
|
||||
|
||||
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
|
||||
test = TestHybrid()
|
||||
output = test(Tensor(input_x), Tensor(input_y))
|
||||
expect = np.matmul(input_x, input_y)
|
||||
assert np.allclose(expect, output.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
def test_user_defined_irbuider():
|
||||
|
||||
shape = (4, 5)
|
||||
input_x = np.random.normal(0, 1, shape).astype(np.float16)
|
||||
input_y = np.random.normal(0, 1, shape).astype(np.float16)
|
||||
|
||||
test = TestIRbuilder(shape)
|
||||
output = test(Tensor(input_x), Tensor(input_y))
|
||||
assert np.allclose(input_x + input_y, output.asnumpy(), 0.001, 0.001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_user_defined_gpu():
|
||||
context.set_context(mode=0, enable_graph_kernel=True)
|
||||
test_user_defined_hybrid()
|
||||
test_user_defined_irbuider()
|
Loading…
Reference in New Issue