mindspore/tests/vm_impl/array_ops_vm_impl.py

344 lines
8.0 KiB
Python
Raw Normal View History

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate vm_impl function for array ops"""
import numpy as np
import mindspore.common.dtype as mstype
2020-05-13 11:30:27 +08:00
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
2020-05-26 09:14:40 +08:00
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
from .vm_interface import vm
2020-05-13 11:30:27 +08:00
# pylint: disable=unused-argument
@vm_impl_getters.register(P.Assign)
def vm_impl_assign(self):
"""Generate vm_impl function for Assign"""
def vm_impl(x, value, u=None):
x.assign_value(value)
return x
return vm_impl
@vm_impl_getters.register(P.ExpandDims)
def vm_impl_expand_dims(self):
"""Generate vm_impl function for ExpandDims"""
2020-05-13 11:30:27 +08:00
def vm_impl(x, axis):
if isinstance(x, float):
x = Tensor(np.array([x]))
x = x.asnumpy()
out = vm.expand_dims(x, axis)
return Tensor(out)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.DType)
def vm_impl_dType(self):
"""Generate vm_impl function for DType"""
2020-05-13 11:30:27 +08:00
def vm_impl(x):
# update the src type
return x.dtype
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Cast)
def vm_impl_cast(self):
"""Generate vm_impl function for Cast"""
2020-05-13 11:30:27 +08:00
def vm_impl(x, t):
if isinstance(t, type(mstype.tensor)):
t = t.element_type()
# update the src type
x = x.asnumpy()
out = x.astype(mstype.dtype_to_nptype(t))
return Tensor(out)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Reshape)
def vm_impl_reshape(self):
"""Generate vm_impl function for Reshape"""
2020-05-13 11:30:27 +08:00
def vm_impl(x, shp):
x = x.asnumpy()
out = vm.reshape(x, shp)
return Tensor(out)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Shape)
def vm_impl_shape(self):
"""Generate vm_impl function for Shape"""
2020-05-13 11:30:27 +08:00
def vm_impl(x):
shp = vm.shape(x.asnumpy())
return shp
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Squeeze)
def vm_impl_squeeze(self):
"""Generate vm_impl function for Squeeze"""
2020-05-13 11:30:27 +08:00
def vm_impl(x):
x = x.asnumpy()
out = vm.squeeze(x, self.axis)
return Tensor(out)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Transpose)
def vm_impl_transpose(self):
"""Generate vm_impl function for Transpose"""
2020-05-13 11:30:27 +08:00
def vm_impl(x, perm=None):
x = x.asnumpy()
if perm is None:
perm = [i for i in reversed(range(len(x.shape)))]
out = vm.transpose(x, perm)
return Tensor(out)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Split)
def vm_impl_split(self):
"""Generate vm_impl function for Split"""
2020-05-13 11:30:27 +08:00
def vm_impl(x):
x = x.asnumpy()
output = np.array_split(x, (self.pos,))
return Tensor(output[0]), Tensor(output[1])
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Fill)
def vm_impl_fill(self):
"""Generate vm_impl function for Fill"""
2020-05-13 11:30:27 +08:00
def vm_impl(dims, x):
if isinstance(x, int):
ret = np.full(dims, x, np.int32)
else:
ret = np.full(dims, x, np.float32)
return Tensor(ret)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Eye)
def vm_impl_eye(self):
"""Generate vm_impl function for Eye"""
2020-05-13 11:30:27 +08:00
def vm_impl(n, m, t):
np_type = mstype.dtype_to_nptype(t)
ret = np.eye(n, m, dtype=np_type)
return Tensor(ret)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.InvertPermutation)
def vm_impl_invert_permutation(self):
"""Generate vm_impl function for InvertPermutation"""
2020-05-13 11:30:27 +08:00
def vm_impl(x):
out = vm.invert_permutation(x)
return out
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Argmax)
def vm_impl_argmax(self):
"""Generate vm_impl function for Argmax"""
2020-05-13 11:30:27 +08:00
def vm_impl(x):
output = np.argmax(x.asnumpy(), axis=self.axis)
return Tensor(output.ravel())
2020-05-13 11:30:27 +08:00
return vm_impl
2020-05-13 11:30:27 +08:00
@vm_impl_getters.register(P.Tile)
def vm_impl_tile(self):
"""Generate vm_impl function for Tile"""
2020-05-13 11:30:27 +08:00
def vm_impl(x, multiples):
x = x.asnumpy()
out = np.tile(x, multiples)
return Tensor(out)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.ReduceAll)
def vm_impl_all(self):
"""Generate vm_impl function for All"""
2020-05-13 11:30:27 +08:00
def vm_impl(x, axis):
x = x.asnumpy()
out = vm.all(x, axis, self.keep_dims)
return Tensor(out)
return vm_impl
@vm_impl_getters.register(P.ReduceAny)
def vm_impl_any(self):
"""Generate vm_impl function for Any"""
def vm_impl(x, axis):
x = x.asnumpy()
out = vm.any(x, axis, self.keep_dims)
return Tensor(out)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Concat)
def vm_impl_concatV2(self):
"""Generate vm_impl function for Concat"""
2020-05-13 11:30:27 +08:00
def vm_impl(x):
x = x.asnumpy()
out = vm.Concat(x, self.axis)
return Tensor(out)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Slice)
def vm_impl_slice(self):
"""Generate vm_impl function for Slice"""
2020-05-13 11:30:27 +08:00
def vm_impl(x, begin, size):
x = x.asnumpy()
begin = begin.asnumpy()
size = size.asnumpy()
out = vm.Slice(x, begin, size)
return Tensor(out)
2020-05-13 11:30:27 +08:00
return vm_impl
2020-05-26 09:14:40 +08:00
@vm_impl_getters.register(G.ConcatOffset)
def vm_impl_concatOffset(self):
"""Generate vm_impl function for ConcatOffset"""
2020-05-13 11:30:27 +08:00
def vm_impl(x):
2020-05-13 11:30:27 +08:00
out = vm.ConcatOffset(x) # out is tuple
return out
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.ReduceSum)
def vm_impl_sum(self):
"""Generate vm_impl function for Sum"""
2020-05-13 11:30:27 +08:00
def vm_impl(x, axis):
x = x.asnumpy()
if axis == ():
out = np.sum(x)
else:
out = np.sum(x, axis=axis)
return Tensor(np.array(out))
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Select)
def vm_impl_select(self):
"""Generate vm_impl function for Select"""
2020-05-13 11:30:27 +08:00
def vm_impl(cond, x, y):
"""
Args:
cond: A `Tensor` of type `bool`
x: A Tensor which may have the same shape as `condition`.
y: A `Tensor` with the same shape and type as `x`.
"""
cond = cond.asnumpy()
x = x.asnumpy()
y = y.asnumpy()
out = vm.select(cond, x, y)
return Tensor(out)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.Square)
def vm_impl_square(self):
"""Generate vm_impl function for Square"""
2020-05-13 11:30:27 +08:00
def vm_impl(x):
x = x.asnumpy()
return Tensor(x * x)
2020-05-13 11:30:27 +08:00
return vm_impl
@vm_impl_getters.register(P.ZerosLike)
def vm_impl_zeros_like(self):
"""Generate vm_impl function for ZerosLike"""
def vm_impl(x):
return Tensor(np.zeros_like(x.asnumpy()))
2020-06-09 17:47:16 +08:00
2020-06-09 17:47:16 +08:00
@vm_impl_getters.register(P.Partial)
def vm_impl_partial(self):
"""Generate vm_impl function for Partial"""
def vm_impl(*args):
func = args[0].__call__
partial_func = functools.partial(func, *args[1:])
return partial_func
return vm_impl
2020-06-09 17:47:16 +08:00
@vm_impl_getters.register(P.Depend)
def vm_impl_depend(self):
"""Generate vm_impl function for Depend"""
def vm_impl(value, expr):
return value
return vm_impl
@vm_impl_getters.register(P.UpdateState)
def vm_impl_updatestate(self):
"""Generate vm_impl function for UpdateState"""
def vm_impl(monad, expr):
return monad
return vm_impl
@vm_impl_getters.register(P.Load)
def vm_impl_load(self):
"""Generate vm_impl function for Load"""
def vm_impl(value, u=None):
return value
return vm_impl