forked from mindspore-Ecosystem/mindspore
apply autodiff in custom
This commit is contained in:
parent
95f7c61d2e
commit
5acce31c72
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit b3f67ab15c03ab7b201152d41cdfad7c05294609
|
||||
Subproject commit 3a7415a9a4184e3d8469b1d19109c03a32843871
|
|
@ -26,6 +26,7 @@
|
|||
#include "abstract/utils.h"
|
||||
#include "runtime/device/cpu/cpu_common.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -40,7 +41,12 @@ CustomAOTCpuKernel::~CustomAOTCpuKernel() {
|
|||
void CustomAOTCpuKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
const auto &exec_info = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "func_name");
|
||||
if (auto pos = exec_info.find(":"); pos != std::string::npos) {
|
||||
file_path_ = exec_info.substr(0, pos);
|
||||
auto path = exec_info.substr(0, pos);
|
||||
auto real_path = FileUtils::GetRealPath(path.c_str());
|
||||
if (!real_path.has_value()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid file path, " << path << " does not exist.";
|
||||
}
|
||||
file_path_ = real_path.value();
|
||||
func_name_ = exec_info.substr(pos + 1);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Wrong execute info:" << exec_info;
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -105,7 +106,13 @@ class CustomAOTGpuKernel : public GpuKernel {
|
|||
bool Init(const CNodePtr &kernel_node) override {
|
||||
const auto &exec_info = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "func_name");
|
||||
if (auto pos = exec_info.find(":"); pos != std::string::npos) {
|
||||
file_path_ = exec_info.substr(0, pos);
|
||||
auto path = exec_info.substr(0, pos);
|
||||
auto real_path = FileUtils::GetRealPath(path.c_str());
|
||||
if (!real_path.has_value()) {
|
||||
MS_LOG(ERROR) << "Invalid file path, " << path << " does not exist.";
|
||||
return false;
|
||||
}
|
||||
file_path_ = real_path.value();
|
||||
func_name_ = exec_info.substr(pos + 1);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Wrong execute info:" << exec_info;
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Register bprop function for Custom Hybrid Autodiff"""
|
||||
|
||||
from collections import UserDict
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
class Registry(UserDict):
|
||||
"""
|
||||
Registry class inherits from UserDict.
|
||||
Key: length of signatures
|
||||
Value : bprop function for custom hybrid op
|
||||
"""
|
||||
|
||||
def register(self, sig_number):
|
||||
def deco(fn):
|
||||
self[sig_number] = fn
|
||||
return fn
|
||||
return deco
|
||||
|
||||
def get(self, sig_number):
|
||||
if sig_number not in self:
|
||||
logger.error(f"Autodiff currently doesn't support hyrbrid function with input num :{sig_number}. \
|
||||
Supported input num is from 1 to 10")
|
||||
fn = self[sig_number]
|
||||
return fn
|
||||
|
||||
|
||||
bprop_factory = Registry()
|
||||
|
||||
|
||||
def autodiff_bprop(n):
|
||||
return bprop_factory.get(n)
|
||||
|
||||
|
||||
def get_outs(out, dout):
|
||||
if isinstance(out, tuple):
|
||||
tupleout = out
|
||||
else:
|
||||
tupleout = (out,)
|
||||
if isinstance(dout, tuple):
|
||||
tupledout = dout
|
||||
else:
|
||||
tupledout = (dout,)
|
||||
return tupleout + tupledout
|
||||
|
||||
|
||||
@bprop_factory.register(1)
|
||||
def bprop_one(op):
|
||||
def bprop(x1, out, dout):
|
||||
inputs = (x1,) + get_outs(out, dout)
|
||||
res = op(*inputs)
|
||||
return res
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_factory.register(2)
|
||||
def bprop_two(op):
|
||||
def bprop(x1, x2, out, dout):
|
||||
inputs = (x1, x2) + get_outs(out, dout)
|
||||
res = op(*inputs)
|
||||
return res
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_factory.register(3)
|
||||
def bprop_three(op):
|
||||
def bprop(x1, x2, x3, out, dout):
|
||||
inputs = (x1, x2, x3) + get_outs(out, dout)
|
||||
res = op(*inputs)
|
||||
return res
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_factory.register(4)
|
||||
def bprop_four(op):
|
||||
def bprop(x1, x2, x3, x4, out, dout):
|
||||
inputs = (x1, x2, x3, x4) + get_outs(out, dout)
|
||||
res = op(*inputs)
|
||||
return res
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_factory.register(5)
|
||||
def bprop_five(op):
|
||||
def bprop(x1, x2, x3, x4, x5, out, dout):
|
||||
inputs = (x1, x2, x3, x4, x5) + get_outs(out, dout)
|
||||
res = op(*inputs)
|
||||
return res
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_factory.register(6)
|
||||
def bprop_six(op):
|
||||
def bprop(x1, x2, x3, x4, x5, x6, out, dout):
|
||||
inputs = (x1, x2, x3, x4, x5, x6) + get_outs(out, dout)
|
||||
res = op(*inputs)
|
||||
return res
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_factory.register(7)
|
||||
def bprop_seven(op):
|
||||
def bprop(x1, x2, x3, x4, x5, x6, x7, out, dout):
|
||||
inputs = (x1, x2, x3, x4, x5, x6, x7) + get_outs(out, dout)
|
||||
res = op(*inputs)
|
||||
return res
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_factory.register(8)
|
||||
def bprop_eight(op):
|
||||
def bprop(x1, x2, x3, x4, x5, x6, x7, x8, out, dout):
|
||||
inputs = (x1, x2, x3, x4, x5, x6, x7, x8) + get_outs(out, dout)
|
||||
res = op(*inputs)
|
||||
return res
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_factory.register(9)
|
||||
def bprop_nine(op):
|
||||
def bprop(x1, x2, x3, x4, x5, x6, x7, x8, x9, out, dout):
|
||||
inputs = (x1, x2, x3, x4, x5, x6, x7, x8, x9) + get_outs(out, dout)
|
||||
res = op(*inputs)
|
||||
return res
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_factory.register(10)
|
||||
def bprop_ten(op):
|
||||
def bprop(x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, out, dout):
|
||||
inputs = (x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) + get_outs(out, dout)
|
||||
res = op(*inputs)
|
||||
return res
|
||||
return bprop
|
|
@ -24,6 +24,7 @@ from mindspore.ops import DataType
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore._c_expression import Oplib
|
||||
from ._pyfunc_registry import add_pyfunc
|
||||
from ._custom_grad import autodiff_bprop
|
||||
|
||||
|
||||
class Custom(ops.PrimitiveWithInfer):
|
||||
|
@ -297,6 +298,8 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
self.func_type = "tvm_compute"
|
||||
else:
|
||||
self.func_type = "hybrid"
|
||||
if not self.bprop:
|
||||
self._hybrid_autodiff()
|
||||
self.add_prim_attr("func_type", self.func_type)
|
||||
self._update_attr()
|
||||
|
||||
|
@ -571,3 +574,25 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
self.add_prim_attr("primitive_target", "GPU")
|
||||
elif registered_targets == ["CPU"]:
|
||||
self.add_prim_attr("primitive_target", "CPU")
|
||||
if callable(self.func) and callable(self.out_shape):
|
||||
if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") == "autodiff":
|
||||
self.add_prim_attr("autodiff", True)
|
||||
else:
|
||||
self.add_prim_attr("autodiff", False)
|
||||
|
||||
def _hybrid_autodiff(self):
|
||||
"""generate backward op for a custom hybrid op"""
|
||||
inputs_num = len(inspect.signature(self.func).parameters)
|
||||
if inputs_num == 0:
|
||||
logger.warning("Function with no input has no backward op.")
|
||||
elif inputs_num > 10:
|
||||
logger.warning("Currently autodiff for function with more than 10 inputs is not supported.")
|
||||
else:
|
||||
grad_func = autodiff_bprop(inputs_num)
|
||||
|
||||
def infer_func(*args):
|
||||
return args[:inputs_num]
|
||||
setattr(infer_func, "type", "autodiff")
|
||||
op = Custom(func=self.func, out_shape=infer_func, out_dtype=infer_func,
|
||||
func_type="akg", bprop=True)
|
||||
self.bprop = grad_func(op)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
|
@ -33,26 +34,69 @@ def outer_product(a, b):
|
|||
return c
|
||||
|
||||
|
||||
class TestHybrid(Cell):
|
||||
def cube(a):
|
||||
c = output_tensor((a.shape[0], a.shape[1]), 'float32')
|
||||
b = allocate((a.shape[0], a.shape[1]), 'float32', 'local')
|
||||
|
||||
for i0 in range(a.shape[0]):
|
||||
for i1 in range(a.shape[1]):
|
||||
b[i0, i1] = a[i0, i1] * a[i0, i1]
|
||||
c[i0, i1] = b[i0, i1] * a[i0, i1]
|
||||
|
||||
return c
|
||||
|
||||
|
||||
class TestHybridTwoInputs(Cell):
|
||||
"""Net definition"""
|
||||
|
||||
def __init__(self):
|
||||
super(TestHybrid, self).__init__()
|
||||
def __init__(self, func, shapes, types):
|
||||
super(TestHybridTwoInputs, self).__init__()
|
||||
|
||||
def infer_func(x, y):
|
||||
return x
|
||||
|
||||
self.program = ops.Custom(outer_product, out_shape=infer_func, out_dtype=infer_func, func_type="akg")
|
||||
self.program = ops.Custom(func, out_shape=shapes, out_dtype=types, func_type="akg")
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.program(x, y)
|
||||
|
||||
|
||||
def hybrid_case():
|
||||
class TestHybridOneInput(Cell):
|
||||
"""Net definition"""
|
||||
|
||||
def __init__(self, func, shapes, types):
|
||||
super(TestHybridOneInput, self).__init__()
|
||||
|
||||
self.program = ops.Custom(func, out_shape=shapes, out_dtype=types, func_type="akg")
|
||||
|
||||
def construct(self, x):
|
||||
return self.program(x)
|
||||
|
||||
|
||||
class MatMulNN(Cell):
|
||||
"""Net definition"""
|
||||
|
||||
def __init__(self):
|
||||
super(MatMulNN, self).__init__()
|
||||
self.matmul = ops.MatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.matmul(x, y)
|
||||
|
||||
|
||||
class PowNN(Cell):
|
||||
"""Net definition"""
|
||||
|
||||
def __init__(self):
|
||||
super(PowNN, self).__init__()
|
||||
self.pow = ops.Pow()
|
||||
|
||||
def construct(self, x):
|
||||
return self.pow(x, 3)
|
||||
|
||||
|
||||
def hybrid_outer_product():
|
||||
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
|
||||
test = TestHybrid()
|
||||
test = TestHybridTwoInputs(outer_product, (4, 4), (ms.float32))
|
||||
output = test(Tensor(input_x), Tensor(input_y))
|
||||
expect = np.matmul(input_x, input_y)
|
||||
compare_res = np.allclose(expect, output.asnumpy(), 0.001, 0.001)
|
||||
|
@ -60,6 +104,34 @@ def hybrid_case():
|
|||
raise ValueError("Precision error, compare result: {}".format(compare_res))
|
||||
|
||||
|
||||
def hybrid_outer_product_autodiff():
|
||||
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
sens = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
|
||||
test = TestHybridTwoInputs(outer_product, (4, 4), (ms.float32))
|
||||
net = MatMulNN()
|
||||
dx, dy = ops.GradOperation(sens_param=True, get_all=True)(test)(Tensor(input_x), Tensor(input_y), Tensor(sens))
|
||||
edx, edy = ops.GradOperation(sens_param=True, get_all=True)(net)(Tensor(input_x), Tensor(input_y), Tensor(sens))
|
||||
compare_res = np.allclose(edx.asnumpy(), dx.asnumpy(), 0.001, 0.001)
|
||||
compare_res &= np.allclose(edy.asnumpy(), dy.asnumpy(), 0.001, 0.001)
|
||||
if not compare_res:
|
||||
raise ValueError("Precision error, compare result: {}".format(compare_res))
|
||||
|
||||
|
||||
def hybrid_pow_autodiff():
|
||||
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
sens = np.random.normal(0, 1, [4, 4]).astype(np.float32)
|
||||
|
||||
test = TestHybridOneInput(cube, (4, 4), (ms.float32))
|
||||
net = PowNN()
|
||||
dx = ops.GradOperation(sens_param=True)(test)(Tensor(input_x), Tensor(sens))
|
||||
edx = ops.GradOperation(sens_param=True)(net)(Tensor(input_x), Tensor(sens))
|
||||
compare_res = np.allclose(edx.asnumpy(), dx.asnumpy(), 0.001, 0.001)
|
||||
if not compare_res:
|
||||
raise ValueError("Precision error, compare result: {}".format(compare_res))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -71,7 +143,7 @@ def test_hybrid_ascend_graph_mode():
|
|||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
hybrid_case()
|
||||
hybrid_outer_product()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -85,7 +157,7 @@ def test_hybrid_ascend_pynative_mode():
|
|||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
hybrid_case()
|
||||
hybrid_outer_product()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -98,7 +170,9 @@ def test_hybrid_gpu_graph_mode():
|
|||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
hybrid_case()
|
||||
hybrid_outer_product()
|
||||
hybrid_outer_product_autodiff()
|
||||
hybrid_pow_autodiff()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -111,7 +185,9 @@ def test_hybrid_gpu_pynative_mode():
|
|||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
hybrid_case()
|
||||
hybrid_outer_product()
|
||||
hybrid_outer_product_autodiff()
|
||||
hybrid_pow_autodiff()
|
||||
|
||||
|
||||
v_add_ascend_info = CustomRegOp() \
|
||||
|
|
Loading…
Reference in New Issue