support multiple types in print

This commit is contained in:
TFBunny 2021-05-10 11:41:10 -04:00
parent 9416502e90
commit 33255dbf60
4 changed files with 157 additions and 109 deletions

View File

@ -18,38 +18,6 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, bool)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, uint8_t)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, uint16_t)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, uint32_t)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, uint64_t)
MS_REG_GPU_KERNEL_ONE(
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, float)
MS_REG_GPU_KERNEL(Print, PrintGpuKernel)
} // namespace kernel
} // namespace mindspore

View File

@ -17,7 +17,9 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_
#include <utility>
#include <tuple>
#include <functional>
#include <unordered_map>
#include <string>
#include <algorithm>
#include <vector>
@ -25,12 +27,12 @@
#include "ir/tensor.h"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/data/dataset_utils.h"
using mindspore::tensor::Tensor;
namespace mindspore {
namespace kernel {
template <typename T>
class PrintGpuKernel : public GpuKernel {
public:
PrintGpuKernel() { ResetResource(); }
@ -43,41 +45,37 @@ class PrintGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
for (size_t i = 0; i < inputs.size(); i++) {
input_device_data_[i] = GetDeviceAddress<T>(inputs, i);
}
std::vector<void *> input_device_data;
InitDeviceData(inputs, &input_device_data);
int *output_address = GetDeviceAddress<int>(outputs, 0);
// host initialization
std::vector<std::unique_ptr<T[]>> input_host_data;
for (size_t i = 0; i < input_size_.size(); i++) {
std::unique_ptr<T[]> value = std::make_unique<T[]>(input_size_[i]);
input_host_data.push_back(std::move(value));
}
// check type
T type_value = static_cast<T>(0.0f);
auto type_id = CheckType(type_value);
if (type_id == kTypeUnknown) {
MS_LOG(EXCEPTION) << "GPU print does not support the input type.";
// host initialization in byte for storage
std::unique_ptr<uint8_t[]> input_host_data;
int64_t sum_of_bytes = 0;
for (size_t i = 0; i < input_info_.size(); i++) {
sum_of_bytes += std::get<0>(input_info_[i]);
}
input_host_data = std::make_unique<uint8_t[]>(sum_of_bytes);
// print core function
size_t string_idx = 0;
auto offset = input_host_data.get();
for (size_t i = 0; i < input_flag_.size(); i++) {
if (input_flag_[i] == -1) {
std::cout << string_value_[string_idx] << std::endl;
string_idx++;
} else {
size_t tensor_idx = LongToSize(input_flag_[i]);
size_t size_to_move = std::get<0>(input_info_[tensor_idx]);
std::string error_msg = "cudaMemcpyAsync print loop failed at input_device_data[";
error_msg.append(std::to_string(tensor_idx));
error_msg.append("].");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(input_host_data[tensor_idx].get(), input_device_data_[tensor_idx],
input_size_[tensor_idx] * sizeof(T), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
cudaMemcpyAsync(offset, input_device_data[tensor_idx], size_to_move,
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
error_msg);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed - Print");
auto current_string = GetTensorString(&input_shape_, tensor_idx, type_id, &input_host_data, &input_size_);
auto current_string = GetString(tensor_idx, i, offset);
std::cout << current_string << std::endl;
offset += size_to_move;
}
}
int output = 1;
@ -93,71 +91,90 @@ class PrintGpuKernel : public GpuKernel {
if (AnfAlgo::HasNodeAttr("string_pos", kernel_node)) {
string_value_ = GetAttr<std::vector<std::string>>(kernel_node, "string_value");
string_pos_ = GetAttr<std::vector<int64_t>>(kernel_node, "string_pos");
auto value_type = GetAttr<std::vector<int64_t>>(kernel_node, "value_type");
auto value_type_pos = GetAttr<std::vector<int64_t>>(kernel_node, "value_type_pos");
for (size_t i = 0; i < value_type.size(); i++) {
value_type_[value_type_pos[i]] = value_type[i];
}
}
size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node);
input_flag_ = SetInputFlag(&string_pos_, input_tensor_num);
input_device_data_ = std::make_unique<T *[]>(input_tensor_num);
std::vector<size_t> value_shape;
for (size_t i = 0; i < input_tensor_num; i++) {
size_t value = 1;
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
for (size_t j = 0; j < input_shape.size(); j++) {
value *= input_shape[j];
value_shape.push_back(input_shape[j]);
}
input_size_.push_back(value);
input_shape_.push_back(value_shape);
value_shape.clear();
auto type_id = AnfAlgo::GetInputDeviceDataType(kernel_node, i);
size_t unit_size = UnitSizeInBytes(type_id);
auto size_in_byte = std::accumulate(input_shape.begin(), input_shape.end(), unit_size, std::multiplies<size_t>());
input_info_.push_back(std::make_tuple(size_in_byte, type_id));
input_shape_.push_back(input_shape);
}
InitSizeLists();
return true;
}
protected:
void ResetResource() noexcept override {
string_value_.clear();
string_pos_.clear();
input_flag_.clear();
input_device_data_ = nullptr;
input_size_.clear();
value_type_.clear();
input_info_.clear();
input_shape_.clear();
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
for (size_t i = 0; i < input_size_.size(); i++) {
input_size_list_.push_back(input_size_[i] * sizeof(T));
for (size_t i = 0; i < input_info_.size(); i++) {
input_size_list_.push_back(std::get<0>(input_info_[i]));
}
output_size_list_.push_back(sizeof(int));
}
TypeId CheckType(T value) {
if (std::is_same<T, bool>::value) {
return kNumberTypeBool;
} else if (std::is_same<T, int8_t>::value) {
return kNumberTypeInt8;
} else if (std::is_same<T, int16_t>::value) {
return kNumberTypeInt16;
} else if (std::is_same<T, int>::value) {
return kNumberTypeInt32;
} else if (std::is_same<T, int64_t>::value) {
return kNumberTypeInt64;
} else if (std::is_same<T, uint8_t>::value) {
return kNumberTypeUInt8;
} else if (std::is_same<T, uint16_t>::value) {
return kNumberTypeUInt16;
} else if (std::is_same<T, uint32_t>::value) {
return kNumberTypeUInt32;
} else if (std::is_same<T, uint64_t>::value) {
return kNumberTypeUInt64;
} else if (std::is_same<T, half>::value) {
return kNumberTypeFloat16;
} else if (std::is_same<T, float>::value) {
return kNumberTypeFloat32;
void InitDeviceData(const std::vector<AddressPtr> &inputs, std::vector<void *> *input_device_data) {
for (size_t i = 0; i < inputs.size(); i++) {
TypeId type_id = std::get<1>(input_info_[i]);
switch (type_id) {
case kNumberTypeBool:
input_device_data->push_back(GetDeviceAddress<bool>(inputs, i));
break;
case kNumberTypeInt8:
input_device_data->push_back(GetDeviceAddress<int8_t>(inputs, i));
break;
case kNumberTypeInt16:
input_device_data->push_back(GetDeviceAddress<int16_t>(inputs, i));
break;
case kNumberTypeInt32:
input_device_data->push_back(GetDeviceAddress<int32_t>(inputs, i));
break;
case kNumberTypeInt64:
input_device_data->push_back(GetDeviceAddress<int64_t>(inputs, i));
break;
case kNumberTypeUInt8:
input_device_data->push_back(GetDeviceAddress<uint8_t>(inputs, i));
break;
case kNumberTypeUInt16:
input_device_data->push_back(GetDeviceAddress<uint16_t>(inputs, i));
break;
case kNumberTypeUInt32:
input_device_data->push_back(GetDeviceAddress<uint32_t>(inputs, i));
break;
case kNumberTypeUInt64:
input_device_data->push_back(GetDeviceAddress<uint64_t>(inputs, i));
break;
case kNumberTypeFloat16:
input_device_data->push_back(GetDeviceAddress<half>(inputs, i));
break;
case kNumberTypeFloat32:
input_device_data->push_back(GetDeviceAddress<float>(inputs, i));
break;
case kNumberTypeFloat64:
input_device_data->push_back(GetDeviceAddress<double>(inputs, i));
break;
default:
MS_LOG(EXCEPTION) << "TypeId: " << type_id << " is not supported in Print.";
}
}
return kTypeUnknown;
}
std::vector<int64_t> SetInputFlag(std::vector<int64_t> *string_pos, size_t input_tensor_num) {
@ -186,12 +203,23 @@ class PrintGpuKernel : public GpuKernel {
return res;
}
std::string GetTensorString(std::vector<std::vector<size_t>> *input_shape, size_t index, TypeId type_id,
std::vector<std::unique_ptr<T[]>> *input_host_data, std::vector<size_t> *input_size) {
std::string GetString(size_t tensor_index, size_t original_index, void *input_host_data) {
ShapeVector shape;
(void)std::transform((*input_shape)[index].begin(), (*input_shape)[index].end(), std::back_inserter(shape),
[](const size_t &value) { return static_cast<int64_t>(value); });
Tensor current_tensor(type_id, shape, (*input_host_data)[index].get(), (*input_size)[index] * sizeof(T));
size_t size_in_byte = std::get<0>(input_info_[tensor_index]);
TypeId type_id = std::get<1>(input_info_[tensor_index]);
(void)std::transform(input_shape_[tensor_index].begin(), input_shape_[tensor_index].end(),
std::back_inserter(shape), [](const size_t &value) { return static_cast<int64_t>(value); });
Tensor current_tensor(type_id, shape, input_host_data, size_in_byte);
if (value_type_.count(original_index) > 0) {
// not a tensor
auto out = current_tensor.data().ToString(type_id, shape, true);
if (value_type_[original_index] != 0) {
// tuple, not scalar
(void)std::replace(out.begin(), out.end(), '[', '(');
(void)std::replace(out.begin(), out.end(), ']', ')');
}
return out;
}
return current_tensor.ToStringNoLimit();
}
@ -199,8 +227,9 @@ class PrintGpuKernel : public GpuKernel {
std::vector<std::string> string_value_;
std::vector<int64_t> string_pos_;
std::vector<int64_t> input_flag_;
std::unique_ptr<T *[]> input_device_data_;
std::vector<size_t> input_size_;
std::unordered_map<int64_t, int64_t> value_type_;
// size_in_byte, typeid
std::vector<std::tuple<size_t, TypeId>> input_info_;
std::vector<std::vector<size_t>> input_shape_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;

View File

@ -53,36 +53,50 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
bool GetOptList(const std::vector<AnfNodePtr> &node_list, std::vector<AnfNodePtr> *opt_list,
std::vector<std::vector<int64_t>> *string_pos_vec,
std::vector<std::vector<std::string>> *string_value_vec) {
std::vector<std::vector<std::string>> *string_value_vec,
std::vector<std::vector<std::pair<int64_t, int64_t>>> *not_tensor_pos_vec) {
for (auto &node : node_list) {
// {prim::kPrimPrint} only print with string will be reduced
// {prim::kPrimPrint} reduction only applies on print with string, tensor(scalar or tuple)
std::vector<int64_t> string_pos;
std::vector<std::string> string_value;
std::vector<std::pair<int64_t, int64_t>> value_type;
if (IsPrimitiveCNode(node, prim::kPrimPrint)) {
size_t input_num = AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < input_num; i++) {
auto current_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i);
// not a string
// not tensor(tuple, scalar, string)
if (current_node->cast<ValueNodePtr>() == nullptr) {
continue;
}
auto value_node = current_node->cast<ValueNodePtr>()->value();
// not a string
if (value_node->type() == nullptr) {
auto value_node = current_node->cast<ValueNodePtr>();
auto shape_node = dyn_cast<abstract::Shape>(value_node->abstract()->GetShapeTrack());
if (shape_node != nullptr) {
// a scalar or tuple
auto shape_size = shape_node->shape().size();
if (shape_size != 0) {
value_type.push_back(std::make_pair(i, 1));
} else {
value_type.push_back(std::make_pair(i, 0));
}
}
auto node_value = value_node->value();
if (node_value->type() == nullptr) {
// not a string
continue;
}
if (value_node->type()->generic_type_id() == kObjectTypeString) {
auto current_string_value = GetValue<std::string>(value_node);
if (node_value->type()->generic_type_id() == kObjectTypeString) {
auto current_string_value = GetValue<std::string>(node_value);
string_pos.push_back(i);
string_value.push_back(std::string(current_string_value));
} else {
MS_LOG(EXCEPTION) << "Current value node is not string or tensor";
}
}
if (string_pos.size() != 0) {
if (string_pos.size() != 0 || value_type.size() != 0) {
opt_list->push_back(node);
string_pos_vec->push_back(string_pos);
string_value_vec->push_back(string_value);
not_tensor_pos_vec->push_back(value_type);
}
}
}
@ -100,7 +114,9 @@ bool PrintReduceFusion::Run(const FuncGraphPtr &graph) {
std::vector<AnfNodePtr> opt_list;
std::vector<std::vector<int64_t>> string_pos_vec;
std::vector<std::vector<std::string>> string_value_vec;
if (!GetOptList(node_list, &opt_list, &string_pos_vec, &string_value_vec)) {
// first is pos, second is type: 0 is Scalar, 1 is ValueTuple
std::vector<std::vector<std::pair<int64_t, int64_t>>> not_tensor_pos_vec;
if (!GetOptList(node_list, &opt_list, &string_pos_vec, &string_value_vec, &not_tensor_pos_vec)) {
return false;
}
for (size_t idx = 0; idx < opt_list.size(); idx++) {
@ -131,11 +147,21 @@ bool PrintReduceFusion::Run(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(monad_node);
inputs.push_back(monad_node);
auto string_value = string_value_vec[idx];
auto value_type_vec = not_tensor_pos_vec[idx];
// split value type and pos
std::vector<int64_t> value_type_pos;
std::vector<int64_t> value_type;
(void)std::transform(value_type_vec.begin(), value_type_vec.end(), std::back_inserter(value_type_pos),
[](const std::pair<int64_t, int64_t> &value) { return value.first; });
(void)std::transform(value_type_vec.begin(), value_type_vec.end(), std::back_inserter(value_type),
[](const std::pair<int64_t, int64_t> &value) { return value.second; });
// create new cnode
auto print_fused = graph->NewCNode(inputs);
// hand over the attrs to new print
AnfAlgo::SetNodeAttr("string_pos", MakeValue<std::vector<int64_t>>(string_pos), print_fused);
AnfAlgo::SetNodeAttr("string_value", MakeValue<std::vector<std::string>>(string_value), print_fused);
AnfAlgo::SetNodeAttr("value_type", MakeValue<std::vector<int64_t>>(value_type), print_fused);
AnfAlgo::SetNodeAttr("value_type_pos", MakeValue<std::vector<int64_t>>(value_type_pos), print_fused);
// set output type and shape
std::vector<TypeId> types;
std::vector<std::vector<size_t>> shapes;

View File

@ -71,6 +71,7 @@ def print_testcase(nptype):
net_2(x, y)
net_3(x)
class PrintNetString(nn.Cell):
def __init__(self):
super(PrintNetString, self).__init__()
@ -83,6 +84,7 @@ class PrintNetString(nn.Cell):
self.op("The first Tensor is", x, y, "is the second Tensor")
return x
def print_testcase_string(nptype):
x = np.ones(18).astype(nptype)
y = np.arange(9).reshape(3, 3).astype(nptype)
@ -93,6 +95,29 @@ def print_testcase_string(nptype):
net = PrintNetString()
net(x, y)
class PrintTypes(nn.Cell):
def __init__(self):
super(PrintTypes, self).__init__()
self.op = P.Print()
def construct(self, x, y, z):
self.op("This is a scalar:", 34, "This is int:", x, "This is float64:", y, "This is int64:", z)
return x
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_print_multiple_types():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x = Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int32))
y = Tensor(np.array([[1], [3], [4], [6], [3]]).astype(np.float64))
z = Tensor(np.arange(9).reshape(3, 3).astype(np.int64))
net = PrintTypes()
net(x, y, z)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard