forked from mindspore-Ecosystem/mindspore
!7525 fix cpu type exception
Merge pull request !7525 from baihuawei/cpu_type_exception
This commit is contained in:
commit
169bf59fd4
|
@ -51,9 +51,9 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
bool ArithmeticSelfCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeFloat32) {
|
||||
if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt32) {
|
||||
} else if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt64) {
|
||||
LaunchKernel<int>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Only support float32, int32, but actual data type is " << TypeIdLabel(dtype_);
|
||||
|
|
|
@ -49,5 +49,37 @@ void FloatToDouble(void *dst, const void *src, size_t elem_num) {
|
|||
double_data[i] = static_cast<double>(float_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void ShortToInt(void *dst, const void *src, size_t elem_num) {
|
||||
auto half_data = static_cast<const int16_t *>(src);
|
||||
auto int_data = static_cast<int *>(dst);
|
||||
for (size_t i = 0; i < elem_num; ++i) {
|
||||
int_data[i] = static_cast<int>(half_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void IntToShort(void *dst, const void *src, size_t elem_num) {
|
||||
auto int_data = static_cast<const int *>(src);
|
||||
auto half_data = static_cast<int16_t *>(dst);
|
||||
for (size_t i = 0; i < elem_num; ++i) {
|
||||
half_data[i] = static_cast<int16_t>(int_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void LongToInt(void *dst, const void *src, size_t elem_num) {
|
||||
auto long_data = static_cast<const int64_t *>(src);
|
||||
auto int_data = static_cast<int *>(dst);
|
||||
for (size_t i = 0; i < elem_num; ++i) {
|
||||
int_data[i] = static_cast<int>(long_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void IntToLong(void *dst, const void *src, size_t elem_num) {
|
||||
auto int_data = static_cast<const int *>(src);
|
||||
auto long_data = static_cast<int64_t *>(dst);
|
||||
for (size_t i = 0; i < elem_num; ++i) {
|
||||
long_data[i] = static_cast<int64_t>(int_data[i]);
|
||||
}
|
||||
}
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,6 +27,10 @@ void HalfToFloat(void *dst, const void *src, size_t elem_num);
|
|||
void FloatToHalf(void *dst, const void *src, size_t elem_num);
|
||||
void DoubleToFloat(void *dst, const void *src, size_t elem_num);
|
||||
void FloatToDouble(void *dst, const void *src, size_t elem_num);
|
||||
void ShortToInt(void *dst, const void *src, size_t elem_num);
|
||||
void IntToShort(void *dst, const void *src, size_t elem_num);
|
||||
void LongToInt(void *dst, const void *src, size_t elem_num);
|
||||
void IntToLong(void *dst, const void *src, size_t elem_num);
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -25,22 +25,24 @@ bool CPUDeviceAddress::SyncDeviceToHost(const ShapeVector & /*shape*/, size_t si
|
|||
MS_LOG(ERROR) << "The pointer ptr_ is null!";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (host_ptr == ptr_) {
|
||||
MS_LOG(DEBUG) << "host_ptr is equal to ptr_, request ignored.";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (type == type_id_) {
|
||||
auto ret_code = memcpy_s(host_ptr, size, ptr_, size_);
|
||||
if (ret_code != EOK) {
|
||||
MS_LOG(ERROR) << "Failed to copy tensor!";
|
||||
return false;
|
||||
}
|
||||
} else if (type == kNumberTypeFloat16) {
|
||||
} else if (type == kNumberTypeFloat16 && type_id_ == kNumberTypeFloat32) {
|
||||
FloatToHalf(host_ptr, ptr_, size / 2);
|
||||
} else if (type == kNumberTypeFloat64) {
|
||||
} else if (type == kNumberTypeFloat64 && type_id_ == kNumberTypeFloat32) {
|
||||
FloatToDouble(host_ptr, ptr_, size / sizeof(double));
|
||||
} else if (type == kNumberTypeInt16 && type_id_ == kNumberTypeInt32) {
|
||||
IntToShort(host_ptr, ptr_, size / 2);
|
||||
} else if (type == kNumberTypeInt64 && type_id_ == kNumberTypeInt32) {
|
||||
IntToLong(host_ptr, ptr_, size / sizeof(int64_t));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Types not match. Device type: " << TypeIdLabel(type_id_) << ", host type: " << TypeIdLabel(type)
|
||||
<< "!";
|
||||
|
@ -51,15 +53,26 @@ bool CPUDeviceAddress::SyncDeviceToHost(const ShapeVector & /*shape*/, size_t si
|
|||
|
||||
bool CPUDeviceAddress::SyncHostToDevice(const ShapeVector & /*shape*/, size_t size, TypeId type,
|
||||
const void *host_ptr) const {
|
||||
if (ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The pointer ptr_ is null!";
|
||||
return false;
|
||||
}
|
||||
if (host_ptr == ptr_) {
|
||||
MS_LOG(DEBUG) << "host_ptr is equal to ptr_, request ignored.";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (type == kNumberTypeFloat16) {
|
||||
if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat16) {
|
||||
HalfToFloat(ptr_, host_ptr, size / 2);
|
||||
} else if (type == kNumberTypeFloat64) {
|
||||
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) {
|
||||
DoubleToFloat(ptr_, host_ptr, size / sizeof(double));
|
||||
} else if (type_id_ == kNumberTypeInt32 && type == kNumberTypeInt16) {
|
||||
ShortToInt(ptr_, host_ptr, size / 2);
|
||||
} else if (type_id_ == kNumberTypeInt32 && type == kNumberTypeInt64) {
|
||||
LongToInt(ptr_, host_ptr, size / sizeof(int64_t));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Types not match. Device type: " << TypeIdLabel(type_id_) << ", host type: " << TypeIdLabel(type)
|
||||
<< "!";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -54,23 +54,17 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph
|
|||
}
|
||||
auto tensor = node_value->cast<TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
size_t type_size = sizeof(float);
|
||||
if (tensor->data_type() == kNumberTypeInt64) {
|
||||
type_size = GetTypeByte(TypeIdToType(kNumberTypeInt64));
|
||||
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item_node, 0);
|
||||
if (output_type_id == kTypeUnknown) {
|
||||
output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0);
|
||||
}
|
||||
size_t type_size = sizeof(TypeIdToType(output_type_id));
|
||||
ShapeVector data_shape = tensor->shape();
|
||||
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
|
||||
DeviceAddressPtr address = nullptr;
|
||||
if (tensor->data_type() == kNumberTypeInt32) {
|
||||
address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeInt32);
|
||||
} else if (tensor->data_type() == kNumberTypeInt64) {
|
||||
address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeInt64);
|
||||
} else {
|
||||
address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeFloat32);
|
||||
}
|
||||
address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, output_type_id);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32 ||
|
||||
tensor->data_type() == kNumberTypeInt64) {
|
||||
if (tensor->data_type() == output_type_id) {
|
||||
address->ptr_ = tensor->data_c();
|
||||
} else {
|
||||
address->ptr_ = resource_manager_.MemMalloc(tensor_size);
|
||||
|
@ -97,10 +91,7 @@ void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel
|
|||
output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
|
||||
}
|
||||
std::vector<size_t> fmt_shape = AnfAlgo::GetOutputDeviceShape(item, index);
|
||||
size_t type_size = sizeof(float);
|
||||
if (output_type_id == kNumberTypeInt64) {
|
||||
type_size = GetTypeByte(TypeIdToType(kNumberTypeInt64));
|
||||
}
|
||||
size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
|
||||
size_t tensor_size =
|
||||
fmt_shape.empty() ? type_size
|
||||
: std::accumulate(fmt_shape.begin(), fmt_shape.end(), type_size, std::multiplies<size_t>());
|
||||
|
@ -254,13 +245,12 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker
|
|||
if (tensor_address != nullptr && tensor_address != address) {
|
||||
tensor->data_sync(false);
|
||||
}
|
||||
if (tensor->data_type() == address->type_id_ || tensor->data_type() == kNumberTypeFloat32 ||
|
||||
tensor->data_type() == kNumberTypeInt32 || tensor->data_type() == kNumberTypeInt64) {
|
||||
if (tensor->data_type() == address->type_id_) {
|
||||
address->ptr_ = tensor->data_c();
|
||||
} else {
|
||||
ShapeVector data_shape = tensor->shape();
|
||||
size_t tensor_size =
|
||||
std::accumulate(data_shape.begin(), data_shape.end(), sizeof(float), std::multiplies<size_t>());
|
||||
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(),
|
||||
GetTypeByte(TypeIdToType(address->type_id_)), std::multiplies<size_t>());
|
||||
address->ptr_ = resource_manager_.MemMalloc(tensor_size);
|
||||
if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c())) {
|
||||
|
|
|
@ -52,6 +52,17 @@ void UpdatePrevNotCNodeFormatDtype(const KernelAttr &kernel_attr, const std::vec
|
|||
}
|
||||
}
|
||||
|
||||
void GetOutputInferFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *output_formats,
|
||||
std::vector<TypeId> *output_types) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
TypeId dtype = kTypeUnknown;
|
||||
dtype = AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
|
||||
output_formats->emplace_back(kOpFormat_DEFAULT);
|
||||
output_types->emplace_back(dtype);
|
||||
}
|
||||
}
|
||||
|
||||
void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *input_formats,
|
||||
std::vector<TypeId> *input_types, std::vector<size_t> *input_no_cnode_indexes) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
|
@ -78,10 +89,53 @@ void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &ke
|
|||
}
|
||||
}
|
||||
|
||||
bool InputDtypeMatch(TypeId InputAttr, TypeId input_type, bool strict) {
|
||||
if (InputAttr == input_type) {
|
||||
return true;
|
||||
}
|
||||
if (!strict && InputAttr == kNumberTypeInt32 && (input_type == kNumberTypeInt16 || input_type == kNumberTypeInt64)) {
|
||||
return true;
|
||||
}
|
||||
if (!strict && InputAttr == kNumberTypeFloat32 &&
|
||||
(input_type == kNumberTypeFloat16 || input_type == kNumberTypeFloat64)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::pair<int, int> GetOutputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
|
||||
const std::vector<std::string> &output_formats,
|
||||
const std::vector<TypeId> &output_types) {
|
||||
if (kernel_attr.GetOutputSize() != output_types.size()) {
|
||||
MS_LOG(DEBUG) << "required output num:" << kernel_attr.GetInputSize()
|
||||
<< ", actual output num:" << output_types.size();
|
||||
return std::make_pair(0, 0);
|
||||
}
|
||||
int data_type_matched_num = 0;
|
||||
int format_matched_num = 0;
|
||||
auto output_num = output_types.size();
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
if (kernel_attr.GetOutputAttr(i).first != output_types[i]) {
|
||||
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetOutputAttr(i).first
|
||||
<< ", actual output dtype:" << output_types[i];
|
||||
} else {
|
||||
data_type_matched_num++;
|
||||
}
|
||||
|
||||
if (kernel_attr.GetOutputAttr(i).second != output_formats[i]) {
|
||||
MS_LOG(DEBUG) << "required format:" << kernel_attr.GetOutputAttr(i).second
|
||||
<< ", actual output format:" << output_formats[i];
|
||||
} else {
|
||||
format_matched_num++;
|
||||
}
|
||||
}
|
||||
return std::make_pair(data_type_matched_num, format_matched_num);
|
||||
}
|
||||
|
||||
std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
|
||||
const std::vector<std::string> &input_formats,
|
||||
const std::vector<TypeId> &input_types,
|
||||
const std::vector<size_t> &input_not_cnode_indexes) {
|
||||
const std::vector<size_t> &input_not_cnode_indexes, bool strict) {
|
||||
if (kernel_attr.GetInputSize() != input_types.size()) {
|
||||
MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size();
|
||||
return std::make_pair(0, 0);
|
||||
|
@ -98,12 +152,23 @@ std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
|
|||
format_matched_num++;
|
||||
continue;
|
||||
}
|
||||
if (is_not_cnode_idx) {
|
||||
if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).first, input_types[i], strict)) {
|
||||
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first
|
||||
<< ", actual input dtype:" << input_types[i];
|
||||
} else {
|
||||
data_type_matched_num++;
|
||||
}
|
||||
format_matched_num++;
|
||||
continue;
|
||||
}
|
||||
if (kernel_attr.GetInputAttr(i).first != input_types[i]) {
|
||||
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first
|
||||
<< ", actual input dtype:" << input_types[i];
|
||||
} else {
|
||||
data_type_matched_num++;
|
||||
}
|
||||
|
||||
if (kernel_attr.GetInputAttr(i).second != input_formats[i]) {
|
||||
MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second
|
||||
<< ", actual input format:" << input_formats[i];
|
||||
|
@ -141,23 +206,13 @@ void SetKernelBuildInfo(const std::vector<std::string> &input_formats, const std
|
|||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||
std::vector<std::string> input_formats;
|
||||
std::vector<TypeId> input_types;
|
||||
std::vector<size_t> input_not_cnode_indexes;
|
||||
std::vector<std::string> output_formats;
|
||||
std::vector<TypeId> output_types;
|
||||
MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node);
|
||||
GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes);
|
||||
auto kernel_attrs =
|
||||
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
|
||||
if (kernel_attrs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Operator[" << AnfAlgo::GetCNodeName(kernel_node) << "] is not support.";
|
||||
}
|
||||
bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
|
||||
const std::vector<KernelAttr> &kernel_attrs, const std::vector<std::string> &input_formats,
|
||||
const std::vector<TypeId> &input_types, const std::vector<size_t> &input_not_cnode_indexes,
|
||||
const std::vector<std::string> &infer_output_formats, const std::vector<TypeId> &infer_output_types,
|
||||
bool strict) {
|
||||
int max_type_matched_num = -1;
|
||||
int max_format_matched_num = -1;
|
||||
KernelAttr selected_kernel_attr;
|
||||
for (auto kernel_attr : kernel_attrs) {
|
||||
if (kernel_attr.GetAllSame()) {
|
||||
ExpandKernelAttr(kernel_node, &kernel_attr);
|
||||
|
@ -168,29 +223,61 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|||
continue;
|
||||
}
|
||||
std::pair<int, int> input_type_format_matched_num =
|
||||
GetInputDtypeFormatMatchedNum(kernel_attr, input_formats, input_types, input_not_cnode_indexes);
|
||||
GetInputDtypeFormatMatchedNum(kernel_attr, input_formats, input_types, input_not_cnode_indexes, strict);
|
||||
std::pair<int, int> output_type_format_matched_num =
|
||||
GetOutputDtypeFormatMatchedNum(kernel_attr, infer_output_formats, infer_output_types);
|
||||
// Data type first
|
||||
if (input_type_format_matched_num.first > max_type_matched_num) {
|
||||
max_type_matched_num = input_type_format_matched_num.first;
|
||||
max_format_matched_num = input_type_format_matched_num.second;
|
||||
selected_kernel_attr = kernel_attr;
|
||||
*selected_kernel_attr = kernel_attr;
|
||||
} else if (input_type_format_matched_num.first == max_type_matched_num &&
|
||||
input_type_format_matched_num.second > max_format_matched_num) {
|
||||
max_format_matched_num = input_type_format_matched_num.second;
|
||||
selected_kernel_attr = kernel_attr;
|
||||
*selected_kernel_attr = kernel_attr;
|
||||
} else if (input_type_format_matched_num.first == max_type_matched_num &&
|
||||
input_type_format_matched_num.second == max_format_matched_num) {
|
||||
if (output_type_format_matched_num.first == SizeToInt(infer_output_types.size()) &&
|
||||
output_type_format_matched_num.second == SizeToInt(infer_output_types.size())) {
|
||||
*selected_kernel_attr = kernel_attr;
|
||||
}
|
||||
}
|
||||
// All formats and data types matched
|
||||
if (max_type_matched_num == SizeToInt(input_types.size()) &&
|
||||
max_format_matched_num == SizeToInt(input_types.size())) {
|
||||
break;
|
||||
max_format_matched_num == SizeToInt(input_types.size()) &&
|
||||
output_type_format_matched_num.first == SizeToInt(infer_output_types.size()) &&
|
||||
output_type_format_matched_num.second == SizeToInt(infer_output_types.size())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||
std::vector<std::string> input_formats;
|
||||
std::vector<TypeId> input_types;
|
||||
std::vector<size_t> input_not_cnode_indexes;
|
||||
std::vector<std::string> output_formats;
|
||||
std::vector<TypeId> output_types;
|
||||
std::vector<std::string> infer_output_formats;
|
||||
std::vector<TypeId> infer_output_types;
|
||||
MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node);
|
||||
GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes);
|
||||
GetOutputInferFormatsAndDtypes(kernel_node, &infer_output_formats, &infer_output_types);
|
||||
auto kernel_attrs =
|
||||
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
|
||||
if (kernel_attrs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Operator[" << AnfAlgo::GetCNodeName(kernel_node) << "] is not support.";
|
||||
}
|
||||
KernelAttr selected_kernel_attr;
|
||||
bool matched = true;
|
||||
if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types,
|
||||
input_not_cnode_indexes, infer_output_formats, infer_output_types, true)) {
|
||||
matched = SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types,
|
||||
input_not_cnode_indexes, infer_output_formats, infer_output_types, false);
|
||||
}
|
||||
|
||||
if (selected_kernel_attr.GetInputSize() > 0 && ((max_type_matched_num == SizeToInt(input_types.size()) &&
|
||||
max_format_matched_num == SizeToInt(input_types.size())) ||
|
||||
input_types.size() == input_not_cnode_indexes.size())) {
|
||||
MS_LOG(INFO) << "Input format and dtype is matched, max_type_matched_num: " << max_type_matched_num
|
||||
<< ", max_format_matched_num: " << max_format_matched_num;
|
||||
if (selected_kernel_attr.GetInputSize() > 0 && (matched || input_types.size() == input_not_cnode_indexes.size())) {
|
||||
MS_LOG(INFO) << "Input format and dtype is matched";
|
||||
GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &output_formats, &output_types);
|
||||
UpdatePrevNotCNodeFormatDtype(selected_kernel_attr, input_not_cnode_indexes, kernel_node);
|
||||
for (auto &input_index : input_not_cnode_indexes) {
|
||||
|
|
|
@ -36,9 +36,46 @@ class SquareNet(nn.Cell):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_square():
|
||||
x = np.array([1, 2, 3]).astype(np.int16)
|
||||
net = SquareNet()
|
||||
output = net(Tensor(x))
|
||||
expect_output = np.array([1, 4, 9]).astype(np.int16)
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output)
|
||||
|
||||
x = np.array([1, 2, 3]).astype(np.int32)
|
||||
net = SquareNet()
|
||||
output = net(Tensor(x))
|
||||
expect_output = np.array([1, 4, 9]).astype(np.int32)
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output)
|
||||
|
||||
x = np.array([1, 2, 3]).astype(np.int64)
|
||||
net = SquareNet()
|
||||
output = net(Tensor(x))
|
||||
expect_output = np.array([1, 4, 9]).astype(np.int64)
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output)
|
||||
|
||||
x = np.array([1, 2, 3]).astype(np.float16)
|
||||
net = SquareNet()
|
||||
output = net(Tensor(x))
|
||||
expect_output = np.array([1, 4, 9]).astype(np.float16)
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output)
|
||||
|
||||
x = np.array([1, 2, 3]).astype(np.float32)
|
||||
net = SquareNet()
|
||||
output = net(Tensor(x))
|
||||
expect_output = np.array([1, 4, 9]).astype(np.float32)
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output)
|
||||
|
||||
x = np.array([1, 2, 3]).astype(np.float64)
|
||||
net = SquareNet()
|
||||
output = net(Tensor(x))
|
||||
expect_output = np.array([1, 4, 9]).astype(np.float64)
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output)
|
||||
|
||||
test_square()
|
||||
|
|
|
@ -71,3 +71,20 @@ def test_maxpool2d():
|
|||
assert (output.asnumpy() == expect_result).all()
|
||||
print(output2.asnumpy())
|
||||
assert (output2.asnumpy() == expect_result2).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_maxpool():
|
||||
x = Tensor(np.array([[[
|
||||
[0, 1, 2, 3, -4, -5],
|
||||
[6, 7, 8, 9, -10, -11],
|
||||
[12, 13, 14, -15, -16, -17],
|
||||
[18, 19, 20, 21, 22, 23],
|
||||
[24, 25, 26, 27, 28, 29],
|
||||
[30, 31, 32, 33, 34, 35]
|
||||
]]]).astype(np.int16))
|
||||
maxpool2d = Net_Pool()
|
||||
with pytest.raises(Exception):
|
||||
maxpool2d(x)
|
||||
|
|
Loading…
Reference in New Issue