forked from mindspore-Ecosystem/mindspore
!13262 Support string in GPU Print
From: @TFbunny Reviewed-by: Signed-off-by:
This commit is contained in:
commit
be797d821f
|
@ -48,7 +48,7 @@ class PrintGpuKernel : public GpuKernel {
|
|||
}
|
||||
int *output_address = GetDeviceAddress<int>(outputs, 0);
|
||||
// host initialization
|
||||
std::vector<std::unique_ptr<T[]> > input_host_data;
|
||||
std::vector<std::unique_ptr<T[]>> input_host_data;
|
||||
for (size_t i = 0; i < input_size_.size(); i++) {
|
||||
std::unique_ptr<T[]> value = std::make_unique<T[]>(input_size_[i]);
|
||||
input_host_data.push_back(std::move(value));
|
||||
|
@ -60,19 +60,25 @@ class PrintGpuKernel : public GpuKernel {
|
|||
MS_LOG(EXCEPTION) << "GPU print does not support the input type.";
|
||||
}
|
||||
// print core function
|
||||
for (size_t i = 0; i < input_host_data.size(); i++) {
|
||||
std::string error_msg = "cudaMemcpy print loop failed at input_device_data[";
|
||||
error_msg.append(std::to_string(i));
|
||||
error_msg.append("].");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudaMemcpy(input_host_data[i].get(), input_device_data_[i], input_size_[i] * sizeof(T), cudaMemcpyDeviceToHost),
|
||||
error_msg);
|
||||
ShapeVector shape;
|
||||
(void)std::transform(input_shape_[i].begin(), input_shape_[i].end(), std::back_inserter(shape),
|
||||
[](const size_t &value) { return static_cast<int64_t>(value); });
|
||||
Tensor current_tensor(type_id, shape, input_host_data[i].get(), input_size_[i] * sizeof(T));
|
||||
std::cout << current_tensor.ToString() << std::endl;
|
||||
size_t string_idx = 0;
|
||||
for (size_t i = 0; i < input_flag_.size(); i++) {
|
||||
if (input_flag_[i] == -1) {
|
||||
std::cout << string_value_[string_idx] << std::endl;
|
||||
string_idx++;
|
||||
} else {
|
||||
size_t tensor_idx = LongToSize(input_flag_[i]);
|
||||
std::string error_msg = "cudaMemcpyAsync print loop failed at input_device_data[";
|
||||
error_msg.append(std::to_string(tensor_idx));
|
||||
error_msg.append("].");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(input_host_data[tensor_idx].get(), input_device_data_[tensor_idx],
|
||||
input_size_[tensor_idx] * sizeof(T), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
error_msg);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed - Print");
|
||||
auto current_string = GetTensorString(&input_shape_, tensor_idx, type_id, &input_host_data, &input_size_);
|
||||
std::cout << current_string << std::endl;
|
||||
}
|
||||
}
|
||||
int output = 1;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
|
@ -84,7 +90,12 @@ class PrintGpuKernel : public GpuKernel {
|
|||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
if (AnfAlgo::HasNodeAttr("string_pos", kernel_node)) {
|
||||
string_value_ = GetAttr<std::vector<std::string>>(kernel_node, "string_value");
|
||||
string_pos_ = GetAttr<std::vector<int64_t>>(kernel_node, "string_pos");
|
||||
}
|
||||
size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
input_flag_ = SetInputFlag(&string_pos_, input_tensor_num);
|
||||
input_device_data_ = std::make_unique<T *[]>(input_tensor_num);
|
||||
std::vector<size_t> value_shape;
|
||||
for (size_t i = 0; i < input_tensor_num; i++) {
|
||||
|
@ -103,6 +114,9 @@ class PrintGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
string_value_.clear();
|
||||
string_pos_.clear();
|
||||
input_flag_.clear();
|
||||
input_device_data_ = nullptr;
|
||||
input_size_.clear();
|
||||
input_shape_.clear();
|
||||
|
@ -146,14 +160,52 @@ class PrintGpuKernel : public GpuKernel {
|
|||
return kTypeUnknown;
|
||||
}
|
||||
|
||||
std::vector<int64_t> SetInputFlag(std::vector<int64_t> *string_pos, size_t input_tensor_num) {
|
||||
// -1 -> string position
|
||||
// others -> input tensor position
|
||||
std::vector<int64_t> res(string_pos->size() + input_tensor_num);
|
||||
// without string inputs
|
||||
int64_t value = 0;
|
||||
if (res.size() == input_tensor_num) {
|
||||
std::generate(res.begin(), res.end(), [&value]() { return value++; });
|
||||
return res;
|
||||
}
|
||||
for (size_t i = 0; i < string_pos->size(); i++) {
|
||||
if ((*string_pos)[i] < 0) {
|
||||
MS_LOG(EXCEPTION) << "string_pos cannot be a negative value";
|
||||
}
|
||||
auto index = IntToSize((*string_pos)[i]);
|
||||
res[index] = -1;
|
||||
}
|
||||
for (size_t i = 0; i < res.size(); i++) {
|
||||
if (res[i] != -1) {
|
||||
res[i] += value;
|
||||
value++;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string GetTensorString(std::vector<std::vector<size_t>> *input_shape, size_t index, TypeId type_id,
|
||||
std::vector<std::unique_ptr<T[]>> *input_host_data, std::vector<size_t> *input_size) {
|
||||
ShapeVector shape;
|
||||
(void)std::transform((*input_shape)[index].begin(), (*input_shape)[index].end(), std::back_inserter(shape),
|
||||
[](const size_t &value) { return static_cast<int64_t>(value); });
|
||||
Tensor current_tensor(type_id, shape, (*input_host_data)[index].get(), (*input_size)[index] * sizeof(T));
|
||||
return current_tensor.ToStringNoLimit();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::string> string_value_;
|
||||
std::vector<int64_t> string_pos_;
|
||||
std::vector<int64_t> input_flag_;
|
||||
std::unique_ptr<T *[]> input_device_data_;
|
||||
std::vector<size_t> input_size_;
|
||||
std::vector<std::vector<size_t> > input_shape_;
|
||||
std::vector<std::vector<size_t>> input_shape_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
}; // namespace kernel
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/gpu/print_reduce_fusion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
std::vector<TypeId> outputs_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t input_index = 0; input_index < input_num; input_index++) {
|
||||
inputs_format.push_back(kOpFormat_DEFAULT);
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t output_index = 0; output_index < output_num; output_index++) {
|
||||
outputs_format.push_back(kOpFormat_DEFAULT);
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||
}
|
||||
|
||||
builder.SetInputsFormat(inputs_format);
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
builder.SetInputsDeviceType(inputs_type);
|
||||
builder.SetOutputsDeviceType(outputs_type);
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
bool GetOptList(const std::vector<AnfNodePtr> &node_list, std::vector<AnfNodePtr> *opt_list,
|
||||
std::vector<std::vector<int64_t>> *string_pos_vec,
|
||||
std::vector<std::vector<std::string>> *string_value_vec) {
|
||||
for (auto &node : node_list) {
|
||||
// {prim::kPrimPrint} only print with string will be reduced
|
||||
std::vector<int64_t> string_pos;
|
||||
std::vector<std::string> string_value;
|
||||
if (IsPrimitiveCNode(node, prim::kPrimPrint)) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
auto current_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i);
|
||||
// not a string
|
||||
if (current_node->cast<ValueNodePtr>() == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto value_node = current_node->cast<ValueNodePtr>()->value();
|
||||
if (value_node->type()->generic_type_id() == kObjectTypeString) {
|
||||
auto current_string_value = GetValue<std::string>(value_node);
|
||||
string_pos.push_back(i);
|
||||
string_value.push_back(std::string(current_string_value));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Current value node is not string or tensor";
|
||||
}
|
||||
}
|
||||
if (string_pos.size() != 0) {
|
||||
opt_list->push_back(node);
|
||||
string_pos_vec->push_back(string_pos);
|
||||
string_value_vec->push_back(string_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (opt_list->size() == 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PrintReduceFusion::Run(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
|
||||
std::vector<AnfNodePtr> opt_list;
|
||||
std::vector<std::vector<int64_t>> string_pos_vec;
|
||||
std::vector<std::vector<std::string>> string_value_vec;
|
||||
if (!GetOptList(node_list, &opt_list, &string_pos_vec, &string_value_vec)) {
|
||||
return false;
|
||||
}
|
||||
for (size_t idx = 0; idx < opt_list.size(); idx++) {
|
||||
auto node = opt_list[idx];
|
||||
CNodePtr cnode = utils::cast<CNodePtr>(node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
auto prim = std::make_shared<Primitive>("Print");
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||
auto string_pos = string_pos_vec[idx];
|
||||
std::vector<int64_t> input_flag(input_num);
|
||||
for (size_t i = 0; i < string_pos.size(); i++) {
|
||||
if (string_pos[i] < 0) {
|
||||
MS_LOG(EXCEPTION) << "string_pos cannot be a negative value";
|
||||
}
|
||||
size_t index = LongToSize(string_pos[i]);
|
||||
input_flag[index] = -1;
|
||||
}
|
||||
for (size_t i = 0; i < input_flag.size(); i++) {
|
||||
if (input_flag[i] == -1) {
|
||||
continue;
|
||||
}
|
||||
auto input_tensor = AnfAlgo::GetInputNode(cnode, i);
|
||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||
inputs.push_back(input_tensor);
|
||||
}
|
||||
// add monad
|
||||
auto monad_node = AnfAlgo::GetInputNode(cnode, input_flag.size());
|
||||
MS_EXCEPTION_IF_NULL(monad_node);
|
||||
inputs.push_back(monad_node);
|
||||
auto string_value = string_value_vec[idx];
|
||||
// create new cnode
|
||||
auto print_fused = graph->NewCNode(inputs);
|
||||
// hand over the attrs to new print
|
||||
AnfAlgo::SetNodeAttr("string_pos", MakeValue<std::vector<int64_t>>(string_pos), print_fused);
|
||||
AnfAlgo::SetNodeAttr("string_value", MakeValue<std::vector<std::string>>(string_value), print_fused);
|
||||
// set output type and shape
|
||||
std::vector<TypeId> types;
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
types.push_back(AnfAlgo::GetOutputInferDataType(cnode, i));
|
||||
shapes.push_back(AnfAlgo::GetOutputInferShape(cnode, i));
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, print_fused.get());
|
||||
// add build info
|
||||
auto build_info = GenerateKernelBuildInfo(print_fused);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, print_fused.get());
|
||||
if (!manager->Replace(cnode, print_fused)) {
|
||||
MS_LOG(EXCEPTION) << "manager replace node failed in print reduce fusion.";
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_PRINT_REDUCE_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_PRINT_REDUCE_FUSION_H_
|
||||
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class PrintReduceFusion : public Pass {
|
||||
public:
|
||||
explicit PrintReduceFusion(const std::string &name) : Pass("print_reduce") {}
|
||||
~PrintReduceFusion() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_PRINT_REDUCE_FUSION_H_
|
|
@ -37,6 +37,7 @@
|
|||
#include "backend/optimizer/gpu/insert_format_transform_op.h"
|
||||
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
|
||||
#include "backend/optimizer/gpu/replace_addn_fusion.h"
|
||||
#include "backend/optimizer/gpu/print_reduce_fusion.h"
|
||||
#include "backend/optimizer/gpu/remove_format_transform_pair.h"
|
||||
#include "backend/optimizer/gpu/remove_redundant_format_transform.h"
|
||||
#include "backend/optimizer/gpu/reduce_precision_fusion.h"
|
||||
|
@ -141,6 +142,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
|
||||
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
||||
pm->AddPass(std::make_shared<opt::PrintReduceFusion>("print_reduce"));
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
|
|
@ -346,11 +346,10 @@ class Print(PrimitiveWithInfer):
|
|||
In pynative mode, please use python print function.
|
||||
In graph mode, the bool, int, float, tuple, and list would be converted into Tensor to print,
|
||||
str remains unchanged.
|
||||
In GPU, all input elements should be the same type and string is not supported.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Union[Tensor, bool, int, float, str, tuple, list]) - The graph node to attach to.
|
||||
Supports multiple inputs which are separated by ','. GPU does not support string as an input.
|
||||
Supports multiple inputs which are separated by ','.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
|
|
@ -71,6 +71,27 @@ def print_testcase(nptype):
|
|||
net_2(x, y)
|
||||
net_3(x)
|
||||
|
||||
class PrintNetString(nn.Cell):
|
||||
def __init__(self):
|
||||
super(PrintNetString, self).__init__()
|
||||
self.op = P.Print()
|
||||
|
||||
def construct(self, x, y):
|
||||
self.op("The first Tensor is", x)
|
||||
self.op("The second Tensor is", y)
|
||||
self.op("This line only prints string", "Another line")
|
||||
self.op("The first Tensor is", x, y, "is the second Tensor")
|
||||
return x
|
||||
|
||||
def print_testcase_string(nptype):
|
||||
x = np.ones(18).astype(nptype)
|
||||
y = np.arange(9).reshape(3, 3).astype(nptype)
|
||||
x = Tensor(x)
|
||||
y = Tensor(y)
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = PrintNetString()
|
||||
net(x, y)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -147,3 +168,10 @@ def test_print_float16():
|
|||
@pytest.mark.env_onecard
|
||||
def test_print_float32():
|
||||
print_testcase(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_string():
|
||||
print_testcase_string(np.float32)
|
||||
|
|
Loading…
Reference in New Issue