pass lambda function to out_shape and out_dtype
This commit is contained in:
parent
bf0142ae4b
commit
ae5532fb9e
6
build.sh
6
build.sh
|
@ -46,7 +46,11 @@ update_submodule()
|
|||
git submodule update --init metadef
|
||||
cd "${BASEPATH}"
|
||||
if [[ "X$ENABLE_AKG" = "Xon" ]]; then
|
||||
git submodule update --init --recursive akg
|
||||
if [[ "X$ENABLE_D" == "Xon" ]]; then
|
||||
git submodule update --init akg
|
||||
else
|
||||
GIT_LFS_SKIP_SMUDGE=1 git submodule update --init akg
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
|
|
|
@ -1,18 +0,0 @@
|
|||
if(ENABLE_GITEE)
|
||||
set(REQ_URL "https://gitee.com/mirrors/incubator-tvm/repository/archive/v0.6.0.tar.gz")
|
||||
set(MD5 "7b22965745cf1c6208a4e367fb86a585")
|
||||
else()
|
||||
set(REQ_URL
|
||||
"https://github.com/apache/incubator-tvm/release/download/v0.6.0/apache-tvm-src-v0.6.0-incubating.tar.gz")
|
||||
set(MD5 "2d77a005f0046d937b99c67de82f6438")
|
||||
endif()
|
||||
set(incubator_tvm_predict_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
set(incubator_tvm_predict_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
mindspore_add_pkg(incubator_tvm_predict
|
||||
VER 0.6.0
|
||||
HEAD_ONLY ./
|
||||
URL ${REQ_URL}
|
||||
MD5 ${MD5}
|
||||
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/predict/0001-RetBugFix-CustomRuntime_v06.patch)
|
||||
include_directories(${incubator_tvm_predict_INC})
|
||||
add_library(mindspore::incubator_tvm_predict ALIAS incubator_tvm_predict)
|
|
@ -215,7 +215,7 @@ class OpInfoExtractor {
|
|||
}
|
||||
}
|
||||
if (op_attr->type().empty()) {
|
||||
MS_LOG(WARNING) << "Unknown type, ignore attr: " << name;
|
||||
MS_LOG(DEBUG) << "Unknown type, ignore attr: " << name;
|
||||
continue;
|
||||
}
|
||||
op_info->add_attrs_ptr(op_attr);
|
||||
|
|
|
@ -97,10 +97,15 @@ void AddMissingAttrs(const CNodePtr &cnode, kernel::OpImplyType imply_type,
|
|||
bool need_update = false;
|
||||
for (const auto &attr : all_attrs) {
|
||||
auto attr_name = attr->name();
|
||||
if (missing_attrs.find(attr_name) == missing_attrs.end() || attr->param_type() != "required") {
|
||||
if (missing_attrs.find(attr_name) == missing_attrs.end()) {
|
||||
continue;
|
||||
}
|
||||
// If attr's param_type is required, it should have default value.
|
||||
// If attr have default value, we should parse it no matter whether its param_type is required or not.
|
||||
auto default_value = attr->default_value();
|
||||
if (default_value.empty() && attr->param_type() != "required") {
|
||||
continue;
|
||||
}
|
||||
if (default_value.empty()) {
|
||||
MS_LOG(EXCEPTION) << "attr [" << attr_name << "] in the registration information of op [" << op_name
|
||||
<< "] does not have a value." << trace::DumpSourceLines(cnode);
|
||||
|
@ -129,8 +134,8 @@ const AnfNodePtr CustomOpRegInfoToAttr::Process(const FuncGraphPtr &, const AnfN
|
|||
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto func_type = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFuncType);
|
||||
// Only AKG needs to process attr, TBE will process later in the json creating phase.
|
||||
if (kCustomTypeAkg.find(func_type) == kCustomTypeAkg.end()) {
|
||||
// AKG/AICPU need to process attr, TBE will process later in the json creating phase.
|
||||
if (kCustomTypeAkg.find(func_type) == kCustomTypeAkg.end() || func_type == kCustomTypeAICPU) {
|
||||
return nullptr;
|
||||
}
|
||||
// Early return if current node does not have attr
|
||||
|
@ -149,7 +154,9 @@ const AnfNodePtr CustomOpRegInfoToAttr::Process(const FuncGraphPtr &, const AnfN
|
|||
if (missing_attrs.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
AddMissingAttrs(cnode, kernel::OpImplyType::kAKG, missing_attrs);
|
||||
kernel::OpImplyType imply_type =
|
||||
func_type == kCustomTypeAICPU ? kernel::OpImplyType::kAICPU : kernel::OpImplyType::kAKG;
|
||||
AddMissingAttrs(cnode, imply_type, missing_attrs);
|
||||
|
||||
return node;
|
||||
}
|
||||
|
|
|
@ -92,7 +92,8 @@ def op_info_register(op_info):
|
|||
|
||||
def custom_info_register(*reg_info):
|
||||
r"""
|
||||
A decorator which is used to bind the registration information to the `func` parameter of `Custom` op.
|
||||
A decorator which is used to bind the registration information to the `func` parameter of
|
||||
:class:`mindspore.ops.Custom`.
|
||||
|
||||
Note:
|
||||
The 'reg_info' will be added into oplib.
|
||||
|
@ -645,7 +646,7 @@ class TBERegOp(RegOp):
|
|||
|
||||
class CustomRegOp(RegOp):
|
||||
r"""
|
||||
Class for `Custom` operator info register.
|
||||
Class used for generating the registration information for the `func` parameter of :class:`mindspore.ops.Custom`.
|
||||
|
||||
Args:
|
||||
op_name (str): kernel name. No need to set this value as `Custom` operator will generate a unique name
|
||||
|
|
|
@ -136,7 +136,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
represents the registration information of `func` in a specific target. You need to invoke `CustomRegOp`
|
||||
or the subclass of `RegOp` to generate the reg info for `func`. Then you can invoke
|
||||
`custom_info_register` to bind the reg info to `func` or just pass the reg info to `reg_info` parameter.
|
||||
The `reg_info` parameter takes higher priority then `custom_info_register` and the reg info in a
|
||||
The `reg_info` parameter takes higher priority than `custom_info_register` and the reg info in a
|
||||
specific target will be registered only once.
|
||||
|
||||
If reg info is not set, then we will infer the data types and formats from the inputs of `Custom` operator.
|
||||
|
@ -161,7 +161,6 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> import mindspore.ops as ops
|
||||
>>> from mindspore.ops import CustomRegOp, custom_info_register, DataType
|
||||
>>> from mindspore.common import dtype as mstype
|
||||
|
@ -169,7 +168,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
>>>
|
||||
>>> # Example, func_type = "akg"
|
||||
>>> def outer_product(a, b):
|
||||
... c = output_tensor((a.shape[0], b.shape[1]), 'float32')
|
||||
... c = output_tensor(a.shape, a.dtype)
|
||||
... for i0 in range(a.shape[0]):
|
||||
... for i1 in range(b.shape[1]):
|
||||
... c[i0, i1] = 0.0
|
||||
|
@ -202,19 +201,18 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
... def square_with_bias(input_x, output_y, bias=0.0, kernel_name="square_with_bias"):
|
||||
... import te.lang.cce
|
||||
... from te import tvm
|
||||
... from topi import generic
|
||||
... from topi.cce import util
|
||||
...
|
||||
... shape = input_x.get("shape")
|
||||
... dtype = input_x.get("dtype").lower()
|
||||
...
|
||||
... shape = util.shape_refine(shape)
|
||||
... data = tvm.placeholder(shape, name="data", dtype=dtype.lower())
|
||||
... data = tvm.placeholder(shape, name="data", dtype=dtype)
|
||||
...
|
||||
... with tvm.target.cce():
|
||||
... res0 = te.lang.cce.vmul(data, data)
|
||||
... res = te.lang.cce.vadds(res0, bias)
|
||||
... sch = generic.auto_schedule(res)
|
||||
... sch = te.lang.cce.auto_schedule(res)
|
||||
...
|
||||
... config = {"print_ir": False,
|
||||
... "name": kernel_name,
|
||||
|
@ -225,8 +223,8 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
>>> class TbeNet(Cell):
|
||||
... def __init__(self):
|
||||
... super(TbeNet, self).__init__()
|
||||
... self.square_with_bias = ops.Custom(square_with_bias, out_shape=[2, 3], out_dtype=mstype.float32, \
|
||||
... func_type="tbe")
|
||||
... self.square_with_bias = ops.Custom(square_with_bias, out_shape=lambda x, _: x, \
|
||||
... out_dtype=lambda x, _: x, func_type="tbe")
|
||||
... def construct(self, x):
|
||||
... res = self.square_with_bias(x, 1.0)
|
||||
... return res
|
||||
|
@ -258,9 +256,9 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
>>>
|
||||
>>> # Example, func_type = "aot"
|
||||
>>> class AOTSingleOutputNet(Cell):
|
||||
... def __init__(self, func, out_shapes, out_types, reg=None):
|
||||
... def __init__(self, out_shapes, out_types):
|
||||
... super(AOTSingleOutputNet, self).__init__()
|
||||
... self.program = ops.Custom("./reorganize.so:CustomReorganize", (2, 3), mstype.float32, "aot")
|
||||
... self.program = ops.Custom("./reorganize.so:CustomReorganize", out_shapes, out_types, "aot")
|
||||
... def construct(self, x, y):
|
||||
... return self.program(x, y)
|
||||
>>>
|
||||
|
@ -269,9 +267,9 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
... return (x1 + x2), (x1 - x2)
|
||||
>>>
|
||||
>>> class PyFuncNet(Cell):
|
||||
... def __init__(self, fn, out_shapes, out_types):
|
||||
... super().__init__()
|
||||
... self.func = ops.Custom(func_multi_output, ((2, 3), (2, 3)), (ms.float32, ms.float32), "pyfunc")
|
||||
... def __init__(self):
|
||||
... super(PyFuncNet, self).__init__()
|
||||
... self.func = ops.Custom(func_multi_output, lambda x, _: (x, x), lambda x, _: (x, x), "pyfunc")
|
||||
... def construct(self, x1, x2):
|
||||
... return self.func(x1, x2)
|
||||
"""
|
||||
|
|
|
@ -15,16 +15,14 @@
|
|||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops import DataType, CustomRegOp, custom_info_register
|
||||
|
||||
|
||||
def outer_product(a, b):
|
||||
c = output_tensor((a.shape[0], b.shape[1]), 'float32')
|
||||
c = output_tensor(a.shape, a.dtype)
|
||||
|
||||
for i0 in range(a.shape[0]):
|
||||
for i1 in range(b.shape[1]):
|
||||
|
@ -35,8 +33,8 @@ def outer_product(a, b):
|
|||
|
||||
|
||||
def cube(a):
|
||||
c = output_tensor((a.shape[0], a.shape[1]), 'float32')
|
||||
b = allocate((a.shape[0], a.shape[1]), 'float32', 'local')
|
||||
c = output_tensor(a.shape, a.dtype)
|
||||
b = allocate(a.shape, a.dtype, 'local')
|
||||
|
||||
for i0 in range(a.shape[0]):
|
||||
for i1 in range(a.shape[1]):
|
||||
|
@ -49,10 +47,10 @@ def cube(a):
|
|||
class TestHybridTwoInputs(Cell):
|
||||
"""Net definition"""
|
||||
|
||||
def __init__(self, func, shapes, types):
|
||||
def __init__(self, func, out_shape, out_dtype):
|
||||
super(TestHybridTwoInputs, self).__init__()
|
||||
|
||||
self.program = ops.Custom(func, out_shape=shapes, out_dtype=types, func_type="akg")
|
||||
self.program = ops.Custom(func, out_shape=out_shape, out_dtype=out_dtype, func_type="akg")
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.program(x, y)
|
||||
|
@ -61,10 +59,10 @@ class TestHybridTwoInputs(Cell):
|
|||
class TestHybridOneInput(Cell):
|
||||
"""Net definition"""
|
||||
|
||||
def __init__(self, func, shapes, types):
|
||||
def __init__(self, func, out_shape, out_dtype):
|
||||
super(TestHybridOneInput, self).__init__()
|
||||
|
||||
self.program = ops.Custom(func, out_shape=shapes, out_dtype=types, func_type="akg")
|
||||
self.program = ops.Custom(func, out_shape=out_shape, out_dtype=out_dtype, func_type="akg")
|
||||
|
||||
def construct(self, x):
|
||||
return self.program(x)
|
||||
|
@ -96,7 +94,7 @@ def hybrid_outer_product():
|
|||
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
|
||||
test = TestHybridTwoInputs(outer_product, (4, 4), (ms.float32))
|
||||
test = TestHybridTwoInputs(outer_product, lambda x, _: x, lambda x, _: x)
|
||||
output = test(Tensor(input_x), Tensor(input_y))
|
||||
expect = np.matmul(input_x, input_y)
|
||||
compare_res = np.allclose(expect, output.asnumpy(), 0.001, 0.001)
|
||||
|
@ -109,7 +107,7 @@ def hybrid_outer_product_autodiff():
|
|||
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
sens = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
|
||||
test = TestHybridTwoInputs(outer_product, (4, 4), (ms.float32))
|
||||
test = TestHybridTwoInputs(outer_product, lambda x, _: x, lambda x, _: x)
|
||||
net = MatMulNN()
|
||||
dx, dy = ops.GradOperation(sens_param=True, get_all=True)(test)(Tensor(input_x), Tensor(input_y), Tensor(sens))
|
||||
edx, edy = ops.GradOperation(sens_param=True, get_all=True)(net)(Tensor(input_x), Tensor(input_y), Tensor(sens))
|
||||
|
@ -123,7 +121,7 @@ def hybrid_pow_autodiff():
|
|||
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
sens = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
|
||||
test = TestHybridOneInput(cube, (4, 4), (ms.float32))
|
||||
test = TestHybridOneInput(cube, lambda x: x, lambda x: x)
|
||||
net = PowNN()
|
||||
dx = ops.GradOperation(sens_param=True)(test)(Tensor(input_x), Tensor(sens))
|
||||
edx = ops.GradOperation(sens_param=True)(net)(Tensor(input_x), Tensor(sens))
|
||||
|
@ -222,9 +220,9 @@ def v_add(inputs, attrs):
|
|||
class TestIRbuilder(Cell):
|
||||
"""Net definition"""
|
||||
|
||||
def __init__(self, shape):
|
||||
def __init__(self):
|
||||
super(TestIRbuilder, self).__init__()
|
||||
self.program = ops.Custom(v_add, out_shape=shape, out_dtype=mstype.float16, func_type="akg")
|
||||
self.program = ops.Custom(v_add, out_shape=lambda x: x[0], out_dtype=lambda x: x[0], func_type="akg")
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.program([x, y])
|
||||
|
@ -235,7 +233,7 @@ def irbuilder_case():
|
|||
input_x = np.random.normal(0, 1, shape).astype(np.float16)
|
||||
input_y = np.random.normal(0, 1, shape).astype(np.float16)
|
||||
|
||||
test = TestIRbuilder(shape)
|
||||
test = TestIRbuilder()
|
||||
output = test(Tensor(input_x), Tensor(input_y))
|
||||
compare_res = np.allclose(input_x + input_y, output.asnumpy(), 0.001, 0.001)
|
||||
if not compare_res:
|
||||
|
|
|
@ -16,39 +16,35 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops import TBERegOp, DataType, CustomRegOp, custom_info_register
|
||||
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
|
||||
square_with_bias_op_info = CustomRegOp() \
|
||||
.attr("bias", "required", "float") \
|
||||
.input(0, "x") \
|
||||
.output(0, "y") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.target("Ascend") \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@custom_info_register(square_with_bias_op_info)
|
||||
@custom_info_register(CustomRegOp() \
|
||||
.attr("bias", "required", "float") \
|
||||
.input(0, "x") \
|
||||
.output(0, "y") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.target("Ascend") \
|
||||
.get_op_info())
|
||||
def square_with_bias(input_x, output_y, bias=0.0, kernel_name="square_with_bias"):
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
from topi import generic
|
||||
from topi.cce import util
|
||||
|
||||
shape = input_x.get("shape")
|
||||
dtype = input_x.get("dtype").lower()
|
||||
|
||||
shape = util.shape_refine(shape)
|
||||
data = tvm.placeholder(shape, name="data", dtype=dtype.lower())
|
||||
data = tvm.placeholder(shape, name="data", dtype=dtype)
|
||||
|
||||
with tvm.target.cce():
|
||||
res0 = te.lang.cce.vmul(data, data)
|
||||
res = te.lang.cce.vadds(res0, bias)
|
||||
sch = generic.auto_schedule(res)
|
||||
sch = te.lang.cce.auto_schedule(res)
|
||||
|
||||
config = {"print_ir": False,
|
||||
"name": kernel_name,
|
||||
|
@ -57,18 +53,15 @@ def square_with_bias(input_x, output_y, bias=0.0, kernel_name="square_with_bias"
|
|||
te.lang.cce.cce_build_code(sch, config)
|
||||
|
||||
|
||||
square_with_bias_v2_op_info = CustomRegOp() \
|
||||
.attr("bias", "required", "float") \
|
||||
.input(0, "input_x") \
|
||||
.output(0, "output1") \
|
||||
.output(1, "output2") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.target("Ascend") \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@custom_info_register(square_with_bias_v2_op_info)
|
||||
@custom_info_register(CustomRegOp() \
|
||||
.attr("bias", "required", "float") \
|
||||
.input(0, "input_x") \
|
||||
.output(0, "output1") \
|
||||
.output(1, "output2") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.target("Ascend") \
|
||||
.get_op_info())
|
||||
def square_with_bias_v2(input_x, output1, output2, bias=0.0, kernel_name="square_with_bias_v2"):
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
|
@ -76,7 +69,7 @@ def square_with_bias_v2(input_x, output1, output2, bias=0.0, kernel_name="square
|
|||
shape = input_x.get("shape")
|
||||
dtype = input_x.get("dtype").lower()
|
||||
|
||||
data = tvm.placeholder(shape, name="data", dtype=dtype.lower())
|
||||
data = tvm.placeholder(shape, name="data", dtype=dtype)
|
||||
|
||||
res0 = te.lang.cce.vmul(data, data)
|
||||
res1 = te.lang.cce.vadds(res0, bias)
|
||||
|
@ -113,7 +106,7 @@ def add_n_with_bias(inputs, output, bias, kernel_name="add_n_with_bias"):
|
|||
for i, d in enumerate(inputs):
|
||||
shape = d.get("shape")
|
||||
dtype = d.get("dtype").lower()
|
||||
data.append(tvm.placeholder(shape, name="input_" + str(i), dtype=dtype.lower()))
|
||||
data.append(tvm.placeholder(shape, name="input_" + str(i), dtype=dtype))
|
||||
|
||||
res = data[0]
|
||||
for i in range(1, len(data)):
|
||||
|
@ -138,14 +131,13 @@ class Net1(Cell):
|
|||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
# TBE dsl with attr
|
||||
self.square_with_bias = ops.Custom(square_with_bias, out_shape=[2, 3], out_dtype=mstype.float32,
|
||||
func_type="tbe")
|
||||
self.square_with_bias = ops.Custom(square_with_bias, lambda x, _: x, lambda x, _: x, func_type="tbe")
|
||||
# TBE dsl with multiple inputs and attr
|
||||
self.add_n_with_bias = ops.Custom(add_n_with_bias, out_shape=[2, 3], out_dtype=mstype.float32, func_type="tbe",
|
||||
self.add_n_with_bias = ops.Custom(add_n_with_bias, lambda x, _: x[0], lambda x, _: x[0], func_type="tbe",
|
||||
reg_info=add_n_with_bias_op_info)
|
||||
# TBE dsl with multiple outputs and attr
|
||||
self.square_with_bias_v2 = ops.Custom(square_with_bias_v2, out_shape=([2, 3], [2, 3]),
|
||||
out_dtype=(mstype.float32, mstype.float32), func_type="tbe")
|
||||
self.square_with_bias_v2 = ops.Custom(square_with_bias_v2, lambda x, _: (x, x), lambda x, _: (x, x),
|
||||
func_type="tbe")
|
||||
self.neg = ops.Neg()
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -224,7 +216,7 @@ class Net2(Cell):
|
|||
|
||||
def __init__(self, bprop_func):
|
||||
super(Net2, self).__init__()
|
||||
self.square_with_bias = ops.Custom(square_with_bias, out_shape=[3], out_dtype=mstype.float32, bprop=bprop_func,
|
||||
self.square_with_bias = ops.Custom(square_with_bias, lambda x, _: x, lambda x, _: x, bprop=bprop_func,
|
||||
func_type="tbe")
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -276,28 +268,25 @@ def test_net2_pynative_mode():
|
|||
grad_case(bprop)
|
||||
|
||||
|
||||
square_with_bias_grad_info = CustomRegOp() \
|
||||
.input(0, "x") \
|
||||
.input(1, "dout") \
|
||||
.output(0, "y") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.target("Ascend") \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@custom_info_register(square_with_bias_grad_info)
|
||||
@custom_info_register(CustomRegOp() \
|
||||
.input(0, "x") \
|
||||
.input(1, "dout") \
|
||||
.output(0, "y") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.target("Ascend") \
|
||||
.get_op_info())
|
||||
def square_with_bias_grad(input_x, dout, output_y, kernel_name="square_with_bias_grad"):
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
|
||||
shape1 = input_x.get("shape")
|
||||
dtype1 = input_x.get("dtype").lower()
|
||||
data1 = tvm.placeholder(shape1, name="data1", dtype=dtype1.lower())
|
||||
data1 = tvm.placeholder(shape1, name="data1", dtype=dtype1)
|
||||
|
||||
shape2 = dout.get("shape")
|
||||
dtype2 = dout.get("dtype").lower()
|
||||
data2 = tvm.placeholder(shape2, name="data2", dtype=dtype2.lower())
|
||||
data2 = tvm.placeholder(shape2, name="data2", dtype=dtype2)
|
||||
|
||||
res0 = te.lang.cce.vmuls(data1, 2.0)
|
||||
res = te.lang.cce.vmul(res0, data2)
|
||||
|
@ -312,7 +301,7 @@ def square_with_bias_grad(input_x, dout, output_y, kernel_name="square_with_bias
|
|||
|
||||
|
||||
def bprop1():
|
||||
op = ops.Custom(square_with_bias_grad, [3], mstype.float32, func_type="tbe")
|
||||
op = ops.Custom(square_with_bias_grad, lambda x, _: x, lambda x, _: x, func_type="tbe")
|
||||
|
||||
def custom_bprop(data, axis, out, dout):
|
||||
dx = op(data, dout)
|
||||
|
|
Loading…
Reference in New Issue