add some expander ops

This commit is contained in:
zengzitao 2021-03-18 09:41:45 +08:00
parent 9bac30d37f
commit d0a656f3cd
14 changed files with 532 additions and 3 deletions

View File

@ -14,6 +14,7 @@
# ============================================================================
"""expanders init"""
from .assign_add import AssignAdd
from .bias_add import BiasAdd
from .bias_add_grad import BiasAddGrad
from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
@ -31,6 +32,11 @@ from .maximum_grad import MaximumGrad
from .minimum_grad import MinimumGrad
from .reduce_mean import ReduceMean
from .softmax import Softmax
from .sigmoid import Sigmoid
from .sigmoid_grad import SigmoidGrad
from .sigmoid_cross_entropy_with_logits import SigmoidCrossEntropyWithLogits
from .sigmoid_cross_entropy_with_logits_grad import SigmoidCrossEntropyWithLogitsGrad
from .softmax_cross_entropy_with_logits import SoftmaxCrossEntropyWithLogits
from .sqrt_grad import SqrtGrad
from .square import Square
from .tanh_grad import TanhGrad

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for assign_add"""
from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.check_all_formats_same
class AssignAdd(Expander):
"""AssignAdd expander"""
def _expand(self, graph_builder):
param, x = self.inputs
next_para = graph_builder.emit('Add', [param, x])
param_result = graph_builder.emit(
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
return param_result

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for Sigmoid"""
from ._utils import Expander
class Sigmoid(Expander):
"""Sigmoid expander"""
def _expand(self, graph_builder):
input_x = self.inputs[0]
# Calculate sigmoid(x)
# formula is : sigmoid(x) = 1 / (1 + exp(-x))
const_one = graph_builder.value(input_x.dtype, 1.0)
neg_x = graph_builder.emit('Neg', [input_x])
exp_neg_x = graph_builder.emit('Exp', [neg_x])
add_exp = graph_builder.emit('Add', [const_one, exp_neg_x])
res = graph_builder.emit('RealDiv', [const_one, add_exp])
return res

