apply autodiff in custom

This commit is contained in:
Yang Jiao 2021-12-02 11:38:54 +08:00
parent 95f7c61d2e
commit 5acce31c72
6 changed files with 279 additions and 16 deletions

2
akg

@ -1 +1 @@
Subproject commit b3f67ab15c03ab7b201152d41cdfad7c05294609 Subproject commit 3a7415a9a4184e3d8469b1d19109c03a32843871

View File

@ -26,6 +26,7 @@
#include "abstract/utils.h" #include "abstract/utils.h"
#include "runtime/device/cpu/cpu_common.h" #include "runtime/device/cpu/cpu_common.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "utils/file_utils.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
@ -40,7 +41,12 @@ CustomAOTCpuKernel::~CustomAOTCpuKernel() {
void CustomAOTCpuKernel::InitKernel(const CNodePtr &kernel_node) { void CustomAOTCpuKernel::InitKernel(const CNodePtr &kernel_node) {
const auto &exec_info = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "func_name"); const auto &exec_info = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "func_name");
if (auto pos = exec_info.find(":"); pos != std::string::npos) { if (auto pos = exec_info.find(":"); pos != std::string::npos) {
file_path_ = exec_info.substr(0, pos); auto path = exec_info.substr(0, pos);
auto real_path = FileUtils::GetRealPath(path.c_str());
if (!real_path.has_value()) {
MS_LOG(EXCEPTION) << "Invalid file path, " << path << " does not exist.";
}
file_path_ = real_path.value();
func_name_ = exec_info.substr(pos + 1); func_name_ = exec_info.substr(pos + 1);
} else { } else {
MS_LOG(EXCEPTION) << "Wrong execute info:" << exec_info; MS_LOG(EXCEPTION) << "Wrong execute info:" << exec_info;

View File

@ -25,6 +25,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"
#include "utils/file_utils.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
@ -105,7 +106,13 @@ class CustomAOTGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
const auto &exec_info = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "func_name"); const auto &exec_info = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "func_name");
if (auto pos = exec_info.find(":"); pos != std::string::npos) { if (auto pos = exec_info.find(":"); pos != std::string::npos) {
file_path_ = exec_info.substr(0, pos); auto path = exec_info.substr(0, pos);
auto real_path = FileUtils::GetRealPath(path.c_str());
if (!real_path.has_value()) {
MS_LOG(ERROR) << "Invalid file path, " << path << " does not exist.";
return false;
}
file_path_ = real_path.value();
func_name_ = exec_info.substr(pos + 1); func_name_ = exec_info.substr(pos + 1);
} else { } else {
MS_LOG(ERROR) << "Wrong execute info:" << exec_info; MS_LOG(ERROR) << "Wrong execute info:" << exec_info;

View File

@ -0,0 +1,149 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Register bprop function for Custom Hybrid Autodiff"""
from collections import UserDict
from mindspore import log as logger
class Registry(UserDict):
"""
Registry class inherits from UserDict.
Key: length of signatures
Value : bprop function for custom hybrid op
"""
def register(self, sig_number):
def deco(fn):
self[sig_number] = fn
return fn
return deco
def get(self, sig_number):
if sig_number not in self:
logger.error(f"Autodiff currently doesn't support hyrbrid function with input num :{sig_number}. \
Supported input num is from 1 to 10")
fn = self[sig_number]
return fn
bprop_factory = Registry()
def autodiff_bprop(n):
return bprop_factory.get(n)
def get_outs(out, dout):
if isinstance(out, tuple):
tupleout = out
else:
tupleout = (out,)
if isinstance(dout, tuple):
tupledout = dout
else:
tupledout = (dout,)
return tupleout + tupledout
@bprop_factory.register(1)
def bprop_one(op):
def bprop(x1, out, dout):
inputs = (x1,) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(2)
def bprop_two(op):
def bprop(x1, x2, out, dout):
inputs = (x1, x2) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(3)
def bprop_three(op):
def bprop(x1, x2, x3, out, dout):
inputs = (x1, x2, x3) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(4)
def bprop_four(op):
def bprop(x1, x2, x3, x4, out, dout):
inputs = (x1, x2, x3, x4) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(5)
def bprop_five(op):
def bprop(x1, x2, x3, x4, x5, out, dout):
inputs = (x1, x2, x3, x4, x5) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(6)
def bprop_six(op):
def bprop(x1, x2, x3, x4, x5, x6, out, dout):
inputs = (x1, x2, x3, x4, x5, x6) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(7)
def bprop_seven(op):
def bprop(x1, x2, x3, x4, x5, x6, x7, out, dout):
inputs = (x1, x2, x3, x4, x5, x6, x7) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(8)
def bprop_eight(op):
def bprop(x1, x2, x3, x4, x5, x6, x7, x8, out, dout):
inputs = (x1, x2, x3, x4, x5, x6, x7, x8) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(9)
def bprop_nine(op):
def bprop(x1, x2, x3, x4, x5, x6, x7, x8, x9, out, dout):
inputs = (x1, x2, x3, x4, x5, x6, x7, x8, x9) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop
@bprop_factory.register(10)
def bprop_ten(op):
def bprop(x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, out, dout):
inputs = (x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) + get_outs(out, dout)
res = op(*inputs)
return res
return bprop

View File

@ -24,6 +24,7 @@ from mindspore.ops import DataType
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore._c_expression import Oplib from mindspore._c_expression import Oplib
from ._pyfunc_registry import add_pyfunc from ._pyfunc_registry import add_pyfunc
from ._custom_grad import autodiff_bprop
class Custom(ops.PrimitiveWithInfer): class Custom(ops.PrimitiveWithInfer):
@ -297,6 +298,8 @@ class Custom(ops.PrimitiveWithInfer):
self.func_type = "tvm_compute" self.func_type = "tvm_compute"
else: else:
self.func_type = "hybrid" self.func_type = "hybrid"
if not self.bprop:
self._hybrid_autodiff()
self.add_prim_attr("func_type", self.func_type) self.add_prim_attr("func_type", self.func_type)
self._update_attr() self._update_attr()
@ -571,3 +574,25 @@ class Custom(ops.PrimitiveWithInfer):
self.add_prim_attr("primitive_target", "GPU") self.add_prim_attr("primitive_target", "GPU")
elif registered_targets == ["CPU"]: elif registered_targets == ["CPU"]:
self.add_prim_attr("primitive_target", "CPU") self.add_prim_attr("primitive_target", "CPU")
if callable(self.func) and callable(self.out_shape):
if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") == "autodiff":
self.add_prim_attr("autodiff", True)
else:
self.add_prim_attr("autodiff", False)
def _hybrid_autodiff(self):
"""generate backward op for a custom hybrid op"""
inputs_num = len(inspect.signature(self.func).parameters)
if inputs_num == 0:
logger.warning("Function with no input has no backward op.")
elif inputs_num > 10:
logger.warning("Currently autodiff for function with more than 10 inputs is not supported.")
else:
grad_func = autodiff_bprop(inputs_num)
def infer_func(*args):
return args[:inputs_num]
setattr(infer_func, "type", "autodiff")
op = Custom(func=self.func, out_shape=infer_func, out_dtype=infer_func,
func_type="akg", bprop=True)
self.bprop = grad_func(op)

View File

@ -15,6 +15,7 @@
import pytest import pytest
import numpy as np import numpy as np
import mindspore as ms
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn import Cell from mindspore.nn import Cell
@ -33,26 +34,69 @@ def outer_product(a, b):
return c return c
class TestHybrid(Cell): def cube(a):
c = output_tensor((a.shape[0], a.shape[1]), 'float32')
b = allocate((a.shape[0], a.shape[1]), 'float32', 'local')
for i0 in range(a.shape[0]):
for i1 in range(a.shape[1]):
b[i0, i1] = a[i0, i1] * a[i0, i1]
c[i0, i1] = b[i0, i1] * a[i0, i1]
return c
class TestHybridTwoInputs(Cell):
"""Net definition""" """Net definition"""
def __init__(self): def __init__(self, func, shapes, types):
super(TestHybrid, self).__init__() super(TestHybridTwoInputs, self).__init__()
def infer_func(x, y): self.program = ops.Custom(func, out_shape=shapes, out_dtype=types, func_type="akg")
return x
self.program = ops.Custom(outer_product, out_shape=infer_func, out_dtype=infer_func, func_type="akg")
def construct(self, x, y): def construct(self, x, y):
return self.program(x, y) return self.program(x, y)
def hybrid_case(): class TestHybridOneInput(Cell):
"""Net definition"""
def __init__(self, func, shapes, types):
super(TestHybridOneInput, self).__init__()
self.program = ops.Custom(func, out_shape=shapes, out_dtype=types, func_type="akg")
def construct(self, x):
return self.program(x)
class MatMulNN(Cell):
"""Net definition"""
def __init__(self):
super(MatMulNN, self).__init__()
self.matmul = ops.MatMul()
def construct(self, x, y):
return self.matmul(x, y)
class PowNN(Cell):
"""Net definition"""
def __init__(self):
super(PowNN, self).__init__()
self.pow = ops.Pow()
def construct(self, x):
return self.pow(x, 3)
def hybrid_outer_product():
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32) input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32) input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
test = TestHybrid() test = TestHybridTwoInputs(outer_product, (4, 4), (ms.float32))
output = test(Tensor(input_x), Tensor(input_y)) output = test(Tensor(input_x), Tensor(input_y))
expect = np.matmul(input_x, input_y) expect = np.matmul(input_x, input_y)
compare_res = np.allclose(expect, output.asnumpy(), 0.001, 0.001) compare_res = np.allclose(expect, output.asnumpy(), 0.001, 0.001)
@ -60,6 +104,34 @@ def hybrid_case():
raise ValueError("Precision error, compare result: {}".format(compare_res)) raise ValueError("Precision error, compare result: {}".format(compare_res))
def hybrid_outer_product_autodiff():
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
sens = np.random.normal(0, 1, [4, 4]).astype(np.float32)
test = TestHybridTwoInputs(outer_product, (4, 4), (ms.float32))
net = MatMulNN()
dx, dy = ops.GradOperation(sens_param=True, get_all=True)(test)(Tensor(input_x), Tensor(input_y), Tensor(sens))
edx, edy = ops.GradOperation(sens_param=True, get_all=True)(net)(Tensor(input_x), Tensor(input_y), Tensor(sens))
compare_res = np.allclose(edx.asnumpy(), dx.asnumpy(), 0.001, 0.001)
compare_res &= np.allclose(edy.asnumpy(), dy.asnumpy(), 0.001, 0.001)
if not compare_res:
raise ValueError("Precision error, compare result: {}".format(compare_res))
def hybrid_pow_autodiff():
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
sens = np.random.normal(0, 1, [4, 4]).astype(np.float32)
test = TestHybridOneInput(cube, (4, 4), (ms.float32))
net = PowNN()
dx = ops.GradOperation(sens_param=True)(test)(Tensor(input_x), Tensor(sens))
edx = ops.GradOperation(sens_param=True)(net)(Tensor(input_x), Tensor(sens))
compare_res = np.allclose(edx.asnumpy(), dx.asnumpy(), 0.001, 0.001)
if not compare_res:
raise ValueError("Precision error, compare result: {}".format(compare_res))
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@ -71,7 +143,7 @@ def test_hybrid_ascend_graph_mode():
Expectation: the result match with numpy result Expectation: the result match with numpy result
""" """
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
hybrid_case() hybrid_outer_product()
@pytest.mark.level0 @pytest.mark.level0
@ -85,7 +157,7 @@ def test_hybrid_ascend_pynative_mode():
Expectation: the result match with numpy result Expectation: the result match with numpy result
""" """
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
hybrid_case() hybrid_outer_product()
@pytest.mark.level0 @pytest.mark.level0
@ -98,7 +170,9 @@ def test_hybrid_gpu_graph_mode():
Expectation: the result match with numpy result Expectation: the result match with numpy result
""" """
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
hybrid_case() hybrid_outer_product()
hybrid_outer_product_autodiff()
hybrid_pow_autodiff()
@pytest.mark.level0 @pytest.mark.level0
@ -111,7 +185,9 @@ def test_hybrid_gpu_pynative_mode():
Expectation: the result match with numpy result Expectation: the result match with numpy result
""" """
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
hybrid_case() hybrid_outer_product()
hybrid_outer_product_autodiff()
hybrid_pow_autodiff()
v_add_ascend_info = CustomRegOp() \ v_add_ascend_info = CustomRegOp() \