forked from mindspore-Ecosystem/mindspore
expand tanh_grad and reduce_mean, fix bug and add test_case in ci
This commit is contained in:
parent
49e3aa35a2
commit
db27783d54
|
@ -23,3 +23,5 @@ from .bias_add import expand_biasadd
|
|||
from .bias_add_grad import expand_biasaddgrad
|
||||
from .fused_adam import expand_fusedadam
|
||||
from .fused_adam_weight_decay import expand_fusedadamweightdecay
|
||||
from .reduce_mean import expand_reducemean
|
||||
from .tanh_grad import expand_tanhgrad
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for reduce_mean"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
|
||||
|
||||
def expand_reducemean(expand_info):
|
||||
"""ReduceMean expander"""
|
||||
|
||||
# get op info.
|
||||
input_desc = expand_info['input_desc'][0]
|
||||
attrs = expand_info['attr']
|
||||
axis = None
|
||||
keep_dims = None
|
||||
for item in attrs:
|
||||
if 'axis' in item:
|
||||
axis = item['axis']
|
||||
if 'keep_dims' in item:
|
||||
keep_dims = item['keep_dims']
|
||||
graph_builder = builder.GraphBuilder()
|
||||
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||
x_shape = input_x.shape
|
||||
graph_scope.set_input(input_x)
|
||||
|
||||
# cal reduce_mean
|
||||
# when axis = None, reduce axis are all
|
||||
all_shape = 1.0
|
||||
real_axis = []
|
||||
if not axis:
|
||||
for i, shape in enumerate(x_shape):
|
||||
real_axis.append(i)
|
||||
all_shape *= shape
|
||||
else:
|
||||
for idx in axis:
|
||||
all_shape *= x_shape[idx]
|
||||
|
||||
all_shape_value = graph_builder.value(input_x.dtype, all_shape, input_x.data_format)
|
||||
|
||||
if not axis:
|
||||
sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': real_axis, 'keep_dims': keep_dims})
|
||||
else:
|
||||
sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims})
|
||||
result = graph_builder.emit('RealDiv', [sum_x, all_shape_value])
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for tanh_grad"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
|
||||
ONE = 1.0
|
||||
|
||||
|
||||
def expand_tanhgrad(expand_info):
|
||||
"""TanhGrad expander"""
|
||||
|
||||
# tanh_grad(y, dy) = dy * (1- y * y)
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
graph_builder = builder.GraphBuilder()
|
||||
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_y = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
const_one = graph_builder.value(input_y.dtype, ONE, input_y.data_format)
|
||||
graph_scope.set_input(input_y, input_dy)
|
||||
|
||||
# cal result
|
||||
double_y = graph_builder.emit('Mul', [input_y, input_y])
|
||||
one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y])
|
||||
result = graph_builder.emit('Mul', [input_dy, one_sub_double_y])
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
|
@ -703,8 +703,8 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo
|
|||
std::unordered_set<PrimitivePtr> GetExpandOps() {
|
||||
std::unordered_set<PrimitivePtr> expand_ops = {
|
||||
prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu,
|
||||
prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay,
|
||||
};
|
||||
prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, prim::kPrimTanhGrad,
|
||||
prim::kPrimReduceMean};
|
||||
return expand_ops;
|
||||
}
|
||||
|
||||
|
|
|
@ -29,12 +29,19 @@ bool BindValueToGraph::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto &value_nodes = kernel_graph->graph_value_nodes();
|
||||
bool changed = false;
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
for (auto node : todos) {
|
||||
if (!GetValueNode<tensor::TensorPtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
if (auto vptr = node->cast<ValueNodePtr>(); value_nodes.count(vptr) == 0) {
|
||||
kernel_graph->AddValueNodeToGraph(vptr);
|
||||
auto new_node = kernel_graph->NewValueNode(vptr);
|
||||
mng->Replace(vptr, new_node);
|
||||
kernel_graph->AddValueNodeToGraph(new_node);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", enable_graph_kernel=True)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, decay_flag=True):
|
||||
super(Net, self).__init__()
|
||||
self.decay_flag = decay_flag
|
||||
self.op_mul = P.Mul()
|
||||
self.op_square = P.Square()
|
||||
self.op_sqrt = P.Sqrt()
|
||||
self.op_cast = P.Cast()
|
||||
self.op_reshape = P.Reshape()
|
||||
self.op_shape = P.Shape()
|
||||
self.param = Parameter(Tensor(np.array([1, 3, 5]).astype(np.float32)), name='param')
|
||||
self.m = Parameter(Tensor(np.array([0.11, 0.33, 0.55]).astype(np.float32)), name='m')
|
||||
self.v = Parameter(Tensor(np.array([1.2, 3.4, 5.6]).astype(np.float32)), name='v')
|
||||
|
||||
@ms_function
|
||||
def construct(self, beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr):
|
||||
param_fp32 = self.op_cast(self.param, mstype.float32)
|
||||
m_fp32 = self.op_cast(self.m, mstype.float32)
|
||||
v_fp32 = self.op_cast(self.v, mstype.float32)
|
||||
gradient_fp32 = self.op_cast(gradient, mstype.float32)
|
||||
|
||||
next_m = self.op_mul(beta1, m_fp32) + \
|
||||
self.op_mul(self.op_cast(one_sub_beta_1, mstype.float32), gradient_fp32)
|
||||
next_v = self.op_mul(beta2, v_fp32) + self.op_mul(self.op_cast(one_sub_beta_2,
|
||||
mstype.float32), self.op_square(gradient_fp32))
|
||||
update = next_m / (eps + self.op_sqrt(next_v))
|
||||
if self.decay_flag:
|
||||
update = self.op_mul(weight_decay_tensor, param_fp32) + update
|
||||
update_with_lr = self.op_mul(lr, update)
|
||||
next_param = param_fp32 - self.op_reshape(update_with_lr, self.op_shape(param_fp32))
|
||||
|
||||
depend_v = F.depend(next_param, F.assign(self.param, next_param))
|
||||
depend_v = F.depend(depend_v, F.assign(self.m, next_m))
|
||||
depend_v = F.depend(depend_v, F.assign(self.v, next_v))
|
||||
return depend_v
|
||||
|
||||
|
||||
def CalFusedAdam(beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr, param, m, v,
|
||||
is_weight_decay=False):
|
||||
m_expect = beta1 * m + one_sub_beta_1 * gradient
|
||||
v_expect = beta2 * v + one_sub_beta_2 * gradient * gradient
|
||||
update = m_expect / (np.sqrt(v_expect) + eps)
|
||||
if is_weight_decay:
|
||||
update += weight_decay_tensor * param
|
||||
param_expect = param - lr * update
|
||||
return param_expect, m_expect, v_expect
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_adam():
|
||||
np.random.seed(0)
|
||||
beta1 = np.array([0.9]).astype(np.float32)
|
||||
beta2 = np.array([0.999]).astype(np.float32)
|
||||
one_sub_beta_1 = (np.array([1.0]) - np.array([0.9])).astype(np.float32)
|
||||
one_sub_beta_2 = (np.array([1.0]) - np.array([0.999])).astype(np.float32)
|
||||
lr = np.array([0.012]).astype(np.float32)
|
||||
eps = np.array([1e-6]).astype(np.float32)
|
||||
weight_decay_tensor = np.array([0.021]).astype(np.float32)
|
||||
|
||||
gradient = np.array([0.01, 0.03, 0.05]).astype(np.float32)
|
||||
m = np.array([0.11, 0.33, 0.55]).astype(np.float32)
|
||||
v = np.array([1.2, 3.4, 5.6]).astype(np.float32)
|
||||
param = np.array([1, 3, 5]).astype(np.float32)
|
||||
is_weight_decay = False
|
||||
opt = Net(is_weight_decay)
|
||||
_ = opt(Tensor(beta1), Tensor(beta2), Tensor(one_sub_beta_1), Tensor(one_sub_beta_2), Tensor(gradient), Tensor(eps),
|
||||
Tensor(weight_decay_tensor), Tensor(lr))
|
||||
param_expect, m_expect, v_expect = CalFusedAdam(
|
||||
beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr,
|
||||
param, m, v, is_weight_decay)
|
||||
assert np.allclose(opt.param.data.asnumpy(), param_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
||||
assert np.allclose(opt.m.data.asnumpy(), m_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
||||
assert np.allclose(opt.v.data.asnumpy(), v_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
||||
|
||||
|
||||
def test_adam_weight_decay():
|
||||
np.random.seed(0)
|
||||
beta1 = np.array([0.9]).astype(np.float32)
|
||||
beta2 = np.array([0.999]).astype(np.float32)
|
||||
one_sub_beta_1 = (np.array([1.0]) - np.array([0.9])).astype(np.float32)
|
||||
one_sub_beta_2 = (np.array([1.0]) - np.array([0.999])).astype(np.float32)
|
||||
lr = np.array([0.012]).astype(np.float32)
|
||||
eps = np.array([1e-6]).astype(np.float32)
|
||||
weight_decay_tensor = np.array([0.021]).astype(np.float32)
|
||||
|
||||
gradient = np.array([0.01, 0.03, 0.05]).astype(np.float32)
|
||||
m = np.array([0.11, 0.33, 0.55]).astype(np.float32)
|
||||
v = np.array([1.2, 3.4, 5.6]).astype(np.float32)
|
||||
param = np.array([1, 3, 5]).astype(np.float32)
|
||||
is_weight_decay = True
|
||||
opt = Net(is_weight_decay)
|
||||
_ = opt(Tensor(beta1), Tensor(beta2), Tensor(one_sub_beta_1), Tensor(one_sub_beta_2), Tensor(gradient), Tensor(eps),
|
||||
Tensor(weight_decay_tensor), Tensor(lr))
|
||||
param_expect, m_expect, v_expect = CalFusedAdam(
|
||||
beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr,
|
||||
param, m, v, is_weight_decay)
|
||||
|
||||
assert np.allclose(opt.param.data.asnumpy(), param_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
||||
assert np.allclose(opt.m.data.asnumpy(), m_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
||||
assert np.allclose(opt.v.data.asnumpy(), v_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops.operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=False)
|
||||
|
||||
def construct(self, x):
|
||||
return self.reduce_mean(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reduce_mean():
|
||||
np.random.seed(0)
|
||||
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
|
||||
expect = np.mean(input_x, keepdims=False)
|
||||
net = Net()
|
||||
result = net(Tensor(input_x))
|
||||
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
|
||||
assert res
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops.operations._grad_ops as G
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||
|
||||
|
||||
class TanhGradNet(Cell):
|
||||
def __init__(self):
|
||||
super(TanhGradNet, self).__init__()
|
||||
self.tanh_grad = G.TanhGrad()
|
||||
|
||||
def construct(self, y, dy):
|
||||
return self.tanh_grad(y, dy)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tanh_grad():
|
||||
np.random.seed(0)
|
||||
input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
|
||||
input_dy = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
|
||||
net = TanhGradNet()
|
||||
result = net(Tensor(input_y), Tensor(input_dy))
|
||||
expect = input_dy * (1.0 - input_y * input_y)
|
||||
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
|
||||
assert res
|
Loading…
Reference in New Issue