forked from mindspore-Ecosystem/mindspore
support multiple types in print
This commit is contained in:
parent
9416502e90
commit
33255dbf60
|
@ -18,38 +18,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, bool)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, int8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, uint8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, uint16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, uint32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, uint64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL(Print, PrintGpuKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,7 +17,9 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_
|
||||
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
@ -25,12 +27,12 @@
|
|||
#include "ir/tensor.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/data/dataset_utils.h"
|
||||
|
||||
using mindspore::tensor::Tensor;
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class PrintGpuKernel : public GpuKernel {
|
||||
public:
|
||||
PrintGpuKernel() { ResetResource(); }
|
||||
|
@ -43,41 +45,37 @@ class PrintGpuKernel : public GpuKernel {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
input_device_data_[i] = GetDeviceAddress<T>(inputs, i);
|
||||
}
|
||||
std::vector<void *> input_device_data;
|
||||
InitDeviceData(inputs, &input_device_data);
|
||||
int *output_address = GetDeviceAddress<int>(outputs, 0);
|
||||
// host initialization
|
||||
std::vector<std::unique_ptr<T[]>> input_host_data;
|
||||
for (size_t i = 0; i < input_size_.size(); i++) {
|
||||
std::unique_ptr<T[]> value = std::make_unique<T[]>(input_size_[i]);
|
||||
input_host_data.push_back(std::move(value));
|
||||
}
|
||||
// check type
|
||||
T type_value = static_cast<T>(0.0f);
|
||||
auto type_id = CheckType(type_value);
|
||||
if (type_id == kTypeUnknown) {
|
||||
MS_LOG(EXCEPTION) << "GPU print does not support the input type.";
|
||||
// host initialization in byte for storage
|
||||
std::unique_ptr<uint8_t[]> input_host_data;
|
||||
int64_t sum_of_bytes = 0;
|
||||
for (size_t i = 0; i < input_info_.size(); i++) {
|
||||
sum_of_bytes += std::get<0>(input_info_[i]);
|
||||
}
|
||||
input_host_data = std::make_unique<uint8_t[]>(sum_of_bytes);
|
||||
// print core function
|
||||
size_t string_idx = 0;
|
||||
auto offset = input_host_data.get();
|
||||
for (size_t i = 0; i < input_flag_.size(); i++) {
|
||||
if (input_flag_[i] == -1) {
|
||||
std::cout << string_value_[string_idx] << std::endl;
|
||||
string_idx++;
|
||||
} else {
|
||||
size_t tensor_idx = LongToSize(input_flag_[i]);
|
||||
size_t size_to_move = std::get<0>(input_info_[tensor_idx]);
|
||||
std::string error_msg = "cudaMemcpyAsync print loop failed at input_device_data[";
|
||||
error_msg.append(std::to_string(tensor_idx));
|
||||
error_msg.append("].");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(input_host_data[tensor_idx].get(), input_device_data_[tensor_idx],
|
||||
input_size_[tensor_idx] * sizeof(T), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
cudaMemcpyAsync(offset, input_device_data[tensor_idx], size_to_move,
|
||||
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
error_msg);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed - Print");
|
||||
auto current_string = GetTensorString(&input_shape_, tensor_idx, type_id, &input_host_data, &input_size_);
|
||||
auto current_string = GetString(tensor_idx, i, offset);
|
||||
std::cout << current_string << std::endl;
|
||||
offset += size_to_move;
|
||||
}
|
||||
}
|
||||
int output = 1;
|
||||
|
@ -93,71 +91,90 @@ class PrintGpuKernel : public GpuKernel {
|
|||
if (AnfAlgo::HasNodeAttr("string_pos", kernel_node)) {
|
||||
string_value_ = GetAttr<std::vector<std::string>>(kernel_node, "string_value");
|
||||
string_pos_ = GetAttr<std::vector<int64_t>>(kernel_node, "string_pos");
|
||||
auto value_type = GetAttr<std::vector<int64_t>>(kernel_node, "value_type");
|
||||
auto value_type_pos = GetAttr<std::vector<int64_t>>(kernel_node, "value_type_pos");
|
||||
for (size_t i = 0; i < value_type.size(); i++) {
|
||||
value_type_[value_type_pos[i]] = value_type[i];
|
||||
}
|
||||
}
|
||||
size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
input_flag_ = SetInputFlag(&string_pos_, input_tensor_num);
|
||||
input_device_data_ = std::make_unique<T *[]>(input_tensor_num);
|
||||
std::vector<size_t> value_shape;
|
||||
for (size_t i = 0; i < input_tensor_num; i++) {
|
||||
size_t value = 1;
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
|
||||
for (size_t j = 0; j < input_shape.size(); j++) {
|
||||
value *= input_shape[j];
|
||||
value_shape.push_back(input_shape[j]);
|
||||
}
|
||||
input_size_.push_back(value);
|
||||
input_shape_.push_back(value_shape);
|
||||
value_shape.clear();
|
||||
auto type_id = AnfAlgo::GetInputDeviceDataType(kernel_node, i);
|
||||
size_t unit_size = UnitSizeInBytes(type_id);
|
||||
auto size_in_byte = std::accumulate(input_shape.begin(), input_shape.end(), unit_size, std::multiplies<size_t>());
|
||||
input_info_.push_back(std::make_tuple(size_in_byte, type_id));
|
||||
input_shape_.push_back(input_shape);
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void ResetResource() noexcept override {
|
||||
string_value_.clear();
|
||||
string_pos_.clear();
|
||||
input_flag_.clear();
|
||||
input_device_data_ = nullptr;
|
||||
input_size_.clear();
|
||||
value_type_.clear();
|
||||
input_info_.clear();
|
||||
input_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
for (size_t i = 0; i < input_size_.size(); i++) {
|
||||
input_size_list_.push_back(input_size_[i] * sizeof(T));
|
||||
for (size_t i = 0; i < input_info_.size(); i++) {
|
||||
input_size_list_.push_back(std::get<0>(input_info_[i]));
|
||||
}
|
||||
output_size_list_.push_back(sizeof(int));
|
||||
}
|
||||
|
||||
TypeId CheckType(T value) {
|
||||
if (std::is_same<T, bool>::value) {
|
||||
return kNumberTypeBool;
|
||||
} else if (std::is_same<T, int8_t>::value) {
|
||||
return kNumberTypeInt8;
|
||||
} else if (std::is_same<T, int16_t>::value) {
|
||||
return kNumberTypeInt16;
|
||||
} else if (std::is_same<T, int>::value) {
|
||||
return kNumberTypeInt32;
|
||||
} else if (std::is_same<T, int64_t>::value) {
|
||||
return kNumberTypeInt64;
|
||||
} else if (std::is_same<T, uint8_t>::value) {
|
||||
return kNumberTypeUInt8;
|
||||
} else if (std::is_same<T, uint16_t>::value) {
|
||||
return kNumberTypeUInt16;
|
||||
} else if (std::is_same<T, uint32_t>::value) {
|
||||
return kNumberTypeUInt32;
|
||||
} else if (std::is_same<T, uint64_t>::value) {
|
||||
return kNumberTypeUInt64;
|
||||
} else if (std::is_same<T, half>::value) {
|
||||
return kNumberTypeFloat16;
|
||||
} else if (std::is_same<T, float>::value) {
|
||||
return kNumberTypeFloat32;
|
||||
void InitDeviceData(const std::vector<AddressPtr> &inputs, std::vector<void *> *input_device_data) {
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
TypeId type_id = std::get<1>(input_info_[i]);
|
||||
switch (type_id) {
|
||||
case kNumberTypeBool:
|
||||
input_device_data->push_back(GetDeviceAddress<bool>(inputs, i));
|
||||
break;
|
||||
case kNumberTypeInt8:
|
||||
input_device_data->push_back(GetDeviceAddress<int8_t>(inputs, i));
|
||||
break;
|
||||
case kNumberTypeInt16:
|
||||
input_device_data->push_back(GetDeviceAddress<int16_t>(inputs, i));
|
||||
break;
|
||||
case kNumberTypeInt32:
|
||||
input_device_data->push_back(GetDeviceAddress<int32_t>(inputs, i));
|
||||
break;
|
||||
case kNumberTypeInt64:
|
||||
input_device_data->push_back(GetDeviceAddress<int64_t>(inputs, i));
|
||||
break;
|
||||
case kNumberTypeUInt8:
|
||||
input_device_data->push_back(GetDeviceAddress<uint8_t>(inputs, i));
|
||||
break;
|
||||
case kNumberTypeUInt16:
|
||||
input_device_data->push_back(GetDeviceAddress<uint16_t>(inputs, i));
|
||||
break;
|
||||
case kNumberTypeUInt32:
|
||||
input_device_data->push_back(GetDeviceAddress<uint32_t>(inputs, i));
|
||||
break;
|
||||
case kNumberTypeUInt64:
|
||||
input_device_data->push_back(GetDeviceAddress<uint64_t>(inputs, i));
|
||||
break;
|
||||
case kNumberTypeFloat16:
|
||||
input_device_data->push_back(GetDeviceAddress<half>(inputs, i));
|
||||
break;
|
||||
case kNumberTypeFloat32:
|
||||
input_device_data->push_back(GetDeviceAddress<float>(inputs, i));
|
||||
break;
|
||||
case kNumberTypeFloat64:
|
||||
input_device_data->push_back(GetDeviceAddress<double>(inputs, i));
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "TypeId: " << type_id << " is not supported in Print.";
|
||||
}
|
||||
}
|
||||
return kTypeUnknown;
|
||||
}
|
||||
|
||||
std::vector<int64_t> SetInputFlag(std::vector<int64_t> *string_pos, size_t input_tensor_num) {
|
||||
|
@ -186,12 +203,23 @@ class PrintGpuKernel : public GpuKernel {
|
|||
return res;
|
||||
}
|
||||
|
||||
std::string GetTensorString(std::vector<std::vector<size_t>> *input_shape, size_t index, TypeId type_id,
|
||||
std::vector<std::unique_ptr<T[]>> *input_host_data, std::vector<size_t> *input_size) {
|
||||
std::string GetString(size_t tensor_index, size_t original_index, void *input_host_data) {
|
||||
ShapeVector shape;
|
||||
(void)std::transform((*input_shape)[index].begin(), (*input_shape)[index].end(), std::back_inserter(shape),
|
||||
[](const size_t &value) { return static_cast<int64_t>(value); });
|
||||
Tensor current_tensor(type_id, shape, (*input_host_data)[index].get(), (*input_size)[index] * sizeof(T));
|
||||
size_t size_in_byte = std::get<0>(input_info_[tensor_index]);
|
||||
TypeId type_id = std::get<1>(input_info_[tensor_index]);
|
||||
(void)std::transform(input_shape_[tensor_index].begin(), input_shape_[tensor_index].end(),
|
||||
std::back_inserter(shape), [](const size_t &value) { return static_cast<int64_t>(value); });
|
||||
Tensor current_tensor(type_id, shape, input_host_data, size_in_byte);
|
||||
if (value_type_.count(original_index) > 0) {
|
||||
// not a tensor
|
||||
auto out = current_tensor.data().ToString(type_id, shape, true);
|
||||
if (value_type_[original_index] != 0) {
|
||||
// tuple, not scalar
|
||||
(void)std::replace(out.begin(), out.end(), '[', '(');
|
||||
(void)std::replace(out.begin(), out.end(), ']', ')');
|
||||
}
|
||||
return out;
|
||||
}
|
||||
return current_tensor.ToStringNoLimit();
|
||||
}
|
||||
|
||||
|
@ -199,8 +227,9 @@ class PrintGpuKernel : public GpuKernel {
|
|||
std::vector<std::string> string_value_;
|
||||
std::vector<int64_t> string_pos_;
|
||||
std::vector<int64_t> input_flag_;
|
||||
std::unique_ptr<T *[]> input_device_data_;
|
||||
std::vector<size_t> input_size_;
|
||||
std::unordered_map<int64_t, int64_t> value_type_;
|
||||
// size_in_byte, typeid
|
||||
std::vector<std::tuple<size_t, TypeId>> input_info_;
|
||||
std::vector<std::vector<size_t>> input_shape_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
|
|
@ -53,36 +53,50 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
|||
|
||||
bool GetOptList(const std::vector<AnfNodePtr> &node_list, std::vector<AnfNodePtr> *opt_list,
|
||||
std::vector<std::vector<int64_t>> *string_pos_vec,
|
||||
std::vector<std::vector<std::string>> *string_value_vec) {
|
||||
std::vector<std::vector<std::string>> *string_value_vec,
|
||||
std::vector<std::vector<std::pair<int64_t, int64_t>>> *not_tensor_pos_vec) {
|
||||
for (auto &node : node_list) {
|
||||
// {prim::kPrimPrint} only print with string will be reduced
|
||||
// {prim::kPrimPrint} reduction only applies on print with string, tensor(scalar or tuple)
|
||||
std::vector<int64_t> string_pos;
|
||||
std::vector<std::string> string_value;
|
||||
std::vector<std::pair<int64_t, int64_t>> value_type;
|
||||
if (IsPrimitiveCNode(node, prim::kPrimPrint)) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
auto current_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i);
|
||||
// not a string
|
||||
// not tensor(tuple, scalar, string)
|
||||
if (current_node->cast<ValueNodePtr>() == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto value_node = current_node->cast<ValueNodePtr>()->value();
|
||||
// not a string
|
||||
if (value_node->type() == nullptr) {
|
||||
auto value_node = current_node->cast<ValueNodePtr>();
|
||||
auto shape_node = dyn_cast<abstract::Shape>(value_node->abstract()->GetShapeTrack());
|
||||
if (shape_node != nullptr) {
|
||||
// a scalar or tuple
|
||||
auto shape_size = shape_node->shape().size();
|
||||
if (shape_size != 0) {
|
||||
value_type.push_back(std::make_pair(i, 1));
|
||||
} else {
|
||||
value_type.push_back(std::make_pair(i, 0));
|
||||
}
|
||||
}
|
||||
auto node_value = value_node->value();
|
||||
if (node_value->type() == nullptr) {
|
||||
// not a string
|
||||
continue;
|
||||
}
|
||||
if (value_node->type()->generic_type_id() == kObjectTypeString) {
|
||||
auto current_string_value = GetValue<std::string>(value_node);
|
||||
if (node_value->type()->generic_type_id() == kObjectTypeString) {
|
||||
auto current_string_value = GetValue<std::string>(node_value);
|
||||
string_pos.push_back(i);
|
||||
string_value.push_back(std::string(current_string_value));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Current value node is not string or tensor";
|
||||
}
|
||||
}
|
||||
if (string_pos.size() != 0) {
|
||||
if (string_pos.size() != 0 || value_type.size() != 0) {
|
||||
opt_list->push_back(node);
|
||||
string_pos_vec->push_back(string_pos);
|
||||
string_value_vec->push_back(string_value);
|
||||
not_tensor_pos_vec->push_back(value_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -100,7 +114,9 @@ bool PrintReduceFusion::Run(const FuncGraphPtr &graph) {
|
|||
std::vector<AnfNodePtr> opt_list;
|
||||
std::vector<std::vector<int64_t>> string_pos_vec;
|
||||
std::vector<std::vector<std::string>> string_value_vec;
|
||||
if (!GetOptList(node_list, &opt_list, &string_pos_vec, &string_value_vec)) {
|
||||
// first is pos, second is type: 0 is Scalar, 1 is ValueTuple
|
||||
std::vector<std::vector<std::pair<int64_t, int64_t>>> not_tensor_pos_vec;
|
||||
if (!GetOptList(node_list, &opt_list, &string_pos_vec, &string_value_vec, ¬_tensor_pos_vec)) {
|
||||
return false;
|
||||
}
|
||||
for (size_t idx = 0; idx < opt_list.size(); idx++) {
|
||||
|
@ -131,11 +147,21 @@ bool PrintReduceFusion::Run(const FuncGraphPtr &graph) {
|
|||
MS_EXCEPTION_IF_NULL(monad_node);
|
||||
inputs.push_back(monad_node);
|
||||
auto string_value = string_value_vec[idx];
|
||||
auto value_type_vec = not_tensor_pos_vec[idx];
|
||||
// split value type and pos
|
||||
std::vector<int64_t> value_type_pos;
|
||||
std::vector<int64_t> value_type;
|
||||
(void)std::transform(value_type_vec.begin(), value_type_vec.end(), std::back_inserter(value_type_pos),
|
||||
[](const std::pair<int64_t, int64_t> &value) { return value.first; });
|
||||
(void)std::transform(value_type_vec.begin(), value_type_vec.end(), std::back_inserter(value_type),
|
||||
[](const std::pair<int64_t, int64_t> &value) { return value.second; });
|
||||
// create new cnode
|
||||
auto print_fused = graph->NewCNode(inputs);
|
||||
// hand over the attrs to new print
|
||||
AnfAlgo::SetNodeAttr("string_pos", MakeValue<std::vector<int64_t>>(string_pos), print_fused);
|
||||
AnfAlgo::SetNodeAttr("string_value", MakeValue<std::vector<std::string>>(string_value), print_fused);
|
||||
AnfAlgo::SetNodeAttr("value_type", MakeValue<std::vector<int64_t>>(value_type), print_fused);
|
||||
AnfAlgo::SetNodeAttr("value_type_pos", MakeValue<std::vector<int64_t>>(value_type_pos), print_fused);
|
||||
// set output type and shape
|
||||
std::vector<TypeId> types;
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
|
|
|
@ -71,6 +71,7 @@ def print_testcase(nptype):
|
|||
net_2(x, y)
|
||||
net_3(x)
|
||||
|
||||
|
||||
class PrintNetString(nn.Cell):
|
||||
def __init__(self):
|
||||
super(PrintNetString, self).__init__()
|
||||
|
@ -83,6 +84,7 @@ class PrintNetString(nn.Cell):
|
|||
self.op("The first Tensor is", x, y, "is the second Tensor")
|
||||
return x
|
||||
|
||||
|
||||
def print_testcase_string(nptype):
|
||||
x = np.ones(18).astype(nptype)
|
||||
y = np.arange(9).reshape(3, 3).astype(nptype)
|
||||
|
@ -93,6 +95,29 @@ def print_testcase_string(nptype):
|
|||
net = PrintNetString()
|
||||
net(x, y)
|
||||
|
||||
|
||||
class PrintTypes(nn.Cell):
|
||||
def __init__(self):
|
||||
super(PrintTypes, self).__init__()
|
||||
self.op = P.Print()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
self.op("This is a scalar:", 34, "This is int:", x, "This is float64:", y, "This is int64:", z)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_multiple_types():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int32))
|
||||
y = Tensor(np.array([[1], [3], [4], [6], [3]]).astype(np.float64))
|
||||
z = Tensor(np.arange(9).reshape(3, 3).astype(np.int64))
|
||||
net = PrintTypes()
|
||||
net(x, y, z)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue