add custom op pynative testcases

This commit is contained in:
looop5 2021-10-25 16:36:54 +08:00
parent 14729d5199
commit 0b34bf33bc
11 changed files with 261 additions and 227 deletions

2
akg

@ -1 +1 @@
Subproject commit b0f0f9e3ef5e6aaa788802f35f1daca657b5642a
Subproject commit 5f5eeb31ffdf5a1dc973e7f904dc88ad7581bc5d

View File

@ -1018,5 +1018,29 @@ int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &kernel_graph, cons
auto output_used_num = GetNodeOutputUsedNum(kernel_graph, node);
return std::accumulate(output_used_num.begin(), output_used_num.end(), int64_t(0));
}
void GetCustomOpAttrIndex(const PrimitivePtr &primitive, std::unordered_set<size_t> *indexes) {
if (primitive == nullptr || primitive->name() != prim::kPrimCustom->name()) {
return;
}
MS_EXCEPTION_IF_NULL(indexes);
auto input_names = primitive->GetAttr(kAttrInputNames);
auto attr_names = primitive->GetAttr(kAttrAttrNames);
if (input_names == nullptr || attr_names == nullptr) {
return;
}
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
auto attr_names_vec = GetValue<std::vector<std::string>>(attr_names);
if (input_names_vec.size() >= attr_names_vec.size()) {
size_t offset = input_names_vec.size() - attr_names_vec.size();
for (size_t i = offset; i < input_names_vec.size(); ++i) {
if (input_names_vec[i] != attr_names_vec[i - offset]) {
MS_LOG(EXCEPTION) << primitive->name() << " found mismatching attr name " << input_names_vec[i]
<< "in input_names and " << attr_names_vec[i - offset] << " in attr_names";
}
indexes->insert(i);
}
}
}
} // namespace opt
} // namespace mindspore

View File

@ -229,6 +229,9 @@ std::vector<int64_t> GetNodeOutputUsedNum(const session::KernelGraph &kernel_gra
// Get total used number of node's output
int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &kernel_graph, const AnfNodePtr &node);
// Get custom operator attr input indexes
void GetCustomOpAttrIndex(const PrimitivePtr &primitive, std::unordered_set<size_t> *indexes);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_

View File

@ -15,13 +15,10 @@
*/
#include "backend/optimizer/pass/custom_op_const_input_to_attr.h"
#include <string>
#include <memory>
#include <vector>
#include <unordered_set>
#include "backend/optimizer/common/helper.h"
#include "utils/utils.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
@ -41,20 +38,8 @@ const AnfNodePtr CustomOpConstInputToAttr::Process(const FuncGraphPtr &, const A
}
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(primitive);
auto input_names = primitive->GetAttr(kAttrInputNames);
auto attr_names = primitive->GetAttr("attr_names");
if (input_names == nullptr || attr_names == nullptr) {
return nullptr;
}
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
auto attr_names_vec = GetValue<std::vector<std::string>>(attr_names);
std::unordered_set<size_t> attr_indices;
if (input_names_vec.size() >= attr_names_vec.size()) {
size_t offset = input_names_vec.size() - attr_names_vec.size();
for (size_t i = offset; i < input_names_vec.size(); ++i) {
attr_indices.insert(i);
}
}
GetCustomOpAttrIndex(primitive, &attr_indices);
if (attr_indices.empty()) {
return nullptr;
}

View File

