2020-03-27 14:49:12 +08:00
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# ============================================================================
|
|
|
|
"""Generate vm_impl function for math ops"""
|
|
|
|
import copy
|
|
|
|
import numpy as np
|
2020-05-13 11:30:27 +08:00
|
|
|
|
|
|
|
from mindspore.common.dtype import dtype_to_nptype
|
2020-03-27 14:49:12 +08:00
|
|
|
from mindspore.common.tensor import Tensor
|
2020-05-13 11:30:27 +08:00
|
|
|
from mindspore.ops import operations as P
|
2020-03-27 14:49:12 +08:00
|
|
|
from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
|
|
|
|
from .vm_interface import vm
|
2020-05-13 11:30:27 +08:00
|
|
|
|
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
# pylint: disable=unused-argument
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.TensorAdd)
|
|
|
|
def vm_impl_tensor_add(self):
|
|
|
|
"""Generate vm_impl function for TensorAdd."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, y):
|
|
|
|
x = x.asnumpy()
|
|
|
|
y = y.asnumpy()
|
|
|
|
return Tensor(x + y)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-05-22 13:23:13 +08:00
|
|
|
# pylint: disable=used-before-assignment
|
2020-03-27 14:49:12 +08:00
|
|
|
@vm_impl_getters.register(P.LogicalNot)
|
|
|
|
def vm_impl_logical_not(self):
|
2020-08-18 20:17:15 +08:00
|
|
|
def vm_impl(x):
|
|
|
|
x = x.asnumpy()
|
|
|
|
out = vm.logical_not(x)
|
|
|
|
return Tensor(out)
|
2020-03-27 14:49:12 +08:00
|
|
|
|
2020-08-18 20:17:15 +08:00
|
|
|
return vm_impl
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
@vm_impl_getters.register(P.MatMul)
|
|
|
|
def vm_impl_mat_mul(self):
|
|
|
|
"""Generate vm_impl function for MatMul."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, w):
|
|
|
|
x = x.asnumpy()
|
|
|
|
w = w.asnumpy()
|
|
|
|
if self.transpose_a:
|
|
|
|
x = x.transpose()
|
|
|
|
if self.transpose_b:
|
|
|
|
w = w.transpose()
|
2020-05-13 11:30:27 +08:00
|
|
|
z = x @ w
|
2020-03-27 14:49:12 +08:00
|
|
|
return Tensor(z)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.AddN)
|
|
|
|
def vm_impl_addn(self):
|
|
|
|
"""Generate vm_impl function for AddN."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(inputs):
|
|
|
|
added = copy.deepcopy(inputs[0].asnumpy())
|
|
|
|
for x in inputs[1:]:
|
|
|
|
added += x.asnumpy()
|
|
|
|
return Tensor(added)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.Neg)
|
|
|
|
def vm_impl_neg(self):
|
|
|
|
"""Generate vm_impl function for Neg."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x):
|
|
|
|
x = x.asnumpy()
|
|
|
|
return Tensor(-x)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.Sub)
|
|
|
|
def vm_impl_Sub(self):
|
|
|
|
"""Generate vm_impl function for Sub."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, y):
|
|
|
|
x = x.asnumpy()
|
|
|
|
y = y.asnumpy()
|
|
|
|
return Tensor(x - y)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.Mul)
|
|
|
|
def vm_impl_mul(self):
|
|
|
|
"""Generate vm_impl function for Mul."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, y):
|
|
|
|
x = x.asnumpy()
|
|
|
|
y = y.asnumpy()
|
|
|
|
return Tensor(x * y)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.Square)
|
|
|
|
def vm_impl_square(self):
|
|
|
|
"""Generate vm_impl function for Square."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x):
|
|
|
|
x = x.asnumpy()
|
|
|
|
return Tensor(x * x)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.Sqrt)
|
|
|
|
def vm_impl_sqrt(self):
|
|
|
|
"""Generate vm_impl function for Sqrt."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x):
|
|
|
|
x = x.asnumpy()
|
|
|
|
res = vm.sqrt(x)
|
|
|
|
return Tensor(res)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.Pow)
|
|
|
|
def vm_impl_pow(self):
|
|
|
|
"""Generate vm_impl function for Pow."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, y):
|
|
|
|
x = x.asnumpy()
|
2020-04-08 16:24:51 +08:00
|
|
|
y = y.asnumpy()
|
2020-03-27 14:49:12 +08:00
|
|
|
res = vm.power(x, y)
|
|
|
|
return Tensor(res)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.Exp)
|
|
|
|
def vm_impl_exp(self):
|
|
|
|
"""Generate vm_impl function for Exp."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x):
|
|
|
|
x = x.asnumpy()
|
|
|
|
res = vm.exp(x)
|
|
|
|
return Tensor(res)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.RealDiv)
|
|
|
|
def vm_impl_real_div(self):
|
|
|
|
"""Generate vm_impl function for RealDiv."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, y):
|
|
|
|
x = x.asnumpy()
|
|
|
|
y = y.asnumpy()
|
|
|
|
out = x / y
|
|
|
|
out = np.array(out, x.dtype)
|
|
|
|
return Tensor(out)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.Div)
|
|
|
|
def vm_impl_div(self):
|
|
|
|
"""Generate vm_impl function for Div."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, y):
|
|
|
|
x = x.asnumpy()
|
|
|
|
y = y.asnumpy()
|
|
|
|
return Tensor(x / y)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.ReduceMean)
|
|
|
|
def vm_impl_reduce_mean(self):
|
|
|
|
"""Generate vm_impl function for ReduceMean."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, axis):
|
|
|
|
x = x.asnumpy()
|
|
|
|
out = vm.mean(x, axis)
|
|
|
|
return Tensor(out)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
2020-07-09 08:29:07 +08:00
|
|
|
@vm_impl_getters.register(P.ReduceMax)
|
|
|
|
def vm_impl_reduce_max(self):
|
|
|
|
"""Generate vm_impl function for ReduceMean."""
|
|
|
|
|
|
|
|
def vm_impl(x, axis):
|
|
|
|
x = x.asnumpy()
|
|
|
|
if axis == ():
|
|
|
|
axis = None
|
|
|
|
out = np.amax(x, axis)
|
|
|
|
return Tensor(out)
|
|
|
|
|
|
|
|
return vm_impl
|
2020-03-27 14:49:12 +08:00
|
|
|
|
|
|
|
@vm_impl_getters.register(P.Equal)
|
|
|
|
def vm_impl_equal(self):
|
|
|
|
"""Generate vm_impl function for Equal."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, y):
|
|
|
|
x = x.asnumpy()
|
|
|
|
y = y.asnumpy()
|
|
|
|
out = vm.equal(x, y)
|
2020-04-24 17:51:55 +08:00
|
|
|
return Tensor(np.array(out))
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.NotEqual)
|
|
|
|
def vm_impl_not_equal(self):
|
|
|
|
"""Generate vm_impl function for NotEqual."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, y):
|
|
|
|
x = x.asnumpy()
|
|
|
|
y = y.asnumpy()
|
|
|
|
out = vm.not_equal(x, y)
|
2020-04-24 17:51:55 +08:00
|
|
|
return Tensor(np.array(out))
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.Greater)
|
|
|
|
def vm_impl_greater(self):
|
|
|
|
"""Generate vm_impl function for Greater."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, y):
|
|
|
|
x = x.asnumpy()
|
|
|
|
y = y.asnumpy()
|
|
|
|
out = vm.greater(x, y)
|
2020-04-24 17:51:55 +08:00
|
|
|
return Tensor(np.array(out))
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
@vm_impl_getters.register(P.Maximum)
|
|
|
|
def vm_impl_maximum(self):
|
|
|
|
"""Generate vm_impl function for Maximum."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, y):
|
|
|
|
x = x.asnumpy()
|
|
|
|
y = y.asnumpy()
|
|
|
|
out = vm.maximum(x, y)
|
|
|
|
return Tensor(out)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
|
|
|
|
|
|
|
@vm_impl_getters.register(P.Minimum)
|
|
|
|
def vm_impl_minimum(self):
|
|
|
|
"""Generate vm_impl function for Minimum."""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, y):
|
|
|
|
x = x.asnumpy()
|
|
|
|
y = y.asnumpy()
|
|
|
|
out = vm.minimum(x, y)
|
|
|
|
return Tensor(out)
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
@vm_impl_getters.register(P.Less)
|
2020-04-24 17:51:55 +08:00
|
|
|
def vm_impl_less(self):
|
2020-03-27 14:49:12 +08:00
|
|
|
"""Generate vm_impl function for Less"""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, y):
|
|
|
|
x = x.asnumpy()
|
|
|
|
y = y.asnumpy()
|
|
|
|
out = vm.less(x, y)
|
2020-04-24 17:51:55 +08:00
|
|
|
return Tensor(np.array(out))
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|
|
|
|
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
@vm_impl_getters.register(P.ScalarCast)
|
2020-04-24 17:51:55 +08:00
|
|
|
def vm_impl_scalar_cast(self):
|
2020-03-27 14:49:12 +08:00
|
|
|
"""Generate vm_impl function for ScalarCast"""
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
def vm_impl(x, t):
|
|
|
|
np_type = dtype_to_nptype(t)
|
|
|
|
value = np_type(x)
|
|
|
|
cast_value = value.item()
|
|
|
|
return cast_value
|
2020-05-13 11:30:27 +08:00
|
|
|
|
2020-03-27 14:49:12 +08:00
|
|
|
return vm_impl
|