mindspore/tests/vm_impl/math_ops_vm_impl.py

303 lines
6.5 KiB
Python
Raw Normal View History

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate vm_impl function for math ops"""
import copy
import numpy as np
2020-05-13 11:30:27 +08:00
from mindspore.common.dtype import dtype_to_nptype
from mindspore.common.tensor import Tensor
2020-05-13 11:30:27 +08:00
from mindspore.ops import operations as P
from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
from .vm_interface import vm
2020-05-13 11:30:27 +08:00
# pylint: disable=unused-argument
@vm_impl_getters.register(P.TensorAdd)
def vm_impl_tensor_add(self):
"""Generate vm_impl function for TensorAdd."""
2020-05-13 11:30:27 +08:00
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
return Tensor(x + y)
2020-05-13 11:30:27 +08:00
return vm_impl
2020-05-13 11:30:27 +08:00
2020-05-22 13:23:13 +08:00
# pylint: disable=used-before-assignment
@vm_impl_getters.register(P.LogicalNot)
def vm_impl_logical_not(self):
def vm_impl(x):
x = x.asnumpy()
out = vm.logical_not(x)
return Tensor(out)
return vm_impl
2020-05-13 11:30:27 +08:00
@vm_impl_getters.register(P.MatMul)
def vm_impl_mat_mul(self):
"""Generate vm_impl function for MatMul."""
2020-05-13 11:30:27 +08:00
def vm_impl(x, w):
x = x.asnumpy()
w = w.asnumpy()
if self.transpose_a:
x = x.transpose()
if self.transpose_b:
w = w.transpose()
2020-05-13 11:30:27 +08:00
z = x @ w
return Tensor(z)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.AddN)
def vm_impl_addn(self):
"""Generate vm_impl function for AddN."""
2020-05-13 11:30:27 +08:00
def vm_impl(inputs):
added = copy.deepcopy(inputs[0].asnumpy())
for x in inputs[1:]:
added += x.asnumpy()
return Tensor(added)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Neg)
def vm_impl_neg(self):
"""Generate vm_impl function for Neg."""
2020-05-13 11:30:27 +08:00
def vm_impl(x):
x = x.asnumpy()
return Tensor(-x)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Sub)
def vm_impl_Sub(self):
"""Generate vm_impl function for Sub."""
2020-05-13 11:30:27 +08:00
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
return Tensor(x - y)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Mul)
def vm_impl_mul(self):
"""Generate vm_impl function for Mul."""
2020-05-13 11:30:27 +08:00
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
return Tensor(x * y)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Square)
def vm_impl_square(self):
"""Generate vm_impl function for Square."""
2020-05-13 11:30:27 +08:00
def vm_impl(x):
x = x.asnumpy()
return Tensor(x * x)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Sqrt)
def vm_impl_sqrt(self):
"""Generate vm_impl function for Sqrt."""
2020-05-13 11:30:27 +08:00
def vm_impl(x):
x = x.asnumpy()
res = vm.sqrt(x)
return Tensor(res)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Pow)
def vm_impl_pow(self):
"""Generate vm_impl function for Pow."""
2020-05-13 11:30:27 +08:00
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
res = vm.power(x, y)
return Tensor(res)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Exp)
def vm_impl_exp(self):
"""Generate vm_impl function for Exp."""
2020-05-13 11:30:27 +08:00
def vm_impl(x):
x = x.asnumpy()
res = vm.exp(x)
return Tensor(res)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.RealDiv)
def vm_impl_real_div(self):
"""Generate vm_impl function for RealDiv."""
2020-05-13 11:30:27 +08:00
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
out = x / y
out = np.array(out, x.dtype)
return Tensor(out)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Div)
def vm_impl_div(self):
"""Generate vm_impl function for Div."""
2020-05-13 11:30:27 +08:00
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
return Tensor(x / y)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.ReduceMean)
def vm_impl_reduce_mean(self):
"""Generate vm_impl function for ReduceMean."""
2020-05-13 11:30:27 +08:00
def vm_impl(x, axis):
x = x.asnumpy()
out = vm.mean(x, axis)
return Tensor(out)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.ReduceMax)
def vm_impl_reduce_max(self):
"""Generate vm_impl function for ReduceMean."""
def vm_impl(x, axis):
x = x.asnumpy()
if axis == ():
axis = None
out = np.amax(x, axis)
return Tensor(out)
return vm_impl
@vm_impl_getters.register(P.Equal)
def vm_impl_equal(self):
"""Generate vm_impl function for Equal."""
2020-05-13 11:30:27 +08:00
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
out = vm.equal(x, y)
2020-04-24 17:51:55 +08:00
return Tensor(np.array(out))
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.NotEqual)
def vm_impl_not_equal(self):
"""Generate vm_impl function for NotEqual."""
2020-05-13 11:30:27 +08:00
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
out = vm.not_equal(x, y)
2020-04-24 17:51:55 +08:00
return Tensor(np.array(out))
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Greater)
def vm_impl_greater(self):
"""Generate vm_impl function for Greater."""
2020-05-13 11:30:27 +08:00
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
out = vm.greater(x, y)
2020-04-24 17:51:55 +08:00
return Tensor(np.array(out))
2020-05-13 11:30:27 +08:00
return vm_impl
2020-05-13 11:30:27 +08:00
@vm_impl_getters.register(P.Maximum)
def vm_impl_maximum(self):
"""Generate vm_impl function for Maximum."""
2020-05-13 11:30:27 +08:00
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
out = vm.maximum(x, y)
return Tensor(out)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Minimum)
def vm_impl_minimum(self):
"""Generate vm_impl function for Minimum."""
2020-05-13 11:30:27 +08:00
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
out = vm.minimum(x, y)
return Tensor(out)
2020-05-13 11:30:27 +08:00
return vm_impl
2020-05-13 11:30:27 +08:00
@vm_impl_getters.register(P.Less)
2020-04-24 17:51:55 +08:00
def vm_impl_less(self):
"""Generate vm_impl function for Less"""
2020-05-13 11:30:27 +08:00
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
out = vm.less(x, y)
2020-04-24 17:51:55 +08:00
return Tensor(np.array(out))
2020-05-13 11:30:27 +08:00
return vm_impl
2020-05-13 11:30:27 +08:00
@vm_impl_getters.register(P.ScalarCast)
2020-04-24 17:51:55 +08:00
def vm_impl_scalar_cast(self):
"""Generate vm_impl function for ScalarCast"""
2020-05-13 11:30:27 +08:00
def vm_impl(x, t):
np_type = dtype_to_nptype(t)
value = np_type(x)
cast_value = value.item()
return cast_value
2020-05-13 11:30:27 +08:00
return vm_impl