From 4d35303265e866228a3f54d2f0e2406090f5941e Mon Sep 17 00:00:00 2001 From: TFBunny Date: Fri, 12 Mar 2021 14:12:19 -0500 Subject: [PATCH] support string in GPU print --- .../gpu/debug/print_gpu_kernel.h | 84 ++++++++-- .../optimizer/gpu/print_reduce_fusion.cc | 154 ++++++++++++++++++ .../optimizer/gpu/print_reduce_fusion.h | 32 ++++ .../ccsrc/backend/session/gpu_session.cc | 2 + mindspore/ops/operations/debug_ops.py | 3 +- tests/st/ops/gpu/test_print_op.py | 28 ++++ 6 files changed, 285 insertions(+), 18 deletions(-) create mode 100644 mindspore/ccsrc/backend/optimizer/gpu/print_reduce_fusion.cc create mode 100644 mindspore/ccsrc/backend/optimizer/gpu/print_reduce_fusion.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/debug/print_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/debug/print_gpu_kernel.h index d409193f983..b17734dcdc9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/debug/print_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/debug/print_gpu_kernel.h @@ -48,7 +48,7 @@ class PrintGpuKernel : public GpuKernel { } int *output_address = GetDeviceAddress(outputs, 0); // host initialization - std::vector > input_host_data; + std::vector> input_host_data; for (size_t i = 0; i < input_size_.size(); i++) { std::unique_ptr value = std::make_unique(input_size_[i]); input_host_data.push_back(std::move(value)); @@ -60,19 +60,25 @@ class PrintGpuKernel : public GpuKernel { MS_LOG(EXCEPTION) << "GPU print does not support the input type."; } // print core function - for (size_t i = 0; i < input_host_data.size(); i++) { - std::string error_msg = "cudaMemcpy print loop failed at input_device_data["; - error_msg.append(std::to_string(i)); - error_msg.append("]."); - CHECK_CUDA_RET_WITH_EXCEPT( - kernel_node_, - cudaMemcpy(input_host_data[i].get(), input_device_data_[i], input_size_[i] * sizeof(T), cudaMemcpyDeviceToHost), - error_msg); - ShapeVector shape; - (void)std::transform(input_shape_[i].begin(), input_shape_[i].end(), std::back_inserter(shape), - [](const size_t &value) { return static_cast(value); }); - Tensor current_tensor(type_id, shape, input_host_data[i].get(), input_size_[i] * sizeof(T)); - std::cout << current_tensor.ToString() << std::endl; + size_t string_idx = 0; + for (size_t i = 0; i < input_flag_.size(); i++) { + if (input_flag_[i] == -1) { + std::cout << string_value_[string_idx] << std::endl; + string_idx++; + } else { + size_t tensor_idx = LongToSize(input_flag_[i]); + std::string error_msg = "cudaMemcpyAsync print loop failed at input_device_data["; + error_msg.append(std::to_string(tensor_idx)); + error_msg.append("]."); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(input_host_data[tensor_idx].get(), input_device_data_[tensor_idx], + input_size_[tensor_idx] * sizeof(T), cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr)), + error_msg); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed - Print"); + auto current_string = GetTensorString(&input_shape_, tensor_idx, type_id, &input_host_data, &input_size_); + std::cout << current_string << std::endl; + } } int output = 1; CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, @@ -84,7 +90,12 @@ class PrintGpuKernel : public GpuKernel { bool Init(const CNodePtr &kernel_node) override { kernel_node_ = kernel_node; + if (AnfAlgo::HasNodeAttr("string_pos", kernel_node)) { + string_value_ = GetAttr>(kernel_node, "string_value"); + string_pos_ = GetAttr>(kernel_node, "string_pos"); + } size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node); + input_flag_ = SetInputFlag(&string_pos_, input_tensor_num); input_device_data_ = std::make_unique(input_tensor_num); std::vector value_shape; for (size_t i = 0; i < input_tensor_num; i++) { @@ -103,6 +114,9 @@ class PrintGpuKernel : public GpuKernel { } void ResetResource() noexcept override { + string_value_.clear(); + string_pos_.clear(); + input_flag_.clear(); input_device_data_ = nullptr; input_size_.clear(); input_shape_.clear(); @@ -146,14 +160,52 @@ class PrintGpuKernel : public GpuKernel { return kTypeUnknown; } + std::vector SetInputFlag(std::vector *string_pos, size_t input_tensor_num) { + // -1 -> string position + // others -> input tensor position + std::vector res(string_pos->size() + input_tensor_num); + // without string inputs + int64_t value = 0; + if (res.size() == input_tensor_num) { + std::generate(res.begin(), res.end(), [&value]() { return value++; }); + return res; + } + for (size_t i = 0; i < string_pos->size(); i++) { + if ((*string_pos)[i] < 0) { + MS_LOG(EXCEPTION) << "string_pos cannot be a negative value"; + } + auto index = IntToSize((*string_pos)[i]); + res[index] = -1; + } + for (size_t i = 0; i < res.size(); i++) { + if (res[i] != -1) { + res[i] += value; + value++; + } + } + return res; + } + + std::string GetTensorString(std::vector> *input_shape, size_t index, TypeId type_id, + std::vector> *input_host_data, std::vector *input_size) { + ShapeVector shape; + (void)std::transform((*input_shape)[index].begin(), (*input_shape)[index].end(), std::back_inserter(shape), + [](const size_t &value) { return static_cast(value); }); + Tensor current_tensor(type_id, shape, (*input_host_data)[index].get(), (*input_size)[index] * sizeof(T)); + return current_tensor.ToStringNoLimit(); + } + private: + std::vector string_value_; + std::vector string_pos_; + std::vector input_flag_; std::unique_ptr input_device_data_; std::vector input_size_; - std::vector > input_shape_; + std::vector> input_shape_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; -}; // namespace kernel +}; } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/print_reduce_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/print_reduce_fusion.cc new file mode 100644 index 00000000000..9c6bcf42a0d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/print_reduce_fusion.cc @@ -0,0 +1,154 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/print_reduce_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { + std::vector inputs_format; + std::vector outputs_format; + std::vector inputs_type; + std::vector outputs_type; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + + size_t input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t input_index = 0; input_index < input_num; input_index++) { + inputs_format.push_back(kOpFormat_DEFAULT); + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); + } + size_t output_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t output_index = 0; output_index < output_num; output_index++) { + outputs_format.push_back(kOpFormat_DEFAULT); + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); + } + + builder.SetInputsFormat(inputs_format); + builder.SetOutputsFormat(outputs_format); + builder.SetInputsDeviceType(inputs_type); + builder.SetOutputsDeviceType(outputs_type); + return builder.Build(); +} + +bool GetOptList(const std::vector &node_list, std::vector *opt_list, + std::vector> *string_pos_vec, + std::vector> *string_value_vec) { + for (auto &node : node_list) { + // {prim::kPrimPrint} only print with string will be reduced + std::vector string_pos; + std::vector string_value; + if (IsPrimitiveCNode(node, prim::kPrimPrint)) { + size_t input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t i = 0; i < input_num; i++) { + auto current_node = AnfAlgo::GetInputNode(utils::cast(node), i); + // not a string + if (current_node->cast() == nullptr) { + continue; + } + auto value_node = current_node->cast()->value(); + if (value_node->type()->generic_type_id() == kObjectTypeString) { + auto current_string_value = GetValue(value_node); + string_pos.push_back(i); + string_value.push_back(std::string(current_string_value)); + } else { + MS_LOG(EXCEPTION) << "Current value node is not string or tensor"; + } + } + if (string_pos.size() != 0) { + opt_list->push_back(node); + string_pos_vec->push_back(string_pos); + string_value_vec->push_back(string_value); + } + } + } + if (opt_list->size() == 0) { + return false; + } + return true; +} + +bool PrintReduceFusion::Run(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + std::vector node_list = TopoSort(graph->get_return()); + std::vector opt_list; + std::vector> string_pos_vec; + std::vector> string_value_vec; + if (!GetOptList(node_list, &opt_list, &string_pos_vec, &string_value_vec)) { + return false; + } + for (size_t idx = 0; idx < opt_list.size(); idx++) { + auto node = opt_list[idx]; + CNodePtr cnode = utils::cast(node); + size_t input_num = AnfAlgo::GetInputTensorNum(cnode); + auto prim = std::make_shared("Print"); + std::vector inputs = {NewValueNode(prim)}; + auto string_pos = string_pos_vec[idx]; + std::vector input_flag(input_num); + for (size_t i = 0; i < string_pos.size(); i++) { + if (string_pos[i] < 0) { + MS_LOG(EXCEPTION) << "string_pos cannot be a negative value"; + } + size_t index = LongToSize(string_pos[i]); + input_flag[index] = -1; + } + for (size_t i = 0; i < input_flag.size(); i++) { + if (input_flag[i] == -1) { + continue; + } + auto input_tensor = AnfAlgo::GetInputNode(cnode, i); + MS_EXCEPTION_IF_NULL(input_tensor); + inputs.push_back(input_tensor); + } + // add monad + auto monad_node = AnfAlgo::GetInputNode(cnode, input_flag.size()); + MS_EXCEPTION_IF_NULL(monad_node); + inputs.push_back(monad_node); + auto string_value = string_value_vec[idx]; + // create new cnode + auto print_fused = graph->NewCNode(inputs); + // hand over the attrs to new print + AnfAlgo::SetNodeAttr("string_pos", MakeValue>(string_pos), print_fused); + AnfAlgo::SetNodeAttr("string_value", MakeValue>(string_value), print_fused); + // set output type and shape + std::vector types; + std::vector> shapes; + size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); + for (size_t i = 0; i < output_num; i++) { + types.push_back(AnfAlgo::GetOutputInferDataType(cnode, i)); + shapes.push_back(AnfAlgo::GetOutputInferShape(cnode, i)); + } + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, print_fused.get()); + // add build info + auto build_info = GenerateKernelBuildInfo(print_fused); + AnfAlgo::SetSelectKernelBuildInfo(build_info, print_fused.get()); + if (!manager->Replace(cnode, print_fused)) { + MS_LOG(EXCEPTION) << "manager replace node failed in print reduce fusion."; + } + } + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/print_reduce_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/print_reduce_fusion.h new file mode 100644 index 00000000000..2c4f2152955 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/print_reduce_fusion.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_PRINT_REDUCE_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_PRINT_REDUCE_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class PrintReduceFusion : public Pass { + public: + explicit PrintReduceFusion(const std::string &name) : Pass("print_reduce") {} + ~PrintReduceFusion() override = default; + bool Run(const FuncGraphPtr &graph) override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_PRINT_REDUCE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 501dadf2c74..f8b9fba49ad 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -37,6 +37,7 @@ #include "backend/optimizer/gpu/insert_format_transform_op.h" #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" #include "backend/optimizer/gpu/replace_addn_fusion.h" +#include "backend/optimizer/gpu/print_reduce_fusion.h" #include "backend/optimizer/gpu/remove_format_transform_pair.h" #include "backend/optimizer/gpu/remove_redundant_format_transform.h" #include "backend/optimizer/gpu/reduce_precision_fusion.h" @@ -141,6 +142,7 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { pm->AddPass(std::make_shared("combine_momentum")); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared("print_reduce")); optimizer->AddPassManager(pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 3d65a2ada4f..5484437b983 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -346,11 +346,10 @@ class Print(PrimitiveWithInfer): In pynative mode, please use python print function. In graph mode, the bool, int, float, tuple, and list would be converted into Tensor to print, str remains unchanged. - In GPU, all input elements should be the same type and string is not supported. Inputs: - **input_x** (Union[Tensor, bool, int, float, str, tuple, list]) - The graph node to attach to. - Supports multiple inputs which are separated by ','. GPU does not support string as an input. + Supports multiple inputs which are separated by ','. Supported Platforms: ``Ascend`` ``GPU`` diff --git a/tests/st/ops/gpu/test_print_op.py b/tests/st/ops/gpu/test_print_op.py index 2c3f2effc7e..17f3f6c5dc6 100644 --- a/tests/st/ops/gpu/test_print_op.py +++ b/tests/st/ops/gpu/test_print_op.py @@ -71,6 +71,27 @@ def print_testcase(nptype): net_2(x, y) net_3(x) +class PrintNetString(nn.Cell): + def __init__(self): + super(PrintNetString, self).__init__() + self.op = P.Print() + + def construct(self, x, y): + self.op("The first Tensor is", x) + self.op("The second Tensor is", y) + self.op("This line only prints string", "Another line") + self.op("The first Tensor is", x, y, "is the second Tensor") + return x + +def print_testcase_string(nptype): + x = np.ones(18).astype(nptype) + y = np.arange(9).reshape(3, 3).astype(nptype) + x = Tensor(x) + y = Tensor(y) + # graph mode + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = PrintNetString() + net(x, y) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -147,3 +168,10 @@ def test_print_float16(): @pytest.mark.env_onecard def test_print_float32(): print_testcase(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_print_string(): + print_testcase_string(np.float32)