apply autodiff in custom

This commit is contained in:
Yang Jiao 2021-12-02 11:38:54 +08:00
parent 95f7c61d2e
commit 5acce31c72
6 changed files with 279 additions and 16 deletions

2
akg

@ -1 +1 @@
Subproject commit b3f67ab15c03ab7b201152d41cdfad7c05294609
Subproject commit 3a7415a9a4184e3d8469b1d19109c03a32843871

View File

@ -26,6 +26,7 @@
#include "abstract/utils.h"
#include "runtime/device/cpu/cpu_common.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "utils/file_utils.h"
namespace mindspore {
namespace kernel {
@ -40,7 +41,12 @@ CustomAOTCpuKernel::~CustomAOTCpuKernel() {
void CustomAOTCpuKernel::InitKernel(const CNodePtr &kernel_node) {
const auto &exec_info = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "func_name");
if (auto pos = exec_info.find(":"); pos != std::string::npos) {
file_path_ = exec_info.substr(0, pos);
auto path = exec_info.substr(0, pos);
auto real_path = FileUtils::GetRealPath(path.c_str());
if (!real_path.has_value()) {
MS_LOG(EXCEPTION) << "Invalid file path, " << path << " does not exist.";
}
file_path_ = real_path.value();
func_name_ = exec_info.substr(pos + 1);
} else {
MS_LOG(EXCEPTION) << "Wrong execute info:" << exec_info;

View File

@ -25,6 +25,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/common_utils.h"
#include "utils/file_utils.h"
namespace mindspore {
namespace kernel {
@ -105,7 +106,13 @@ class CustomAOTGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
const auto &exec_info = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "func_name");
if (auto pos = exec_info.find(":"); pos != std::string::npos) {
file_path_ = exec_info.substr(0, pos);
auto path = exec_info.substr(0, pos);
auto real_path = FileUtils::GetRealPath(path.c_str());
if (!real_path.has_value()) {
MS_LOG(ERROR) << "Invalid file path, " << path << " does not exist.";
return false;
}
file_path_ = real_path.value();
func_name_ = exec_info.substr(pos + 1);
} else {
MS_LOG(ERROR) << "Wrong execute info:" << exec_info;

View File

@ -0,0 +1,149 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Register bprop function for Custom Hybrid Autodiff"""
from collections import UserDict
from mindspore import log as logger
class Registry(UserDict):
"""
Registry class inherits from UserDict.
Key: length of signatures
Value : bprop function for custom hybrid op
"""
def register(self, sig_number):
def deco(fn):
self[sig_number] = fn
return fn
return deco
def get(self, sig_number):
if sig_number not in self:
logger.error(f"Autodiff currently doesn't support hyrbrid function with input num :{sig_number}. \
Supported input num is from 1 to 10")
fn = self[sig_number]
return fn
bprop_factory = Registry()
def autodiff_bprop(n):
return bprop_factory.get(n)
def get_outs(out, dout):
if isinstance(out, tuple):
tupleout = out
else:
tupleout = (out,)
if isinstance(dout, tuple):
tupledout = dout
else:
tupledout = (dout,)
return tupleout + tupledout
@bprop_factory.register(1)
def bprop_one(op):
def bprop(x1, out, dout):
inputs = (x1,) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(2)
def bprop_two(op):
def bprop(x1, x2, out, dout):
inputs = (x1, x2) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(3)
def bprop_three(op):
def bprop(x1, x2, x3, out, dout):
inputs = (x1, x2, x3) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(4)
def bprop_four(op):
def bprop(x1, x2, x3, x4, out, dout):
inputs = (x1, x2, x3, x4) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(5)
def bprop_five(op):
def bprop(x1, x2, x3, x4, x5, out, dout):
inputs = (x1, x2, x3, x4, x5) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(6)
def bprop_six(op):
def bprop(x1, x2, x3, x4, x5, x6, out, dout):
inputs = (x1, x2, x3, x4, x5, x6) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(7)
def bprop_seven(op):
def bprop(x1, x2, x3, x4, x5, x6, x7, out, dout):
inputs = (x1, x2, x3, x4, x5, x6, x7) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(8)
def bprop_eight(op):
def bprop(x1, x2, x3, x4, x5, x6, x7, x8, out, dout):
inputs = (x1, x2, x3, x4, x5, x6, x7, x8) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(9)
def bprop_nine(op):
def bprop(x1, x2, x3, x4, x5, x6, x7, x8, x9, out, dout):
inputs = (x1, x2, x3, x4, x5, x6, x7, x8, x9) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(10)
def bprop_ten(op):
def bprop(x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, out, dout):
inputs = (x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop

View File

@ -24,6 +24,7 @@ from mindspore.ops import DataType
from mindspore.common import dtype as mstype
from mindspore._c_expression import Oplib
from ._pyfunc_registry import add_pyfunc
from ._custom_grad import autodiff_bprop
class Custom(ops.PrimitiveWithInfer):
@ -297,6 +298,8 @@ class Custom(ops.PrimitiveWithInfer):
self.func_type = "tvm_compute"
else:
self.func_type = "hybrid"
if not self.bprop:
self._hybrid_autodiff()
self.add_prim_attr("func_type", self.func_type)
self._update_attr()
@ -571,3 +574,25 @@ class Custom(ops.PrimitiveWithInfer):
self.add_prim_attr("primitive_target", "GPU")
elif registered_targets == ["CPU"]:
self.add_prim_attr("primitive_target", "CPU")
if callable(self.func) and callable(self.out_shape):
if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") == "autodiff":
self.add_prim_attr("autodiff", True)
else:
self.add_prim_attr("autodiff", False)
def _hybrid_autodiff(self):
"""generate backward op for a custom hybrid op"""
inputs_num = len(inspect.signature(self.func).parameters)
if inputs_num == 0:
logger.warning("Function with no input has no backward op.")
elif inputs_num > 10:
logger.warning("Currently autodiff for function with more than 10 inputs is not supported.")
else:
grad_func = autodiff_bprop(inputs_num)
def infer_func(*args):
return args[:inputs_num]
setattr(infer_func, "type", "autodiff")
op = Custom(func=self.func, out_shape=infer_func, out_dtype=infer_func,
func_type="akg", bprop=True)
self.bprop = grad_func(op)

View File

@ -15,6 +15,7 @@
import pytest
import numpy as np
import mindspore as ms
from mindspore import context, Tensor
from mindspore.common import dtype as mstype
from mindspore.nn import Cell
@ -33,26 +34,69 @@ def outer_product(a, b):
return c
class TestHybrid(Cell):
def cube(a):
c = output_tensor((a.shape[0], a.shape[1]), 'float32')
b = allocate((a.shape[0], a.shape[1]), 'float32', 'local')
for i0 in range(a.shape[0]):
for i1 in range(a.shape[1]):
b[i0, i1] = a[i0, i1] * a[i0, i1]
c[i0, i1] = b[i0, i1] * a[i0, i1]
return c
class TestHybridTwoInputs(Cell):
"""Net definition"""
def __init__(self):
super(TestHybrid, self).__init__()
def __init__(self, func, shapes, types):
super(TestHybridTwoInputs, self).__init__()
def infer_func(x, y):
return x
self.program = ops.Custom(outer_product, out_shape=infer_func, out_dtype=infer_func, func_type="akg")
self.program = ops.Custom(func, out_shape=shapes, out_dtype=types, func_type="akg")
def construct(self, x, y):
return self.program(x, y)
def hybrid_case():
class TestHybridOneInput(Cell):
"""Net definition"""
def __init__(self, func, shapes, types):
super(TestHybridOneInput, self).__init__()
self.program = ops.Custom(func, out_shape=shapes, out_dtype=types, func_type="akg")
def construct(self, x):
return self.program(x)
class MatMulNN(Cell):
"""Net definition"""
def __init__(self):
super(MatMulNN, self).__init__()
self.matmul = ops.MatMul()
def construct(self, x, y):
return self.matmul(x, y)
class PowNN(Cell):
"""Net definition"""
def __init__(self):
super(PowNN, self).__init__()
self.pow = ops.Pow()
def construct(self, x):
return self.pow(x, 3)
def hybrid_outer_product():
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
test = TestHybrid()
test = TestHybridTwoInputs(outer_product, (4, 4), (ms.float32))
output = test(Tensor(input_x), Tensor(input_y))
expect = np.matmul(input_x, input_y)
compare_res = np.allclose(expect, output.asnumpy(), 0.001, 0.001)
@ -60,6 +104,34 @@ def hybrid_case():
raise ValueError("Precision error, compare result: {}".format(compare_res))
def hybrid_outer_product_autodiff():
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
sens = np.random.normal(0, 1, [4, 4]).astype(np.float32)
test = TestHybridTwoInputs(outer_product, (4, 4), (ms.float32))
net = MatMulNN()
dx, dy = ops.GradOperation(sens_param=True, get_all=True)(test)(Tensor(input_x), Tensor(input_y), Tensor(sens))
edx, edy = ops.GradOperation(sens_param=True, get_all=True)(net)(Tensor(input_x), Tensor(input_y), Tensor(sens))
compare_res = np.allclose(edx.asnumpy(), dx.asnumpy(), 0.001, 0.001)
compare_res &= np.allclose(edy.asnumpy(), dy.asnumpy(), 0.001, 0.001)
if not compare_res:
raise ValueError("Precision error, compare result: {}".format(compare_res))
def hybrid_pow_autodiff():
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
sens = np.random.normal(0, 1, [4, 4]).astype(np.float32)
test = TestHybridOneInput(cube, (4, 4), (ms.float32))
net = PowNN()
dx = ops.GradOperation(sens_param=True)(test)(Tensor(input_x), Tensor(sens))
edx = ops.GradOperation(sens_param=True)(net)(Tensor(input_x), Tensor(sens))
compare_res = np.allclose(edx.asnumpy(), dx.asnumpy(), 0.001, 0.001)
if not compare_res:
raise ValueError("Precision error, compare result: {}".format(compare_res))
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -71,7 +143,7 @@ def test_hybrid_ascend_graph_mode():
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
hybrid_case()
hybrid_outer_product()
@pytest.mark.level0
@ -85,7 +157,7 @@ def test_hybrid_ascend_pynative_mode():
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
hybrid_case()
hybrid_outer_product()
@pytest.mark.level0
@ -98,7 +170,9 @@ def test_hybrid_gpu_graph_mode():
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
hybrid_case()
hybrid_outer_product()
hybrid_outer_product_autodiff()
hybrid_pow_autodiff()
@pytest.mark.level0
@ -111,7 +185,9 @@ def test_hybrid_gpu_pynative_mode():
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
hybrid_case()
hybrid_outer_product()
hybrid_outer_product_autodiff()
hybrid_pow_autodiff()
v_add_ascend_info = CustomRegOp() \