From 9e38dfd28e7a717d36b6ac38e4757026ab791b9f Mon Sep 17 00:00:00 2001 From: xiangyunwu21 Date: Sat, 28 Nov 2020 17:10:55 +0800 Subject: [PATCH] Add assign op for cpu --- .../kernel_compiler/cpu/assign_cpu_kernel.cc | 49 +--- .../kernel_compiler/cpu/assign_cpu_kernel.h | 36 ++- .../kernel_compiler/cpu/cast_cpu_kernel.cc | 25 ++ .../kernel_compiler/cpu/cast_cpu_kernel.h | 25 ++ mindspore/ops/operations/other_ops.py | 7 +- tests/st/ops/cpu/test_assign_op.py | 223 ++++++++++++++ tests/st/ops/cpu/test_cast_op.py | 273 ++++++++++++++++++ 7 files changed, 598 insertions(+), 40 deletions(-) create mode 100644 tests/st/ops/cpu/test_assign_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.cc index 54e4daa5373..a1c4cd115bf 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.cc @@ -16,67 +16,48 @@ #include "backend/kernel_compiler/cpu/assign_cpu_kernel.h" #include +#include #include "runtime/device/cpu/cpu_device_address.h" namespace mindspore { namespace kernel { +static std::map input_x_dtype_size_map = { + {kNumberTypeBool, sizeof(bool)}, {kNumberTypeInt8, 1}, {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, + {kNumberTypeInt64, 8}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2}, {kNumberTypeUInt32, 4}, + {kNumberTypeUInt64, 8}, {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}}; + void AssignCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); auto input_x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto input_y_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - for (size_t i = 0; i < input_x_shape.size(); ++i) { - batch_size_ *= input_x_shape[i]; - } - if (input_x_shape.size() != input_y_shape.size()) MS_LOG(EXCEPTION) << "x y must be same shape"; for (size_t i = 0; i < input_x_shape.size(); ++i) { if (input_x_shape[i] != input_y_shape[i]) { MS_LOG(EXCEPTION) << "x y must be same shape"; } + batch_size_ *= input_x_shape[i]; } input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); - if (input_x_dtype_ == kNumberTypeFloat32 || input_x_dtype_ == kNumberTypeInt32) { - input_x_dtype_size_ = 4; - } else if (input_x_dtype_ == kNumberTypeFloat64 || input_x_dtype_ == kNumberTypeInt64) { - input_x_dtype_size_ = 8; - } else { - MS_LOG(EXCEPTION) << "input_x dtype only support float32, float64, int32, int64"; + if (input_x_dtype_size_map.find(input_x_dtype_) == input_x_dtype_size_map.end()) { + MS_LOG(EXCEPTION) << "unsupported input_x dtype"; } + input_x_dtype_size_ = input_x_dtype_size_map[input_x_dtype_]; } -bool AssignCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (input_x_dtype_ == kNumberTypeInt32) { - LaunchKernel(inputs, outputs); - } else if (input_x_dtype_ == kNumberTypeInt64) { - LaunchKernel(inputs, outputs); - } else if (input_x_dtype_ == kNumberTypeFloat32) { - LaunchKernel(inputs, outputs); - } else if (input_x_dtype_ == kNumberTypeFloat64) { - LaunchKernel(inputs, outputs); - } else { - MS_LOG(ERROR) << "indices dtype only support float32, float64, int32, int64"; - return false; - } - return true; -} - -template -void AssignCPUKernel::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - T *input_x = reinterpret_cast(inputs[0]->addr); - T *input_y = reinterpret_cast(inputs[1]->addr); +bool AssignCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs) { auto max_size = inputs[0]->size; size_t total_size = input_x_dtype_size_ * batch_size_; if (total_size > max_size) { MS_LOG(EXCEPTION) << "Memcpy size must <= max_size, but got memcpy size is : " << total_size << ", max size is : " << max_size; } - int ret = memcpy_s(input_x, total_size, input_y, total_size); + int ret = memcpy_s(inputs[0]->addr, total_size, inputs[1]->addr, total_size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; } + return true; } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.h index b58b4f17566..4f1f34192b2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.h @@ -34,15 +34,24 @@ class AssignCPUKernel : public CPUKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - private: size_t batch_size_{1}; TypeId input_x_dtype_{kTypeUnknown}; size_t input_x_dtype_size_ = 4; }; +MS_REG_CPU_KERNEL( + Assign, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + AssignCPUKernel); + +MS_REG_CPU_KERNEL( + Assign, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + AssignCPUKernel); + +MS_REG_CPU_KERNEL( + Assign, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + AssignCPUKernel); + MS_REG_CPU_KERNEL( Assign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), AssignCPUKernel); @@ -51,6 +60,27 @@ MS_REG_CPU_KERNEL( Assign, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), AssignCPUKernel); +MS_REG_CPU_KERNEL( + Assign, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + AssignCPUKernel); + +MS_REG_CPU_KERNEL( + Assign, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + AssignCPUKernel); + +MS_REG_CPU_KERNEL( + Assign, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + AssignCPUKernel); + +MS_REG_CPU_KERNEL( + Assign, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + AssignCPUKernel); + +MS_REG_CPU_KERNEL( + Assign, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + AssignCPUKernel); + MS_REG_CPU_KERNEL( Assign, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc index d3f2d22b31c..acf0a02bf83 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc @@ -82,6 +82,31 @@ bool CastCPUKernel::Launch(const std::vector &inputs, mode_map[kNumberTypeBool][kNumberTypeFloat32] = LaunchCast; mode_map[kNumberTypeBool][kNumberTypeBool] = LaunchCast; mode_map[kNumberTypeBool][kNumberTypeInt32] = LaunchCast; + + mode_map[kNumberTypeInt8][kNumberTypeInt16] = LaunchCast; + mode_map[kNumberTypeInt8][kNumberTypeInt32] = LaunchCast; + mode_map[kNumberTypeInt8][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeUInt8][kNumberTypeInt16] = LaunchCast; + mode_map[kNumberTypeUInt8][kNumberTypeInt32] = LaunchCast; + mode_map[kNumberTypeUInt8][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeUInt8][kNumberTypeUInt16] = LaunchCast; + mode_map[kNumberTypeUInt8][kNumberTypeUInt32] = LaunchCast; + mode_map[kNumberTypeUInt8][kNumberTypeUInt64] = LaunchCast; + + mode_map[kNumberTypeInt16][kNumberTypeInt32] = LaunchCast; + mode_map[kNumberTypeInt16][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeUInt16][kNumberTypeInt32] = LaunchCast; + mode_map[kNumberTypeUInt16][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeUInt16][kNumberTypeUInt32] = LaunchCast; + mode_map[kNumberTypeUInt16][kNumberTypeUInt64] = LaunchCast; + + mode_map[kNumberTypeInt32][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeUInt32][kNumberTypeInt64] = LaunchCast; + mode_map[kNumberTypeUInt32][kNumberTypeUInt64] = LaunchCast; + + mode_map[kNumberTypeFloat16][kNumberTypeFloat32] = LaunchCast; + mode_map[kNumberTypeFloat16][kNumberTypeFloat64] = LaunchCast; + mode_map[kNumberTypeFloat32][kNumberTypeFloat64] = LaunchCast; mode_map[source_dtype][target_dtype](inputs, outputs); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.h index 996086a5716..157400ad5e0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.h @@ -47,6 +47,31 @@ MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAtt MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); + +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); + +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); + +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); + +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); +MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 88dff291bb6..4eb54804f5d 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -38,7 +38,7 @@ class Assign(PrimitiveWithCheck): Tensor, has the same type as original `variable`. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> class Net(nn.Cell): @@ -66,9 +66,10 @@ class Assign(PrimitiveWithCheck): self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) def check_dtype(self, variable, value): + types = mstype.number_type + (mstype.bool_,) if variable != mstype.type_refkey: - validator.check_tensor_dtype_valid("variable", variable, mstype.number_type, self.name) - validator.check_scalar_or_tensor_types_same({"value": value}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid("variable", variable, types, self.name) + validator.check_scalar_or_tensor_types_same({"value": value}, types, self.name) class InplaceAssign(PrimitiveWithInfer): diff --git a/tests/st/ops/cpu/test_assign_op.py b/tests/st/ops/cpu/test_assign_op.py new file mode 100644 index 00000000000..9648496ff81 --- /dev/null +++ b/tests/st/ops/cpu/test_assign_op.py @@ -0,0 +1,223 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore.common.initializer import initializer +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class Assign(nn.Cell): + def __init__(self, x, y): + super(Assign, self).__init__() + self.x = Parameter(initializer(x, x.shape), name="x") + self.y = Parameter(initializer(y, y.shape), name="y") + self.assign = P.Assign() + + def construct(self): + self.assign(self.y, self.x) + return self.y + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_assign_bool(): + x = Tensor(np.ones([3, 3]).astype(np.bool_)) + y = Tensor(np.zeros([3, 3]).astype(np.bool_)) + assign = Assign(x, y) + output = assign() + output = output.asnumpy() + output_expect = np.ones([3, 3]).astype(np.bool_) + print(output) + assert np.all(output == output_expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_assign_int8(): + x = Tensor(np.ones([3, 3]).astype(np.int8)) + y = Tensor(np.zeros([3, 3]).astype(np.int8)) + assign = Assign(x, y) + output = assign() + output = output.asnumpy() + output_expect = np.ones([3, 3]).astype(np.int8) + print(output) + assert np.all(output == output_expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_assign_uint8(): + x = Tensor(np.ones([3, 3]).astype(np.uint8)) + y = Tensor(np.zeros([3, 3]).astype(np.uint8)) + assign = Assign(x, y) + output = assign() + output = output.asnumpy() + output_expect = np.ones([3, 3]).astype(np.uint8) + print(output) + assert np.all(output == output_expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_assign_int16(): + x = Tensor(np.ones([3, 3]).astype(np.int16)) + y = Tensor(np.zeros([3, 3]).astype(np.int16)) + assign = Assign(x, y) + output = assign() + output = output.asnumpy() + output_expect = np.ones([3, 3]).astype(np.int16) + print(output) + assert np.all(output == output_expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_assign_uint16(): + x = Tensor(np.ones([3, 3]).astype(np.uint16)) + y = Tensor(np.zeros([3, 3]).astype(np.uint16)) + assign = Assign(x, y) + output = assign() + output = output.asnumpy() + output_expect = np.ones([3, 3]).astype(np.uint16) + print(output) + assert np.all(output == output_expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_assign_int32(): + x = Tensor(np.ones([3, 3]).astype(np.int32)) + y = Tensor(np.zeros([3, 3]).astype(np.int32)) + assign = Assign(x, y) + output = assign() + output = output.asnumpy() + output_expect = np.ones([3, 3]).astype(np.int32) + print(output) + assert np.all(output == output_expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_assign_uint32(): + x = Tensor(np.ones([3, 3]).astype(np.uint32)) + y = Tensor(np.zeros([3, 3]).astype(np.uint32)) + assign = Assign(x, y) + output = assign() + output = output.asnumpy() + output_expect = np.ones([3, 3]).astype(np.uint32) + print(output) + assert np.all(output == output_expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_assign_int64(): + x = Tensor(np.ones([3, 3]).astype(np.int64)) + y = Tensor(np.zeros([3, 3]).astype(np.int64)) + assign = Assign(x, y) + output = assign() + output = output.asnumpy() + output_expect = np.ones([3, 3]).astype(np.int64) + print(output) + assert np.all(output == output_expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_assign_uint64(): + x = Tensor(np.ones([3, 3]).astype(np.uint64)) + y = Tensor(np.zeros([3, 3]).astype(np.uint64)) + assign = Assign(x, y) + output = assign() + output = output.asnumpy() + output_expect = np.ones([3, 3]).astype(np.uint64) + print(output) + assert np.all(output == output_expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_assign_float16(): + x = Tensor(np.array([[0.1, 0.2, 0.3], + [0.4, 0.5, 0.5], + [0.6, 0.7, 0.8]]).astype(np.float16)) + y = Tensor(np.array([[0.4, 0.5, 0.5], + [0.6, 0.7, 0.8], + [0.1, 0.2, 0.3]]).astype(np.float16)) + assign = Assign(x, y) + output = assign() + output = output.asnumpy() + output_expect = np.array([[0.1, 0.2, 0.3], + [0.4, 0.5, 0.5], + [0.6, 0.7, 0.8]]).astype(np.float16) + print(output) + assert np.all(output - output_expect < 1e-6) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_assign_float32(): + x = Tensor(np.array([[0.1, 0.2, 0.3], + [0.4, 0.5, 0.5], + [0.6, 0.7, 0.8]]).astype(np.float32)) + y = Tensor(np.array([[0.4, 0.5, 0.5], + [0.6, 0.7, 0.8], + [0.1, 0.2, 0.3]]).astype(np.float32)) + assign = Assign(x, y) + output = assign() + output = output.asnumpy() + output_expect = np.array([[0.1, 0.2, 0.3], + [0.4, 0.5, 0.5], + [0.6, 0.7, 0.8]]).astype(np.float32) + print(output) + assert np.all(output - output_expect < 1e-6) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_assign_float64(): + x = Tensor(np.array([[0.1, 0.2, 0.3], + [0.4, 0.5, 0.5], + [0.6, 0.7, 0.8]]).astype(np.float64)) + y = Tensor(np.array([[0.4, 0.5, 0.5], + [0.6, 0.7, 0.8], + [0.1, 0.2, 0.3]]).astype(np.float64)) + assign = Assign(x, y) + output = assign() + output = output.asnumpy() + output_expect = np.array([[0.1, 0.2, 0.3], + [0.4, 0.5, 0.5], + [0.6, 0.7, 0.8]]).astype(np.float64) + print(output) + assert np.all(output - output_expect < 1e-6) diff --git a/tests/st/ops/cpu/test_cast_op.py b/tests/st/ops/cpu/test_cast_op.py index b75110b2639..13c36ee0df8 100644 --- a/tests/st/ops/cpu/test_cast_op.py +++ b/tests/st/ops/cpu/test_cast_op.py @@ -74,3 +74,276 @@ def test_cast_float32(): output = net(x2) type2 = output.asnumpy().dtype assert type2 == 'float32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_int8_to_int16(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8)) + t = mstype.int16 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'int16' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_int8_to_int32(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8)) + t = mstype.int32 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'int32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_int8_to_int64(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8)) + t = mstype.int64 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'int64' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_uint8_to_int16(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8)) + t = mstype.int16 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'int16' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_uint8_to_int32(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8)) + t = mstype.int32 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'int32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_uint8_to_int64(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8)) + t = mstype.int64 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'int64' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_uint8_to_uint16(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8)) + t = mstype.uint16 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'uint16' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_uint8_to_uint32(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8)) + t = mstype.uint32 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'uint32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_uint8_to_uint64(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8)) + t = mstype.uint64 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'uint64' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_int16_to_int32(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16)) + t = mstype.int32 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'int32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_int16_to_int64(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16)) + t = mstype.int64 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'int64' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_uint16_to_int32(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16)) + t = mstype.int32 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'int32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_uint16_to_int64(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16)) + t = mstype.int64 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'int64' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_uint16_to_uint32(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16)) + t = mstype.uint32 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'uint32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_uint16_to_uint64(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16)) + t = mstype.uint64 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'uint64' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_int32_to_int64(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32)) + t = mstype.int64 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'int64' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_uint32_to_int64(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32)) + t = mstype.int64 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'int64' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_uint32_to_uint64(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32)) + t = mstype.uint64 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'uint64' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_float16_to_float32(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16)) + t = mstype.float32 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'float32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_float16_to_float64(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16)) + t = mstype.float64 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'float64' + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cast_float32_to_float64(): + x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32)) + t = mstype.float64 + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net(t) + output = net(x) + dtype = output.asnumpy().dtype + assert dtype == 'float64'