diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index f0bdd539696..ce5bd7b488b 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -23,3 +23,5 @@ from .bias_add import expand_biasadd from .bias_add_grad import expand_biasaddgrad from .fused_adam import expand_fusedadam from .fused_adam_weight_decay import expand_fusedadamweightdecay +from .reduce_mean import expand_reducemean +from .tanh_grad import expand_tanhgrad diff --git a/mindspore/_extends/graph_kernel/expanders/reduce_mean.py b/mindspore/_extends/graph_kernel/expanders/reduce_mean.py new file mode 100644 index 00000000000..932e59d4129 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/reduce_mean.py @@ -0,0 +1,65 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for reduce_mean""" +from mindspore._extends.graph_kernel.model import model_builder as builder + + +def expand_reducemean(expand_info): + """ReduceMean expander""" + + # get op info. + input_desc = expand_info['input_desc'][0] + attrs = expand_info['attr'] + axis = None + keep_dims = None + for item in attrs: + if 'axis' in item: + axis = item['axis'] + if 'keep_dims' in item: + keep_dims = item['keep_dims'] + graph_builder = builder.GraphBuilder() + + # generate a graph. + with graph_builder.graph_scope('main') as graph_scope: + # create tensor input. + input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) + x_shape = input_x.shape + graph_scope.set_input(input_x) + + # cal reduce_mean + # when axis = None, reduce axis are all + all_shape = 1.0 + real_axis = [] + if not axis: + for i, shape in enumerate(x_shape): + real_axis.append(i) + all_shape *= shape + else: + for idx in axis: + all_shape *= x_shape[idx] + + all_shape_value = graph_builder.value(input_x.dtype, all_shape, input_x.data_format) + + if not axis: + sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': real_axis, 'keep_dims': keep_dims}) + else: + sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims}) + result = graph_builder.emit('RealDiv', [sum_x, all_shape_value]) + + # set graph output. + graph_scope.set_output(result) + + graph = graph_builder.get()[0] + return graph diff --git a/mindspore/_extends/graph_kernel/expanders/tanh_grad.py b/mindspore/_extends/graph_kernel/expanders/tanh_grad.py new file mode 100644 index 00000000000..b60521bb612 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/tanh_grad.py @@ -0,0 +1,47 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for tanh_grad""" +from mindspore._extends.graph_kernel.model import model_builder as builder + +ONE = 1.0 + + +def expand_tanhgrad(expand_info): + """TanhGrad expander""" + + # tanh_grad(y, dy) = dy * (1- y * y) + # get op info. + input_desc_0 = expand_info['input_desc'][0] + input_desc_1 = expand_info['input_desc'][1] + graph_builder = builder.GraphBuilder() + + # generate a graph. + with graph_builder.graph_scope('main') as graph_scope: + # create tensor input. + input_y = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) + input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) + const_one = graph_builder.value(input_y.dtype, ONE, input_y.data_format) + graph_scope.set_input(input_y, input_dy) + + # cal result + double_y = graph_builder.emit('Mul', [input_y, input_y]) + one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y]) + result = graph_builder.emit('Mul', [input_dy, one_sub_double_y]) + + # set graph output. + graph_scope.set_output(result) + + graph = graph_builder.get()[0] + return graph diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 603c451a98d..9b0a5a1fa70 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -702,9 +702,9 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector GetExpandOps() { std::unordered_set expand_ops = { - prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, - prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, - }; + prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, + prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, prim::kPrimTanhGrad, + prim::kPrimReduceMean}; return expand_ops; } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/value_graph_binder.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/value_graph_binder.cc index db0078833dd..ef3e5bfa627 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/value_graph_binder.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/value_graph_binder.cc @@ -29,12 +29,19 @@ bool BindValueToGraph::Run(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); auto &value_nodes = kernel_graph->graph_value_nodes(); bool changed = false; + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } for (auto node : todos) { if (!GetValueNode(node)) { continue; } if (auto vptr = node->cast(); value_nodes.count(vptr) == 0) { - kernel_graph->AddValueNodeToGraph(vptr); + auto new_node = kernel_graph->NewValueNode(vptr); + mng->Replace(vptr, new_node); + kernel_graph->AddValueNodeToGraph(new_node); changed = true; } } diff --git a/tests/st/ops/graph_kernel/test_fused_adam.py b/tests/st/ops/graph_kernel/test_fused_adam.py new file mode 100644 index 00000000000..ac230a2facc --- /dev/null +++ b/tests/st/ops/graph_kernel/test_fused_adam.py @@ -0,0 +1,132 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU", enable_graph_kernel=True) + + +class Net(nn.Cell): + def __init__(self, decay_flag=True): + super(Net, self).__init__() + self.decay_flag = decay_flag + self.op_mul = P.Mul() + self.op_square = P.Square() + self.op_sqrt = P.Sqrt() + self.op_cast = P.Cast() + self.op_reshape = P.Reshape() + self.op_shape = P.Shape() + self.param = Parameter(Tensor(np.array([1, 3, 5]).astype(np.float32)), name='param') + self.m = Parameter(Tensor(np.array([0.11, 0.33, 0.55]).astype(np.float32)), name='m') + self.v = Parameter(Tensor(np.array([1.2, 3.4, 5.6]).astype(np.float32)), name='v') + + @ms_function + def construct(self, beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr): + param_fp32 = self.op_cast(self.param, mstype.float32) + m_fp32 = self.op_cast(self.m, mstype.float32) + v_fp32 = self.op_cast(self.v, mstype.float32) + gradient_fp32 = self.op_cast(gradient, mstype.float32) + + next_m = self.op_mul(beta1, m_fp32) + \ + self.op_mul(self.op_cast(one_sub_beta_1, mstype.float32), gradient_fp32) + next_v = self.op_mul(beta2, v_fp32) + self.op_mul(self.op_cast(one_sub_beta_2, + mstype.float32), self.op_square(gradient_fp32)) + update = next_m / (eps + self.op_sqrt(next_v)) + if self.decay_flag: + update = self.op_mul(weight_decay_tensor, param_fp32) + update + update_with_lr = self.op_mul(lr, update) + next_param = param_fp32 - self.op_reshape(update_with_lr, self.op_shape(param_fp32)) + + depend_v = F.depend(next_param, F.assign(self.param, next_param)) + depend_v = F.depend(depend_v, F.assign(self.m, next_m)) + depend_v = F.depend(depend_v, F.assign(self.v, next_v)) + return depend_v + + +def CalFusedAdam(beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr, param, m, v, + is_weight_decay=False): + m_expect = beta1 * m + one_sub_beta_1 * gradient + v_expect = beta2 * v + one_sub_beta_2 * gradient * gradient + update = m_expect / (np.sqrt(v_expect) + eps) + if is_weight_decay: + update += weight_decay_tensor * param + param_expect = param - lr * update + return param_expect, m_expect, v_expect + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_adam(): + np.random.seed(0) + beta1 = np.array([0.9]).astype(np.float32) + beta2 = np.array([0.999]).astype(np.float32) + one_sub_beta_1 = (np.array([1.0]) - np.array([0.9])).astype(np.float32) + one_sub_beta_2 = (np.array([1.0]) - np.array([0.999])).astype(np.float32) + lr = np.array([0.012]).astype(np.float32) + eps = np.array([1e-6]).astype(np.float32) + weight_decay_tensor = np.array([0.021]).astype(np.float32) + + gradient = np.array([0.01, 0.03, 0.05]).astype(np.float32) + m = np.array([0.11, 0.33, 0.55]).astype(np.float32) + v = np.array([1.2, 3.4, 5.6]).astype(np.float32) + param = np.array([1, 3, 5]).astype(np.float32) + is_weight_decay = False + opt = Net(is_weight_decay) + _ = opt(Tensor(beta1), Tensor(beta2), Tensor(one_sub_beta_1), Tensor(one_sub_beta_2), Tensor(gradient), Tensor(eps), + Tensor(weight_decay_tensor), Tensor(lr)) + param_expect, m_expect, v_expect = CalFusedAdam( + beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr, + param, m, v, is_weight_decay) + assert np.allclose(opt.param.data.asnumpy(), param_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(opt.m.data.asnumpy(), m_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(opt.v.data.asnumpy(), v_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) + + +def test_adam_weight_decay(): + np.random.seed(0) + beta1 = np.array([0.9]).astype(np.float32) + beta2 = np.array([0.999]).astype(np.float32) + one_sub_beta_1 = (np.array([1.0]) - np.array([0.9])).astype(np.float32) + one_sub_beta_2 = (np.array([1.0]) - np.array([0.999])).astype(np.float32) + lr = np.array([0.012]).astype(np.float32) + eps = np.array([1e-6]).astype(np.float32) + weight_decay_tensor = np.array([0.021]).astype(np.float32) + + gradient = np.array([0.01, 0.03, 0.05]).astype(np.float32) + m = np.array([0.11, 0.33, 0.55]).astype(np.float32) + v = np.array([1.2, 3.4, 5.6]).astype(np.float32) + param = np.array([1, 3, 5]).astype(np.float32) + is_weight_decay = True + opt = Net(is_weight_decay) + _ = opt(Tensor(beta1), Tensor(beta2), Tensor(one_sub_beta_1), Tensor(one_sub_beta_2), Tensor(gradient), Tensor(eps), + Tensor(weight_decay_tensor), Tensor(lr)) + param_expect, m_expect, v_expect = CalFusedAdam( + beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr, + param, m, v, is_weight_decay) + + assert np.allclose(opt.param.data.asnumpy(), param_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(opt.m.data.asnumpy(), m_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(opt.v.data.asnumpy(), v_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) diff --git a/tests/st/ops/graph_kernel/test_reduce_mean.py b/tests/st/ops/graph_kernel/test_reduce_mean.py new file mode 100644 index 00000000000..0ce44d78135 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_reduce_mean.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell +import mindspore.ops.operations as P + +context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") + + +class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.reduce_mean = P.ReduceMean(keep_dims=False) + + def construct(self, x): + return self.reduce_mean(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_reduce_mean(): + np.random.seed(0) + input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) + expect = np.mean(input_x, keepdims=False) + net = Net() + result = net(Tensor(input_x)) + res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True) + assert res diff --git a/tests/st/ops/graph_kernel/test_tanh_grad.py b/tests/st/ops/graph_kernel/test_tanh_grad.py new file mode 100644 index 00000000000..da100dbdad2 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_tanh_grad.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell +import mindspore.ops.operations._grad_ops as G + +context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") + + +class TanhGradNet(Cell): + def __init__(self): + super(TanhGradNet, self).__init__() + self.tanh_grad = G.TanhGrad() + + def construct(self, y, dy): + return self.tanh_grad(y, dy) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_tanh_grad(): + np.random.seed(0) + input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) + input_dy = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) + net = TanhGradNet() + result = net(Tensor(input_y), Tensor(input_dy)) + expect = input_dy * (1.0 - input_y * input_y) + res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True) + assert res