@ -470,9 +470,9 @@ std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) {
auto primitive = AnfAlgo::GetCNodePrimitive(node);
if (primitive != nullptr) {
if (primitive->name() == "Custom") {
auto func_name = primitive->GetAttr("func_name");
if (func_name) {
return GetValue<std::string>(func_name);
auto uniq_name = primitive->GetAttr("uniq_name");
if (uniq_name) {
return GetValue<std::string>(uniq_name);
}
}
return primitive->name();

View File

@ -503,7 +503,18 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t>
MS_EXCEPTION_IF_NULL(op_prim);
// Checking whether attr conversion is needed.
opt::ConstInputToAttrInfoRegister reg;
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
bool reg_exist = false;
if (op_run_info->op_name == prim::kPrimCustom->name()) {
// Custom op needs to set reg dynamically
std::unordered_set<size_t> attr_indexes;
opt::GetCustomOpAttrIndex(op_prim, &attr_indexes);
if (!attr_indexes.empty()) {
reg_exist = true;
reg.SetConstInputToAttr(attr_indexes);
}
} else {
reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
}
if (op_run_info->is_dynamic_shape &&
dynamic_shape_const_input_to_attr.find(op_run_info->op_name) == dynamic_shape_const_input_to_attr.end()) {
MS_LOG(DEBUG) << "current node is dynamic shape: " << op_run_info->op_name;

View File

@ -336,6 +336,7 @@ constexpr auto kHcomOpTypeReduceScatter = "HcomReduceScatter";
// attr key name
constexpr auto kAttrInputNames = "input_names";
constexpr auto kAttrAttrNames = "attr_names";
constexpr auto kAttrIsAICPUKernel = "is_AICPU_kernel";
constexpr auto kIsBackendCast = "is_backed_cast";
constexpr auto kAttrOutputNames = "output_names";

View File

@ -91,7 +91,7 @@ class CustomRegOp(RegOp):
Please note that target and the `func_type` of `Custom` op have some constraints.
If func_type is "akg", target can be one of ["Ascend", "GPU"].
If func_type is "tbe", target can only be "Ascend".
If func_type is "aot", target can only be "GPU".
If func_type is "aot", target can only be ["GPU", "CPU"].
If func_type is "py_func", target can only be "CPU".
Default: None.
"""
@ -264,7 +264,7 @@ class Custom(ops.PrimitiveWithInfer):
>>>
>>> class Net(Cell):
... def __init__(self):
... super(Net1, self).__init__()
... super(Net, self).__init__()
... self.square_with_bias = Custom(square_with_bias, out_shape=[2, 3], out_dtype=mstype.float32, \
... func_type="tbe")
...
@ -296,6 +296,7 @@ class Custom(ops.PrimitiveWithInfer):
"""
registered_func = {}
attr_dict = {} # Save input_names and attr_names for func.
def __init__(self, func, out_shape, out_dtype, func_type, bprop=None, reg_info=None):
ops.PrimitiveWithInfer.__init__(self, "Custom")
@ -303,16 +304,23 @@ class Custom(ops.PrimitiveWithInfer):
self.supported_targets = ["Ascend", "GPU", "CPU"]
self.func = func
self.func_name = ""
self.uniq_name = ""
self.imply_path = ""
if callable(self.func):
# Get the original function if func is decorated
if "__wrapped__" in self.func.__dict__:
self.func = self.func.__dict__["__wrapped__"]
self.imply_path = os.path.realpath(inspect.getfile(self.func))
self.func_name = self.func.__name__
self.uniq_name = self.name + "_" + self.func_name + "_" + str(id(self.func))
elif isinstance(self.func, str):
self.func_name = self.func
self.uniq_name = self.name + "_" + self.func_name
else:
raise TypeError("func should be of type function or str, but got {}".format(type(self.func)))
self.add_prim_attr("func_name", self.func_name)
self.add_prim_attr("uniq_name", self.uniq_name)
self.add_prim_attr("imply_path", self.imply_path)
self.out_shape = out_shape
self.out_dtype = out_dtype
self.bprop = bprop
@ -333,6 +341,7 @@ class Custom(ops.PrimitiveWithInfer):
else:
self.func_type = "hybrid"
self.add_prim_attr("func_type", self.func_type)
self.update_attr()
def infer_shape(self, *args):
if callable(self.out_shape):
@ -353,7 +362,6 @@ class Custom(ops.PrimitiveWithInfer):
if reg_info is None and hasattr(self.func, "reg_info"):
reg_info = getattr(self.func, "reg_info")
reg_info_list = self.get_expanded_list(reg_info)
already_add_attr = False
for reg_info in reg_info_list:
if not isinstance(reg_info, (str, dict)):
continue
@ -366,17 +374,10 @@ class Custom(ops.PrimitiveWithInfer):
# Register
reg_info = self.reformat_reg_info(reg_info, target)
reg_info_str = json.dumps(reg_info)
if isinstance(self.func, str):
imply_path = self.func
else:
imply_path = os.path.realpath(inspect.getfile(self.func))
op_lib = Oplib()
if not op_lib.reg_op(reg_info_str, imply_path):
raise ValueError('Invalid reg info {}: {}\n'.format(imply_path, reg_info_str))
# Add inputs name to attr
if not already_add_attr:
self.add_inputs_name_to_attr(reg_info)
already_add_attr = True
if not op_lib.reg_op(reg_info_str, self.imply_path):
raise ValueError('Invalid reg info {}: {}\n'.format(self.imply_path, reg_info_str))
self.save_attr(reg_info)
self.save_register_status(target)
def get_expanded_list(self, data):
@ -405,7 +406,7 @@ class Custom(ops.PrimitiveWithInfer):
registered_targets = getattr(self.func, "registered_targets", [])
registered_targets.append(target)
setattr(self.func, "registered_targets", registered_targets)
elif isinstance(self, str):
elif isinstance(self.func, str):
if isinstance(Custom.registered_func.get(self.func), list):
Custom.registered_func[self.func].append(target)
else:
@ -415,7 +416,7 @@ class Custom(ops.PrimitiveWithInfer):
"""Reformat registration information."""
if not isinstance(reg_info, dict):
raise TypeError("reg_info should be of type dict, but got {}".format(type(reg_info)))
reg_info["op_name"] = self.func_name
reg_info["op_name"] = self.uniq_name
reg_info["imply_type"] = self.get_imply_type(reg_info, target)
# Supplement necessary info for TBE if these information is missing in reg_info
if reg_info["imply_type"] == "TBE":
@ -468,13 +469,19 @@ class Custom(ops.PrimitiveWithInfer):
func_type_to_imply_type = {"akg": "AKG", "tbe": "TBE", "aot": target, "py_func": target}
return func_type_to_imply_type.get(self.func_type, "AKG")
def add_inputs_name_to_attr(self, reg_info):
"""Save inputs name to primitive's attr."""
def save_attr(self, reg_info):
"""Save input_names and attr_names of current func."""
if not isinstance(reg_info, dict):
return
tensor_inputs = reg_info.get("inputs", [])
attr = reg_info.get("attr", [])
if not isinstance(tensor_inputs, (list, tuple)):
tensor_inputs = [tensor_inputs]
if not isinstance(attr, (list, tuple)):
attr = [attr]
# input_names include tensor input names and attr input names
input_names = []
# attr_names only includes attr input names
attr_names = []
for item in tensor_inputs:
if isinstance(item, dict) and item.get("name") is not None:
@ -483,6 +490,46 @@ class Custom(ops.PrimitiveWithInfer):
if isinstance(item, dict) and item.get("name") is not None:
input_names.append(item["name"])
attr_names.append(item["name"])
# input_names include tensor input names and attr names
cur_attr = {"input_names": input_names, "attr_names": attr_names}
# If func does not have attr, save current attr.
# Else, check if current attr is same as previous saved one.
prev_input_names = input_names
prev_attr_names = attr_names
if callable(self.func):
func_attr = getattr(self.func, "func_attr", None)
if not isinstance(func_attr, dict):
setattr(self.func, "func_attr", cur_attr)
else:
prev_input_names = func_attr.get("input_names")
prev_attr_names = func_attr.get("attr_names")
elif isinstance(self.func, str):
func_attr = Custom.attr_dict.get(self.func)
if not isinstance(func_attr, dict):
Custom.attr_dict[self.func] = cur_attr
else:
prev_input_names = func_attr.get("input_names")
prev_attr_names = func_attr.get("attr_names")
if not isinstance(prev_input_names, list):
raise TypeError("func {}: previous saved input_names should be a list, but got {}"
.format(self.func, type(prev_input_names)))
if len(input_names) != len(prev_input_names):
raise ValueError("func {}: input_names's length {} is different from previous saved one {}"
.format(self.func, len(input_names), len(prev_input_names)))
if attr_names != prev_attr_names:
raise ValueError("func {}: attr_names {} is different from previous saved one {}"
.format(self.func, attr_names, prev_attr_names))
def update_attr(self):
"""Add input_names and attr_names to primitive's attr."""
func_attr = {}
if callable(self.func):
func_attr = getattr(self.func, "func_attr", None)
elif isinstance(self.func, str):
func_attr = Custom.attr_dict.get(self.func)
if isinstance(func_attr, dict):
input_names = func_attr.get("input_names")
attr_names = func_attr.get("attr_names")
if input_names:
self.add_prim_attr("input_names", input_names)
if attr_names:
self.add_prim_attr("attr_names", attr_names)

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore import context, Tensor
from mindspore.common import dtype as mstype
@ -53,6 +54,7 @@ def outer_product(a, b):
class TestHybrid(Cell):
"""Net definition"""
def __init__(self):
super(TestHybrid, self).__init__()
@ -65,12 +67,7 @@ class TestHybrid(Cell):
return self.program(x, y)
def test_hybrid():
"""
Feature: ALL To ALL
Description: hybrid test cases.
Expectation: the result match with numpy result
"""
def hybrid_case():
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
@ -82,24 +79,58 @@ def test_hybrid():
raise ValueError("Precision error, compare result: {}".format(compare_res))
def test_hybrid_ascend():
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_hybrid_ascend_graph_mode():
"""
Feature: ALL To ALL
Description: hybrid ascend test cases.
Feature: test case for Custom op with func_type="akg"
Description: ascend test case, akg dsl using hybrid grammar in GRAPH_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_hybrid()
hybrid_case()
def test_hybrid_gpu():
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_hybrid_ascend_pynative_mode():
"""
Feature: ALL To ALL
Description: hybrid gpu test cases.
Feature: test case for Custom op with func_type="akg"
Description: ascend test case, akg dsl using hybrid grammar in PYNATIVE_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
hybrid_case()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_hybrid_gpu_graph_mode():
"""
Feature: test case for Custom op with func_type="akg"
Description: gpu test case, akg dsl using hybrid grammar in GRAPH_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_hybrid()
hybrid_case()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_hybrid_gpu_pynative_mode():
"""
Feature: test case for Custom op with func_type="akg"
Description: gpu test case, akg dsl using hybrid grammar in PYNATIVE_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
hybrid_case()
v_add_ascend_info = CustomRegOp() \
@ -135,6 +166,7 @@ def v_add(inputs, attrs):
class TestIRbuilder(Cell):
"""Net definition"""
def __init__(self, shape):
super(TestIRbuilder, self).__init__()
self.program = Custom(v_add, out_shape=shape, out_dtype=mstype.float16, func_type="akg")
@ -143,12 +175,7 @@ class TestIRbuilder(Cell):
return self.program([x, y])
def test_irbuider():
"""
Feature: ALL To ALL
Description: irbuider test cases.
Expectation: the result match with numpy result
"""
def irbuider_case():
shape = (4, 5)
input_x = np.random.normal(0, 1, shape).astype(np.float16)
input_y = np.random.normal(0, 1, shape).astype(np.float16)
@ -160,21 +187,55 @@ def test_irbuider():
raise ValueError("Precision error, compare result: {}".format(compare_res))
def test_irbuider_ascend():
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_irbuider_ascend_graph_mode():
"""
Feature: ALL To ALL
Description: irbuider ascend test cases.
Feature: test case for Custom op with func_type="akg"
Description: ascend test case, akg dsl using irbuider grammar in GRAPH_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_irbuider()
irbuider_case()
def test_irbuider_gpu():
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_irbuider_ascend_pynative_mode():
"""
Feature: ALL To ALL
Description: irbuider gpu test cases.
Feature: test case for Custom op with func_type="akg"
Description: ascend test case, akg dsl using irbuider grammar in PYNATIVE_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
irbuider_case()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_irbuider_gpu_graph_mode():
"""
Feature: test case for Custom op with func_type="akg"
Description: gpu test case, akg dsl using irbuider grammar in GRAPH_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_irbuider()
irbuider_case()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_irbuider_gpu_pynative_mode():
"""
Feature: test case for Custom op with func_type="akg"
Description: gpu test case, akg dsl using irbuider grammar in PYNATIVE_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
irbuider_case()

View File

@ -21,6 +21,7 @@ from mindspore.nn import Cell
from mindspore.ops import operations as P
from mindspore.ops.op_info_register import TBERegOp, DataType
from mindspore.ops.operations.custom_ops import Custom, CustomRegOp, custom_op_info_register
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
square_with_bias_op_info = CustomRegOp() \
.fusion_type("OPAQUE") \
@ -158,17 +159,7 @@ class Net1(Cell):
return res
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_multi_input_multi_output_with_attr():
"""
Feature: ALL To ALL
Description: test cases for Custom op with multiple inputs, outputs and attr.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def multi_input_multi_output_with_attr():
dtype = np.float32
x = np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]).astype(dtype)
expect0 = np.array([[9.0, 9.0, 9.0], [441.0, 441.0, 441.0]]).astype(dtype)
@ -196,10 +187,38 @@ def test_multi_input_multi_output_with_attr():
raise ValueError("Precision error, compare result: {}".format(compare_res))
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_net1_graph_mode():
"""
Feature: test case for Custom op with func_type="tbe"
Description: test cases with multiple inputs, outputs and attr in GRAPH_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
multi_input_multi_output_with_attr()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_net1_pynative_mode():
"""
Feature: test case for Custom op with func_type="tbe"
Description: test cases with multiple inputs, outputs and attr in PYNATIVE_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
multi_input_multi_output_with_attr()
def bprop(data, axis, out, dout):
gradient = data * 2
dx = gradient * dout
return (dx,)
return dx, zeros_like(axis)
class Net2(Cell):
@ -215,17 +234,7 @@ class Net2(Cell):
return res
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_bprop():
"""
Feature: ALL To ALL
Description: test cases for bprop function of Custom op.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def grad_case():
x = np.array([1.0, 4.0, 9.0]).astype(np.float32)
sens = np.array([1.0, 1.0, 1.0]).astype(np.float32)
expect = np.array([2.0, 8.0, 18.0]).astype(np.float32)
@ -237,3 +246,31 @@ def test_bprop():
compare_res = np.allclose(expect, dx_np, 0.0001, 0.0001)
if not compare_res:
raise ValueError("Precision error, compare result: {}".format(compare_res))
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_net2_graph_mode():
"""
Feature: test case for Custom op with func_type="tbe"
Description: grad test case in GRAPH_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
grad_case()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_net2_pynative_mode():
"""
Feature: test case for Custom op with func_type="tbe"
Description: grad test case in PYNATIVE_MODE.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
grad_case()

View File

@ -1,135 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import inspect
import numpy as np
import pytest
from mindspore import context, ops, Tensor
from mindspore.common import dtype as mstype
from mindspore.nn import Cell
class UserDefined(ops.PrimitiveWithInfer):
def __init__(self, func, shape, dtype, func_type=None):
ops.PrimitiveWithInfer.__init__(self, "UserDefined")
self.add_prim_attr('akg', True)
if "__wrapped__" in func.__dict__:
func = func.__dict__["__wrapped__"]
func_name = func.__name__
self.add_prim_attr('func_name', func_name)
func_source_str = inspect.getsource(func)
if func_type is None:
if "ir_builder" in func_source_str:
func_type = "ir_builder"
elif "compute" in func_source_str:
func_type = "tvm_compute"
else:
func_type = "hybrid"
self.add_prim_attr('func_source_str', func_source_str)
self.add_prim_attr('func_type', func_type)
self._shape = shape
self._dtype = dtype
def infer_shape(self, *args):
if callable(self._shape):
return self._shape(*args)
return self._shape
def infer_dtype(self, *args):
if callable(self._dtype):
return self._dtype(*args)
return self._dtype
def outer_product(a, b):
c = output_tensor((a.shape[0], b.shape[1]), 'float32')
for i0 in range(a.shape[0]):
for i1 in range(b.shape[1]):
c[i0, i1] = 0.0
for i2 in range(a.shape[1]):
c[i0, i1] = c[i0, i1] + (a[i0, i2] * b[i2, i1])
return c
class TestHybrid(Cell):
def __init__(self):
super(TestHybrid, self).__init__()
def infer_func(x, y):
return x
self.program = UserDefined(
outer_product, shape=infer_func, dtype=infer_func)
def construct(self, x, y):
return self.program(x, y)
def v_add(inputs, attrs):
def vadd_func(dst, data_1, data_2):
ib = tvm.ir_builder.create()
with ib.for_range_n(data_1.shape, "i") as i:
ib.store(dst, i, ib.load(data_1, i) + ib.load(data_2, i))
return ib.get()
data_1, data_2 = inputs[0], inputs[1]
return tvm.extern(data_1.shape, [data_1, data_2],
lambda ins, outs: vadd_func(outs[0], ins[0], ins[1]),
name="v_add", dtype=data_1.dtype)
class TestIRbuilder(Cell):
def __init__(self, shape):
super(TestIRbuilder, self).__init__()
self.program = UserDefined(
v_add, shape=shape, dtype=mstype.float16)
def construct(self, x, y):
return self.program(x, y)
def test_user_defined_hybrid():
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
test = TestHybrid()
output = test(Tensor(input_x), Tensor(input_y))
expect = np.matmul(input_x, input_y)
assert np.allclose(expect, output.asnumpy(), 0.001, 0.001)
def test_user_defined_irbuider():
shape = (4, 5)
input_x = np.random.normal(0, 1, shape).astype(np.float16)
input_y = np.random.normal(0, 1, shape).astype(np.float16)
test = TestIRbuilder(shape)
output = test(Tensor(input_x), Tensor(input_y))
assert np.allclose(input_x + input_y, output.asnumpy(), 0.001, 0.001)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_user_defined_gpu():
context.set_context(mode=0, enable_graph_kernel=True)
test_user_defined_hybrid()
test_user_defined_irbuider()