View File

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for SigmoidCrossEntropyWithLogits"""
from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.check_all_formats_same
class SigmoidCrossEntropyWithLogits(Expander):
"""SigmoidCrossEntropyWithLogits expander"""
def _expand(self, graph_builder):
logits, label = self.inputs
# Calculate sigmoid_cross_entropy_with_logits(logits, label)
# formula is :
# sigmoid_cross_entropy_with_logits(logits, label)
# = -(label * log(sigmoid(logits)) + (1 - label) * log(1 - sigmoid(logits)))
const_one = graph_builder.value(logits.dtype, 1.0)
neg_x = graph_builder.emit('Neg', [logits])
exp_neg_x = graph_builder.emit('Exp', [neg_x])
add_exp = graph_builder.emit('Add', [const_one, exp_neg_x])
p = graph_builder.emit('RealDiv', [const_one, add_exp])
one_sub_p = graph_builder.emit('Sub', [const_one, p])
one_sub_label = graph_builder.emit('Sub', [const_one, label])
log_p = graph_builder.emit('Log', [p])
log_one_sub_p = graph_builder.emit('Log', [one_sub_p])
res_tmp_1 = graph_builder.emit('Mul', [one_sub_label, log_one_sub_p])
res_tmp_2 = graph_builder.emit('Mul', [label, log_p])
res_tmp = graph_builder.emit('Add', [res_tmp_1, res_tmp_2])
res = graph_builder.emit('Neg', [res_tmp])
return res

View File

@ -0,0 +1,35 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for SigmoidCrossEntropyWithLogitsGrad"""
from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.check_all_formats_same
class SigmoidCrossEntropyWithLogitsGrad(Expander):
"""SigmoidCrossEntropyWithLogitsGrad expander"""
def _expand(self, graph_builder):
logits, label, dout = self.inputs
# Calculate sigmoid_cross_entropy_with_logits_grad(logits, label, dout)
# formula is :
# sigmoid_cross_entropy_with_logits_grad(logits, label, dout) = (sigmoid(logits) - label) * dout
const_one = graph_builder.value(logits.dtype, 1.0)
neg_x = graph_builder.emit('Neg', [logits])
exp_neg_x = graph_builder.emit('Exp', [neg_x])
add_exp = graph_builder.emit('Add', [const_one, exp_neg_x])
sigmoid_res = graph_builder.emit('RealDiv', [const_one, add_exp])
sigmoid_res_sub_label = graph_builder.emit('Sub', [sigmoid_res, label])
res = graph_builder.emit('Mul', [sigmoid_res_sub_label, dout])
return res

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for SigmoidGrad"""
from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.check_all_formats_same
class SigmoidGrad(Expander):
"""SigmoidGrad expander"""
def _expand(self, graph_builder):
input_y, dy = self.inputs
# Calculate sigmoid_grad(y, dy)
# formula is : sigmoid_grad(y, dy) = (1 - y) * y * dy
const_one = graph_builder.value(input_y.dtype, 1.0)
one_mins_y = graph_builder.emit('Sub', [const_one, input_y])
y_mul_dy = graph_builder.emit('Mul', [input_y, dy])
res = graph_builder.emit('Mul', [one_mins_y, y_mul_dy])
return res

View File

@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for SoftmaxCrossEntropyWithLogits"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
class SoftmaxCrossEntropyWithLogits(Expander):
"""SoftmaxCrossEntropyWithLogits expander"""
def _expand(self, graph_builder):
logits, label = self.inputs
# Calculate softmax_cross_entropy_with_logits(logits, label)
# formula is :
# softmax_cross_entropy_with_logits(logits, label) = -reduce_sum(label * log(softmax(logits)))
axis = (-1,)
max_x = graph_builder.emit('ReduceMax', [logits], attrs={'reduce_axis': axis, 'keep_dims': True})
data_sub = graph_builder.emit('Sub', [logits, max_x])
data_exp = graph_builder.emit('Exp', [data_sub])
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
data_softmax = graph_builder.emit('RealDiv', [data_exp, data_expsum])
softmax_log = graph_builder.emit('Log', [data_softmax])
label_mul_log = graph_builder.emit('Mul', [label, softmax_log])
tmp_res = data_expsum = graph_builder.emit('ReduceSum', [label_mul_log], attrs={
'reduce_axis': axis, 'keep_dims': True})
loss = graph_builder.emit('Neg', [tmp_res])
dlogits = graph_builder.emit('Sub', [data_softmax, label])
return loss, dlogits

View File

@ -58,6 +58,12 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
prim::kPrimSoftmax,
prim::kPrimLayerNorm,
prim::kPrimLayerNormGrad,
prim::kPrimSigmoid,
prim::kPrimSigmoidGrad,
prim::kPrimSigmoidCrossEntropyWithLogits,
prim::kPrimSigmoidCrossEntropyWithLogitsGrad,
prim::kPrimSoftmaxCrossEntropyWithLogits,
prim::kPrimAssignAdd,
#endif
};
return expand_ops;

View File

@ -27,16 +27,22 @@
namespace mindspore {
namespace opt {
const BaseRef SplitAssign::DefinePattern() const {
VarPtr v = std::make_shared<Var>();
VarPtr Xs = std::make_shared<Var>();
VarPtr Us = std::make_shared<Var>();
VarPtr UMonad = std::make_shared<Var>();
return VectorRef({prim::kPrimAssign, Xs, Us, UMonad});
return VectorRef({v, Xs, Us, UMonad});
}
bool CanSplit(const AnfNodePtr &node) {
return IsPrimitiveCNode(node, prim::kPrimAssignAdd) || IsPrimitiveCNode(node, prim::kPrimAssign) ||
IsPrimitiveCNode(node, prim::kPrimAssignSub);
}
const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
if (!CanSplit(node)) return node;
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
CheckCNodeInputSize(cnode, kAssignInputTensorNum);
@ -49,7 +55,7 @@ const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfN
depend_cnode->set_abstract(original_inputs[1]->abstract());
depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
// Create new assign node, delete U from inputs.
AnfNodePtrList new_assign_inputs = {NewValueNode(prim::kPrimAssign), depend_cnode, original_inputs[2]};
AnfNodePtrList new_assign_inputs = {cnode->input(0), depend_cnode, original_inputs[2]};
auto new_assign_cnode = func_graph->NewCNode(new_assign_inputs);
new_assign_cnode->set_abstract(original_abstract);
new_assign_cnode->set_kernel_info(cnode->kernel_info_ptr());

View File

@ -512,6 +512,7 @@ inline const PrimitivePtr kPrimScaleFusion = std::make_shared<Primitive>("ScaleF
inline const PrimitivePtr kPrimSubFusion = std::make_shared<Primitive>("SubFusion");
inline const PrimitivePtr kPrimMulFusion = std::make_shared<Primitive>("MulFusion");
inline const PrimitivePtr kPrimSigmoid = std::make_shared<Primitive>("Sigmoid");
inline const PrimitivePtr kPrimSigmoidGrad = std::make_shared<Primitive>("SigmoidGrad");
inline const PrimitivePtr kPrimClip = std::make_shared<Primitive>("Clip");
inline const PrimitivePtr kPrimHardTanh = std::make_shared<Primitive>("HardTanh");
inline const PrimitivePtr kPrimDepthWiseConv2DTransposeFusion =

View File

@ -0,0 +1,57 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
class AssignAdd(nn.Cell):
def __init__(self, value):
super(AssignAdd, self).__init__()
self.var = Parameter(value, name="var")
self.add = P.AssignAdd()
def construct(self, y):
self.add(self.var, y)
return self.var
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_assign_add():
x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=True, device_target="GPU")
add = AssignAdd(x2)
result_gk_on_1 = add(y2)
add_2 = AssignAdd(result_gk_on_1)
result_gk_on_2 = add_2(y2)
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=False, device_target="GPU")
add_beta = AssignAdd(x2)
result_gk_off_1 = add_beta(y2)
add_beta_2 = AssignAdd(result_gk_off_1)
result_gk_off_2 = add_beta_2(y2)
assert (result_gk_on_1.asnumpy() == result_gk_off_1.asnumpy()).all()
assert (result_gk_on_2.asnumpy() == result_gk_off_2.asnumpy()).all()

View File

@ -0,0 +1,90 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
class NetSigmoid(nn.Cell):
def __init__(self):
super(NetSigmoid, self).__init__()
self.sigmoid = P.Sigmoid()
def construct(self, x):
return self.sigmoid(x)
class NetSigmoidGrad(nn.Cell):
def __init__(self):
super(NetSigmoidGrad, self).__init__()
self.sigmoid_grad = G.SigmoidGrad()
def construct(self, y, dy):
return self.sigmoid_grad(y, dy)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sigmoid():
x = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.float32))
error = np.ones(shape=[1, 1, 3, 3]) * 1.0e-6
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=True, device_target="GPU")
net = NetSigmoid()
result_open_gk = net(x)
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=False, device_target="GPU")
net_beta = NetSigmoid()
result_close_gk = net_beta(x)
diff = result_open_gk.asnumpy() - result_close_gk.asnumpy()
assert np.all(abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sigmoid_grad():
y = Tensor(np.array([[[[-1, 1, 2],
[1, -1, 1],
[2, 1, -1]]]]).astype(np.float32))
dy = Tensor(np.array([[[[-11, 2, 4],
[-1, 1, -1],
[-4, 4, -4]]]]).astype(np.float32))
error = np.ones(shape=[1, 1, 3, 3]) * 1.0e-6
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=True, device_target="GPU")
net = NetSigmoidGrad()
result_open_gk = net(y, dy)
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=False, device_target="GPU")
net_beta = NetSigmoidGrad()
result_close_gk = net_beta(y, dy)
diff = result_open_gk.asnumpy() - result_close_gk.asnumpy()
assert np.all(abs(diff) < error)

View File

@ -0,0 +1,94 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
class NetSigmoidCrossEntropyWithLogits(nn.Cell):
def __init__(self):
super(NetSigmoidCrossEntropyWithLogits, self).__init__()
self.loss = P.SigmoidCrossEntropyWithLogits()
def construct(self, logits, labels):
return self.loss(logits, labels)
class NetSigmoidCrossEntropyWithLogitsGrad(nn.Cell):
def __init__(self):
super(NetSigmoidCrossEntropyWithLogitsGrad, self).__init__()
self.sigmoid_cross_entropy_with_logits_grad = G.SigmoidCrossEntropyWithLogitsGrad()
def construct(self, logits, labels, dout):
return self.sigmoid_cross_entropy_with_logits_grad(logits, labels, dout)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sigmoid_cross_entropy_with_logits():
logits = Tensor(np.array([[1, 1, 2],
[1, 2, 1],
[2, 1, 1]]).astype(np.float32))
labels = Tensor(np.array([[0, 0, 1],
[0, 1, 0],
[1, 0, 0]]).astype(np.float32))
error = np.ones(shape=[3, 3]) * 1.0e-6
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=True, device_target="GPU")
sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits()
result_open_gk = sigmoid_cross_entropy_with_logits(logits, labels)
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=False, device_target="GPU")
sigmoid_cross_entropy_with_logits_beta = NetSigmoidCrossEntropyWithLogits()
result_close_gk = sigmoid_cross_entropy_with_logits_beta(logits, labels)
diff = result_open_gk.asnumpy() - result_close_gk.asnumpy()
assert np.all(abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sigmoid_cross_entropy_with_logits_grad():
logits = Tensor(np.array([[1, 1, 2],
[1, 2, 1],
[2, 1, 1]]).astype(np.float32))
labels = Tensor(np.array([[0, 0, 1],
[0, 1, 0],
[1, 0, 0]]).astype(np.float32))
dout = Tensor(np.ones(shape=[3, 3]).astype(np.float32))
error = np.ones(shape=[3, 3]) * 1.0e-6
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=True, device_target="GPU")
sigmoid_cross_entropy_with_logits_grad = NetSigmoidCrossEntropyWithLogitsGrad()
result_open_gk = sigmoid_cross_entropy_with_logits_grad(logits, labels, dout)
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=False, device_target="GPU")
sigmoid_cross_entropy_with_logits_grad_beta = NetSigmoidCrossEntropyWithLogitsGrad()
result_close_gk = sigmoid_cross_entropy_with_logits_grad_beta(logits, labels, dout)
diff = result_open_gk.asnumpy() - result_close_gk.asnumpy()
assert np.all(abs(diff) < error)

View File

@ -0,0 +1,59 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class NetSoftmaxCrossEntropyWithLogits(nn.Cell):
def __init__(self):
super(NetSoftmaxCrossEntropyWithLogits, self).__init__()
self.loss = P.SoftmaxCrossEntropyWithLogits()
def construct(self, logits, labels):
return self.loss(logits, labels)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_softmax_cross_entropy_with_logits():
logits = Tensor(np.array([[1, 1, 10],
[1, 10, 1],
[10, 1, 1]]).astype(np.float32))
labels = Tensor(np.array([[0, 0, 1],
[0, 1, 0],
[1, 0, 0]]).astype(np.float32))
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=True, device_target="GPU")
softmax_cross_entropy_with_logits = NetSoftmaxCrossEntropyWithLogits()
result_open_gk = softmax_cross_entropy_with_logits(logits, labels)
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=False, device_target="GPU")
softmax_cross_entropy_with_logits_beta = NetSoftmaxCrossEntropyWithLogits()
result_close_gk = softmax_cross_entropy_with_logits_beta(logits, labels)
error0 = 1.0e-6
diff0 = result_open_gk[0].asnumpy() - result_close_gk[0].asnumpy()
diff1 = result_open_gk[1].asnumpy() - result_close_gk[1].asnumpy()
assert np.all(abs(diff0) < error0)
assert np.all(abs(diff1) < error0)