commit
62d272f648
|
@ -19,37 +19,37 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, bool)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, int8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, uint8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
PrintGpuKernel, bool)
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, uint16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, uint32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Print,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, uint64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
PrintGpuKernel, uint16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
PrintGpuKernel, uint32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
PrintGpuKernel, uint64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
|
||||
PrintGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,11 +17,17 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_
|
||||
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ir/tensor.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
using mindspore::tensor::Tensor;
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
|
@ -37,19 +43,42 @@ class PrintGpuKernel : public GpuKernel {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
VARIABLE_NOT_USED(outputs);
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
input_device_data_[i] = GetDeviceAddress<T>(inputs, i);
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudaMemcpy(&input_host_data_[0], &input_device_data_[0], input_size_ * sizeof(T), cudaMemcpyDeviceToHost),
|
||||
"cudaMemcpy output failed");
|
||||
for (size_t i = 0; i < input_num_.size(); i++) {
|
||||
for (size_t j = 0; j < input_num_[i]; j++) {
|
||||
std::cout << input_host_data_[i][j];
|
||||
}
|
||||
int *output_address = GetDeviceAddress<int>(outputs, 0);
|
||||
// host initialization
|
||||
std::vector<std::unique_ptr<T[]> > input_host_data;
|
||||
for (size_t i = 0; i < input_size_.size(); i++) {
|
||||
std::unique_ptr<T[]> value = std::make_unique<T[]>(input_size_[i]);
|
||||
input_host_data.push_back(std::move(value));
|
||||
}
|
||||
// check type
|
||||
T type_value = static_cast<T>(0.0f);
|
||||
auto type_id = CheckType(type_value);
|
||||
if (type_id == kTypeUnknown) {
|
||||
MS_LOG(EXCEPTION) << "GPU print does not support the input type.";
|
||||
}
|
||||
// print core function
|
||||
for (size_t i = 0; i < input_host_data.size(); i++) {
|
||||
std::string error_msg = "cudaMemcpy print loop failed at input_device_data[";
|
||||
error_msg.append(std::to_string(i));
|
||||
error_msg.append("].");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudaMemcpy(input_host_data[i].get(), input_device_data_[i], input_size_[i] * sizeof(T), cudaMemcpyDeviceToHost),
|
||||
error_msg);
|
||||
ShapeVector shape;
|
||||
(void)std::transform(input_shape_[i].begin(), input_shape_[i].end(), std::back_inserter(shape),
|
||||
[](const size_t &value) { return static_cast<int64_t>(value); });
|
||||
Tensor current_tensor(type_id, shape, input_host_data[i].get(), input_size_[i] * sizeof(T));
|
||||
std::cout << current_tensor.ToString() << std::endl;
|
||||
}
|
||||
int output = 1;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(output_address, &output, sizeof(int), cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -57,38 +86,70 @@ class PrintGpuKernel : public GpuKernel {
|
|||
kernel_node_ = kernel_node;
|
||||
size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
input_device_data_ = std::make_unique<T *[]>(input_tensor_num);
|
||||
input_host_data_ = std::make_unique<T *[]>(input_tensor_num);
|
||||
std::vector<size_t> value_shape;
|
||||
for (size_t i = 0; i < input_tensor_num; i++) {
|
||||
size_t counter = 0;
|
||||
size_t value = 1;
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
|
||||
for (size_t j = 0; j < input_shape.size(); j++) {
|
||||
input_size_ *= input_shape[j];
|
||||
counter++;
|
||||
value *= input_shape[j];
|
||||
value_shape.push_back(input_shape[j]);
|
||||
}
|
||||
input_num_.push_back(counter);
|
||||
input_size_.push_back(value);
|
||||
input_shape_.push_back(value_shape);
|
||||
value_shape.clear();
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
input_size_ = 1;
|
||||
input_device_data_ = nullptr;
|
||||
input_host_data_ = nullptr;
|
||||
input_num_.clear();
|
||||
input_size_.clear();
|
||||
input_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override { input_size_list_.push_back(input_size_ * sizeof(T)); }
|
||||
void InitSizeLists() override {
|
||||
for (size_t i = 0; i < input_size_.size(); i++) {
|
||||
input_size_list_.push_back(input_size_[i] * sizeof(T));
|
||||
}
|
||||
output_size_list_.push_back(sizeof(int));
|
||||
}
|
||||
|
||||
TypeId CheckType(T value) {
|
||||
if (std::is_same<T, bool>::value) {
|
||||
return kNumberTypeBool;
|
||||
} else if (std::is_same<T, int8_t>::value) {
|
||||
return kNumberTypeInt8;
|
||||
} else if (std::is_same<T, int16_t>::value) {
|
||||
return kNumberTypeInt16;
|
||||
} else if (std::is_same<T, int>::value) {
|
||||
return kNumberTypeInt32;
|
||||
} else if (std::is_same<T, int64_t>::value) {
|
||||
return kNumberTypeInt64;
|
||||
} else if (std::is_same<T, uint8_t>::value) {
|
||||
return kNumberTypeUInt8;
|
||||
} else if (std::is_same<T, uint16_t>::value) {
|
||||
return kNumberTypeUInt16;
|
||||
} else if (std::is_same<T, uint32_t>::value) {
|
||||
return kNumberTypeUInt32;
|
||||
} else if (std::is_same<T, uint64_t>::value) {
|
||||
return kNumberTypeUInt64;
|
||||
} else if (std::is_same<T, half>::value) {
|
||||
return kNumberTypeFloat16;
|
||||
} else if (std::is_same<T, float>::value) {
|
||||
return kNumberTypeFloat32;
|
||||
}
|
||||
return kTypeUnknown;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t input_size_;
|
||||
std::unique_ptr<T *[]> input_device_data_;
|
||||
std::unique_ptr<T *[]> input_host_data_;
|
||||
std::vector<size_t> input_num_;
|
||||
std::vector<size_t> input_size_;
|
||||
std::vector<std::vector<size_t> > input_shape_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
|
|
@ -341,10 +341,11 @@ class Print(PrimitiveWithInfer):
|
|||
In pynative mode, please use python print function.
|
||||
In graph mode, the bool, int, float, tuple, and list would be converted into Tensor to print,
|
||||
str remains unchanged.
|
||||
In GPU, all input elements should be the same type and string is not supported.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Union[Tensor, bool, int, float, str, tuple, list]) - The graph node to attach to.
|
||||
Supports multiple inputs which are separated by ','.
|
||||
Supports multiple inputs which are separated by ','. GPU does not support string as an input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.context as context
|
||||
|
||||
|
||||
class PrintNetOneInput(nn.Cell):
|
||||
def __init__(self):
|
||||
super(PrintNetOneInput, self).__init__()
|
||||
self.op = P.Print()
|
||||
|
||||
def construct(self, x):
|
||||
self.op(x)
|
||||
return x
|
||||
|
||||
|
||||
class PrintNetTwoInputs(nn.Cell):
|
||||
def __init__(self):
|
||||
super(PrintNetTwoInputs, self).__init__()
|
||||
self.op = P.Print()
|
||||
|
||||
def construct(self, x, y):
|
||||
self.op(x, y)
|
||||
return x
|
||||
|
||||
|
||||
def print_testcase(nptype):
|
||||
# large shape
|
||||
x = np.arange(20808).reshape(6, 3, 34, 34).astype(nptype)
|
||||
# small shape
|
||||
y = np.arange(9).reshape(3, 3).astype(nptype)
|
||||
x = Tensor(x)
|
||||
y = Tensor(y)
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net_1 = PrintNetOneInput()
|
||||
net_2 = PrintNetTwoInputs()
|
||||
net_1(x)
|
||||
net_2(x, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_bool():
|
||||
print_testcase(np.bool)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_int8():
|
||||
print_testcase(np.int8)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_int16():
|
||||
print_testcase(np.int16)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_int32():
|
||||
print_testcase(np.int32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_int64():
|
||||
print_testcase(np.int64)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_uint8():
|
||||
print_testcase(np.uint8)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_uint16():
|
||||
print_testcase(np.uint16)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_uint32():
|
||||
print_testcase(np.uint32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_uint64():
|
||||
print_testcase(np.uint64)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_float16():
|
||||
print_testcase(np.float16)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_float32():
|
||||
print_testcase(np.float32)
|
Loading…
Reference in New Issue