Add assign op for cpu

This commit is contained in:
xiangyunwu21 2020-11-28 17:10:55 +08:00
parent 18f58f7db4
commit 9e38dfd28e
7 changed files with 598 additions and 40 deletions

View File

@ -16,67 +16,48 @@
#include "backend/kernel_compiler/cpu/assign_cpu_kernel.h"
#include <string>
#include <map>
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
static std::map<TypeId, size_t> input_x_dtype_size_map = {
{kNumberTypeBool, sizeof(bool)}, {kNumberTypeInt8, 1}, {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4},
{kNumberTypeInt64, 8}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2}, {kNumberTypeUInt32, 4},
{kNumberTypeUInt64, 8}, {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}};
void AssignCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto input_x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto input_y_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (size_t i = 0; i < input_x_shape.size(); ++i) {
batch_size_ *= input_x_shape[i];
}
if (input_x_shape.size() != input_y_shape.size()) MS_LOG(EXCEPTION) << "x y must be same shape";
for (size_t i = 0; i < input_x_shape.size(); ++i) {
if (input_x_shape[i] != input_y_shape[i]) {
MS_LOG(EXCEPTION) << "x y must be same shape";
}
batch_size_ *= input_x_shape[i];
}
input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
if (input_x_dtype_ == kNumberTypeFloat32 || input_x_dtype_ == kNumberTypeInt32) {
input_x_dtype_size_ = 4;
} else if (input_x_dtype_ == kNumberTypeFloat64 || input_x_dtype_ == kNumberTypeInt64) {
input_x_dtype_size_ = 8;
} else {
MS_LOG(EXCEPTION) << "input_x dtype only support float32, float64, int32, int64";
if (input_x_dtype_size_map.find(input_x_dtype_) == input_x_dtype_size_map.end()) {
MS_LOG(EXCEPTION) << "unsupported input_x dtype";
}
input_x_dtype_size_ = input_x_dtype_size_map[input_x_dtype_];
}
bool AssignCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
if (input_x_dtype_ == kNumberTypeInt32) {
LaunchKernel<int>(inputs, outputs);
} else if (input_x_dtype_ == kNumberTypeInt64) {
LaunchKernel<int64_t>(inputs, outputs);
} else if (input_x_dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
} else if (input_x_dtype_ == kNumberTypeFloat64) {
LaunchKernel<double>(inputs, outputs);
} else {
MS_LOG(ERROR) << "indices dtype only support float32, float64, int32, int64";
return false;
}
return true;
}
template <typename T>
void AssignCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
T *input_x = reinterpret_cast<T *>(inputs[0]->addr);
T *input_y = reinterpret_cast<T *>(inputs[1]->addr);
bool AssignCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /*workspace*/,
const std::vector<AddressPtr> &outputs) {
auto max_size = inputs[0]->size;
size_t total_size = input_x_dtype_size_ * batch_size_;
if (total_size > max_size) {
MS_LOG(EXCEPTION) << "Memcpy size must <= max_size, but got memcpy size is : " << total_size
<< ", max size is : " << max_size;
}
int ret = memcpy_s(input_x, total_size, input_y, total_size);
int ret = memcpy_s(inputs[0]->addr, total_size, inputs[1]->addr, total_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -34,15 +34,24 @@ class AssignCPUKernel : public CPUKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
private:
size_t batch_size_{1};
TypeId input_x_dtype_{kTypeUnknown};
size_t input_x_dtype_size_ = 4;
};
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
AssignCPUKernel);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
AssignCPUKernel);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
AssignCPUKernel);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
AssignCPUKernel);
@ -51,6 +60,27 @@ MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
AssignCPUKernel);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
AssignCPUKernel);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
AssignCPUKernel);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
AssignCPUKernel);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
AssignCPUKernel);
MS_REG_CPU_KERNEL(
Assign,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
AssignCPUKernel);
MS_REG_CPU_KERNEL(
Assign,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),

View File

@ -82,6 +82,31 @@ bool CastCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
mode_map[kNumberTypeBool][kNumberTypeFloat32] = LaunchCast<bool, float>;
mode_map[kNumberTypeBool][kNumberTypeBool] = LaunchCast<bool, bool>;
mode_map[kNumberTypeBool][kNumberTypeInt32] = LaunchCast<bool, int>;
mode_map[kNumberTypeInt8][kNumberTypeInt16] = LaunchCast<int8_t, int16_t>;
mode_map[kNumberTypeInt8][kNumberTypeInt32] = LaunchCast<int8_t, int32_t>;
mode_map[kNumberTypeInt8][kNumberTypeInt64] = LaunchCast<int8_t, int64_t>;
mode_map[kNumberTypeUInt8][kNumberTypeInt16] = LaunchCast<uint8_t, int16_t>;
mode_map[kNumberTypeUInt8][kNumberTypeInt32] = LaunchCast<uint8_t, int32_t>;
mode_map[kNumberTypeUInt8][kNumberTypeInt64] = LaunchCast<uint8_t, int64_t>;
mode_map[kNumberTypeUInt8][kNumberTypeUInt16] = LaunchCast<uint8_t, uint16_t>;
mode_map[kNumberTypeUInt8][kNumberTypeUInt32] = LaunchCast<uint8_t, uint32_t>;
mode_map[kNumberTypeUInt8][kNumberTypeUInt64] = LaunchCast<uint8_t, uint64_t>;
mode_map[kNumberTypeInt16][kNumberTypeInt32] = LaunchCast<int16_t, int32_t>;
mode_map[kNumberTypeInt16][kNumberTypeInt64] = LaunchCast<int16_t, int64_t>;
mode_map[kNumberTypeUInt16][kNumberTypeInt32] = LaunchCast<uint16_t, int32_t>;
mode_map[kNumberTypeUInt16][kNumberTypeInt64] = LaunchCast<uint16_t, int64_t>;
mode_map[kNumberTypeUInt16][kNumberTypeUInt32] = LaunchCast<uint16_t, uint32_t>;
mode_map[kNumberTypeUInt16][kNumberTypeUInt64] = LaunchCast<uint16_t, uint64_t>;
mode_map[kNumberTypeInt32][kNumberTypeInt64] = LaunchCast<int32_t, int64_t>;
mode_map[kNumberTypeUInt32][kNumberTypeInt64] = LaunchCast<uint32_t, int64_t>;
mode_map[kNumberTypeUInt32][kNumberTypeUInt64] = LaunchCast<uint32_t, uint64_t>;
mode_map[kNumberTypeFloat16][kNumberTypeFloat32] = LaunchCast<float16, float>;
mode_map[kNumberTypeFloat16][kNumberTypeFloat64] = LaunchCast<float16, double>;
mode_map[kNumberTypeFloat32][kNumberTypeFloat64] = LaunchCast<float, double>;
mode_map[source_dtype][target_dtype](inputs, outputs);
return true;
}

View File

@ -47,6 +47,31 @@ MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAtt
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -38,7 +38,7 @@ class Assign(PrimitiveWithCheck):
Tensor, has the same type as original `variable`.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> class Net(nn.Cell):
@ -66,9 +66,10 @@ class Assign(PrimitiveWithCheck):
self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
def check_dtype(self, variable, value):
types = mstype.number_type + (mstype.bool_,)
if variable != mstype.type_refkey:
validator.check_tensor_dtype_valid("variable", variable, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same({"value": value}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid("variable", variable, types, self.name)
validator.check_scalar_or_tensor_types_same({"value": value}, types, self.name)
class InplaceAssign(PrimitiveWithInfer):

View File

@ -0,0 +1,223 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class Assign(nn.Cell):
def __init__(self, x, y):
super(Assign, self).__init__()
self.x = Parameter(initializer(x, x.shape), name="x")
self.y = Parameter(initializer(y, y.shape), name="y")
self.assign = P.Assign()
def construct(self):
self.assign(self.y, self.x)
return self.y
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_assign_bool():
x = Tensor(np.ones([3, 3]).astype(np.bool_))
y = Tensor(np.zeros([3, 3]).astype(np.bool_))
assign = Assign(x, y)
output = assign()
output = output.asnumpy()
output_expect = np.ones([3, 3]).astype(np.bool_)
print(output)
assert np.all(output == output_expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_assign_int8():
x = Tensor(np.ones([3, 3]).astype(np.int8))
y = Tensor(np.zeros([3, 3]).astype(np.int8))
assign = Assign(x, y)
output = assign()
output = output.asnumpy()
output_expect = np.ones([3, 3]).astype(np.int8)
print(output)
assert np.all(output == output_expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_assign_uint8():
x = Tensor(np.ones([3, 3]).astype(np.uint8))
y = Tensor(np.zeros([3, 3]).astype(np.uint8))
assign = Assign(x, y)
output = assign()
output = output.asnumpy()
output_expect = np.ones([3, 3]).astype(np.uint8)
print(output)
assert np.all(output == output_expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_assign_int16():
x = Tensor(np.ones([3, 3]).astype(np.int16))
y = Tensor(np.zeros([3, 3]).astype(np.int16))
assign = Assign(x, y)
output = assign()
output = output.asnumpy()
output_expect = np.ones([3, 3]).astype(np.int16)
print(output)
assert np.all(output == output_expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_assign_uint16():
x = Tensor(np.ones([3, 3]).astype(np.uint16))
y = Tensor(np.zeros([3, 3]).astype(np.uint16))
assign = Assign(x, y)
output = assign()
output = output.asnumpy()
output_expect = np.ones([3, 3]).astype(np.uint16)
print(output)
assert np.all(output == output_expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_assign_int32():
x = Tensor(np.ones([3, 3]).astype(np.int32))
y = Tensor(np.zeros([3, 3]).astype(np.int32))
assign = Assign(x, y)
output = assign()
output = output.asnumpy()
output_expect = np.ones([3, 3]).astype(np.int32)
print(output)
assert np.all(output == output_expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_assign_uint32():
x = Tensor(np.ones([3, 3]).astype(np.uint32))
y = Tensor(np.zeros([3, 3]).astype(np.uint32))
assign = Assign(x, y)
output = assign()
output = output.asnumpy()
output_expect = np.ones([3, 3]).astype(np.uint32)
print(output)
assert np.all(output == output_expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_assign_int64():
x = Tensor(np.ones([3, 3]).astype(np.int64))
y = Tensor(np.zeros([3, 3]).astype(np.int64))
assign = Assign(x, y)
output = assign()
output = output.asnumpy()
output_expect = np.ones([3, 3]).astype(np.int64)
print(output)
assert np.all(output == output_expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_assign_uint64():
x = Tensor(np.ones([3, 3]).astype(np.uint64))
y = Tensor(np.zeros([3, 3]).astype(np.uint64))
assign = Assign(x, y)
output = assign()
output = output.asnumpy()
output_expect = np.ones([3, 3]).astype(np.uint64)
print(output)
assert np.all(output == output_expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_assign_float16():
x = Tensor(np.array([[0.1, 0.2, 0.3],
[0.4, 0.5, 0.5],
[0.6, 0.7, 0.8]]).astype(np.float16))
y = Tensor(np.array([[0.4, 0.5, 0.5],
[0.6, 0.7, 0.8],
[0.1, 0.2, 0.3]]).astype(np.float16))
assign = Assign(x, y)
output = assign()
output = output.asnumpy()
output_expect = np.array([[0.1, 0.2, 0.3],
[0.4, 0.5, 0.5],
[0.6, 0.7, 0.8]]).astype(np.float16)
print(output)
assert np.all(output - output_expect < 1e-6)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_assign_float32():
x = Tensor(np.array([[0.1, 0.2, 0.3],
[0.4, 0.5, 0.5],
[0.6, 0.7, 0.8]]).astype(np.float32))
y = Tensor(np.array([[0.4, 0.5, 0.5],
[0.6, 0.7, 0.8],
[0.1, 0.2, 0.3]]).astype(np.float32))
assign = Assign(x, y)
output = assign()
output = output.asnumpy()
output_expect = np.array([[0.1, 0.2, 0.3],
[0.4, 0.5, 0.5],
[0.6, 0.7, 0.8]]).astype(np.float32)
print(output)
assert np.all(output - output_expect < 1e-6)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_assign_float64():
x = Tensor(np.array([[0.1, 0.2, 0.3],
[0.4, 0.5, 0.5],
[0.6, 0.7, 0.8]]).astype(np.float64))
y = Tensor(np.array([[0.4, 0.5, 0.5],
[0.6, 0.7, 0.8],
[0.1, 0.2, 0.3]]).astype(np.float64))
assign = Assign(x, y)
output = assign()
output = output.asnumpy()
output_expect = np.array([[0.1, 0.2, 0.3],
[0.4, 0.5, 0.5],
[0.6, 0.7, 0.8]]).astype(np.float64)
print(output)
assert np.all(output - output_expect < 1e-6)

View File

@ -74,3 +74,276 @@ def test_cast_float32():
output = net(x2)
type2 = output.asnumpy().dtype
assert type2 == 'float32'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_int8_to_int16():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))
t = mstype.int16
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'int16'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_int8_to_int32():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))
t = mstype.int32
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'int32'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_int8_to_int64():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int8))
t = mstype.int64
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'int64'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_uint8_to_int16():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))
t = mstype.int16
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'int16'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_uint8_to_int32():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))
t = mstype.int32
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'int32'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_uint8_to_int64():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))
t = mstype.int64
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'int64'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_uint8_to_uint16():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))
t = mstype.uint16
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'uint16'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_uint8_to_uint32():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))
t = mstype.uint32
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'uint32'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_uint8_to_uint64():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint8))
t = mstype.uint64
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'uint64'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_int16_to_int32():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16))
t = mstype.int32
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'int32'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_int16_to_int64():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int16))
t = mstype.int64
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'int64'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_uint16_to_int32():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))
t = mstype.int32
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'int32'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_uint16_to_int64():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))
t = mstype.int64
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'int64'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_uint16_to_uint32():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))
t = mstype.uint32
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'uint32'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_uint16_to_uint64():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint16))
t = mstype.uint64
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'uint64'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_int32_to_int64():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.int32))
t = mstype.int64
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'int64'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_uint32_to_int64():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32))
t = mstype.int64
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'int64'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_uint32_to_uint64():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.uint32))
t = mstype.uint64
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'uint64'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_float16_to_float32():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16))
t = mstype.float32
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'float32'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_float16_to_float64():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float16))
t = mstype.float64
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'float64'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cast_float32_to_float64():
x = Tensor(np.random.uniform(-2, 2, (3, 2)).astype(np.float32))
t = mstype.float64
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net(t)
output = net(x)
dtype = output.asnumpy().dtype
assert dtype == 'float64'