forked from mindspore-Ecosystem/mindspore
!45801 Supports the ability of the print operator on the CPU to print tuple constants.
Merge pull request !45801 from Margaret_wangrui/print_cpu_tuple
This commit is contained in:
commit
10b054a811
|
@ -11,6 +11,7 @@ mindspore/mindspore/lite/providers/nnie_proposal/src/proposal.cc:mindspore::prop
|
|||
mindspore/mindspore/core/abstract/ops/primitive_infer_map.cc:mindspore::abstract::GetPrimitiveToEvalImplMap
|
||||
mindspore/mindspore/core/abstract/ops/primitive_infer_map.cc:mindspore::abstract::GetHostDependsMap
|
||||
mindspore/mindspore/core/ir/tensor.cc:mindspore::tensor::MakeTensorData
|
||||
mindspore/mindspore/ccsrc/kernel/common_utils.cc:mindspore::kernel::UnitSizeInBytes
|
||||
mindspore/mindspore/ccsrc/frontend/optimizer/irpass.cc:mindspore::opt::irpass::OptimizeIRPassLib::OptimizeIRPassLib
|
||||
mindspore/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc:mindspore::parallel::GatherV2PInfo::CheckStrategy
|
||||
mindspore/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_kernel_runtime.cc:mindspore::device::gpu::GPUKernelRuntime::LaunchKernelDynamic
|
||||
|
|
|
@ -1113,6 +1113,9 @@ size_t UnitSizeInBytes(const mindspore::TypeId &t) {
|
|||
case kNumberTypeComplex128:
|
||||
bytes = sizeof(double) * complex_factor;
|
||||
break;
|
||||
case kObjectTypeString:
|
||||
bytes = sizeof(std::string);
|
||||
break;
|
||||
case kNumberTypeInt4:
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Invalid types for UnitSizeInBytes : " << TypeIdToString(t);
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_memory_manager.h"
|
||||
#include "plugin/device/cpu/optimizer/reg_cpu_const_input_to_attr.h"
|
||||
#include "plugin/device/cpu/optimizer/print_value_type.h"
|
||||
#include "plugin/device/cpu/hal/hardware/cpu_somas.h"
|
||||
#ifdef ENABLE_AKG
|
||||
#include "plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.h"
|
||||
|
@ -195,6 +196,7 @@ void CPUKernelExecutor::OptimizeGraphImpl(const KernelGraphPtr &graph) const {
|
|||
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
|
||||
pm->AddPass(std::make_shared<opt::InsertTensorMoveForCommunication>());
|
||||
pm->AddPass(std::make_shared<opt::AddTrainingAttr>());
|
||||
pm->AddPass(std::make_shared<opt::PrintValueType>("print_value_type"));
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(graph);
|
||||
graph->SetExecOrderByDefault();
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "plugin/factory/ms_factory.h"
|
||||
#include "runtime/device/kernel_runtime.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/optimizer/print_value_type.h"
|
||||
#ifdef ENABLE_AKG
|
||||
#include "plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.h"
|
||||
#endif
|
||||
|
@ -105,6 +106,7 @@ void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
pm->AddPass(std::make_shared<opt::InsertFormatTransformOpCPU>("insert_format_transform_op_cpu"));
|
||||
pm->AddPass(std::make_shared<opt::InsertCastCPU>("insert_cast"));
|
||||
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
|
||||
pm->AddPass(std::make_shared<opt::PrintValueType>("print_value_type"));
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
|
|
@ -15,11 +15,13 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/print_cpu_kernel.h"
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <complex>
|
||||
#include "ir/tensor.h"
|
||||
#include "ops/print.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -36,6 +38,14 @@ bool PrintCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::ve
|
|||
TypeId type = inputs[i]->GetDtype();
|
||||
(void)data_types_.emplace_back(type);
|
||||
}
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Print>(base_operator);
|
||||
if (kernel_ptr->HasAttr("value_type")) {
|
||||
auto value_type = kernel_ptr->get_value_type();
|
||||
auto value_type_pos = kernel_ptr->get_value_type_pos();
|
||||
for (size_t i = 0; i < value_type.size(); i++) {
|
||||
value_type_[value_type_pos[i]] = value_type[i];
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -46,9 +56,9 @@ int PrintCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::v
|
|||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
input_sizes_.clear();
|
||||
input_shapes_.clear();
|
||||
input_info_.clear();
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(inputs[i]);
|
||||
auto input_shape = inputs[i]->GetShapeVector();
|
||||
|
@ -57,7 +67,11 @@ int PrintCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::v
|
|||
for (size_t j = 0; j < input_shape.size(); ++j) {
|
||||
size *= input_shape[j];
|
||||
}
|
||||
auto type_id = inputs[i]->GetDtype();
|
||||
size_t unit_size = UnitSizeInBytes(type_id);
|
||||
auto size_in_byte = std::accumulate(input_shape.begin(), input_shape.end(), unit_size, std::multiplies<size_t>());
|
||||
(void)input_sizes_.emplace_back(LongToSize(size));
|
||||
input_info_.push_back(std::make_tuple(size_in_byte, type_id));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@ -91,8 +105,20 @@ void PrintCpuKernelMod::LaunchKernel(size_t index, const std::vector<kernel::Add
|
|||
std::cout << *num << std::endl;
|
||||
}
|
||||
} else {
|
||||
TypeId type_id = std::get<1>(input_info_[index]);
|
||||
Tensor tensor(data_types_[index], input_shapes_[index], inputs[index]->addr, input_sizes_[index] * sizeof(T));
|
||||
std::cout << tensor.ToStringNoLimit() << std::endl;
|
||||
if (value_type_.count(index) > 0) {
|
||||
// not a tensor
|
||||
auto out = tensor.data().ToString(type_id, input_shapes_[index], true);
|
||||
if (value_type_[index] != 0) {
|
||||
// tuple, not scalar
|
||||
(void)std::replace(out.begin(), out.end(), '[', '(');
|
||||
(void)std::replace(out.begin(), out.end(), ']', ')');
|
||||
}
|
||||
std::cout << out << std::endl;
|
||||
} else {
|
||||
std::cout << tensor.ToStringNoLimit() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
#include <vector>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
|
@ -54,6 +56,9 @@ class PrintCpuKernelMod : public NativeCpuKernelMod {
|
|||
std::vector<ShapeVector> input_shapes_;
|
||||
std::vector<size_t> input_sizes_;
|
||||
std::vector<TypeId> data_types_;
|
||||
|
||||
std::unordered_map<int64_t, int64_t> value_type_;
|
||||
std::vector<std::tuple<size_t, TypeId>> input_info_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/cpu/optimizer/print_value_type.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
std::vector<TypeId> outputs_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t input_index = 0; input_index < input_num; input_index++) {
|
||||
inputs_format.push_back(kOpFormat_DEFAULT);
|
||||
inputs_type.push_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
}
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t output_index = 0; output_index < output_num; output_index++) {
|
||||
outputs_format.push_back(kOpFormat_DEFAULT);
|
||||
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||
}
|
||||
|
||||
builder.SetInputsFormat(inputs_format);
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
builder.SetInputsDeviceType(inputs_type);
|
||||
builder.SetOutputsDeviceType(outputs_type);
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
bool GetOptList(const std::vector<AnfNodePtr> &node_list, std::vector<AnfNodePtr> *opt_list,
|
||||
std::vector<std::vector<std::pair<int64_t, int64_t>>> *not_tensor_pos_vec) {
|
||||
MS_EXCEPTION_IF_NULL(opt_list);
|
||||
|
||||
for (auto &node : node_list) {
|
||||
// {prim::kPrimPrint} reduction only applies on print with string, tensor(scalar or tuple)
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<std::pair<int64_t, int64_t>> value_type;
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimPrint)) {
|
||||
continue;
|
||||
}
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
auto current_node = common::AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i);
|
||||
// not tensor(tuple, scalar, string)
|
||||
if (current_node->cast<ValueNodePtr>() == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto value_node = current_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto shape = value_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
auto shape_node = dyn_cast<abstract::Shape>(shape->GetShapeTrack());
|
||||
if (shape_node != nullptr) {
|
||||
// a scalar or tuple
|
||||
auto shape_size = shape_node->shape().size();
|
||||
if (shape_size != 0) {
|
||||
value_type.push_back(std::make_pair(i, 1));
|
||||
} else {
|
||||
value_type.push_back(std::make_pair(i, 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (value_type.size() != 0) {
|
||||
opt_list->push_back(node);
|
||||
not_tensor_pos_vec->push_back(value_type);
|
||||
}
|
||||
}
|
||||
if (opt_list->size() == 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PrintValueType::Run(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
|
||||
std::vector<AnfNodePtr> opt_list;
|
||||
// first is pos, second is type: 0 is Scalar, 1 is ValueTuple
|
||||
std::vector<std::vector<std::pair<int64_t, int64_t>>> not_tensor_pos_vec;
|
||||
if (!GetOptList(node_list, &opt_list, ¬_tensor_pos_vec)) {
|
||||
return false;
|
||||
}
|
||||
for (size_t idx = 0; idx < opt_list.size(); idx++) {
|
||||
auto node = opt_list[idx];
|
||||
CNodePtr cnode = utils::cast<CNodePtr>(node);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto value_type_vec = not_tensor_pos_vec[idx];
|
||||
// split value type and pos
|
||||
std::vector<int64_t> value_type_pos;
|
||||
std::vector<int64_t> value_type;
|
||||
(void)std::transform(value_type_vec.begin(), value_type_vec.end(), std::back_inserter(value_type_pos),
|
||||
[](const std::pair<int64_t, int64_t> &value) { return value.first; });
|
||||
(void)std::transform(value_type_vec.begin(), value_type_vec.end(), std::back_inserter(value_type),
|
||||
[](const std::pair<int64_t, int64_t> &value) { return value.second; });
|
||||
|
||||
// hand over the attrs to new print
|
||||
common::AnfAlgo::SetNodeAttr("value_type", MakeValue<std::vector<int64_t>>(value_type), cnode);
|
||||
common::AnfAlgo::SetNodeAttr("value_type_pos", MakeValue<std::vector<int64_t>>(value_type_pos), cnode);
|
||||
// set output type and shape
|
||||
std::vector<TypeId> types;
|
||||
std::vector<BaseShapePtr> shapes;
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
types.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, i));
|
||||
shapes.push_back(common::AnfAlgo::GetOutputDetailShape(cnode, i));
|
||||
}
|
||||
common::AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, cnode.get());
|
||||
// add build info
|
||||
auto build_info = GenerateKernelBuildInfo(cnode);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, cnode.get());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_OPTIMIZER_PRINT_VALUE_TYPE_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_OPTIMIZER_PRINT_VALUE_TYPE_H_
|
||||
|
||||
#include <string>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class PrintValueType : public Pass {
|
||||
public:
|
||||
explicit PrintValueType(const std::string &name) : Pass("print_value_type") {}
|
||||
~PrintValueType() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_OPTIMIZER_PRINT_VALUE_TYPE_H_
|
|
@ -91,3 +91,24 @@ def test_print_op_functional(mode):
|
|||
net = PrintFunc()
|
||||
x = Tensor(np.random.randn(3, 4, 5).astype(np.float32))
|
||||
net(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_cpu
|
||||
def test_print_op_tuple():
|
||||
"""
|
||||
Feature: cpu Print op.
|
||||
Description: test Print with tuple input.
|
||||
Expectation: success.
|
||||
"""
|
||||
class PrintTupleNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
tuple_x = tuple((1, 2, 3, 4, 5))
|
||||
ops.print_("tuple_x:", tuple_x, x, "print success!")
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net = PrintTupleNet()
|
||||
x = Tensor([6, 7, 8, 9, 10])
|
||||
net(x)
|
||||
|
|
|
@ -177,3 +177,25 @@ def test_print_dynamic_shape(mode):
|
|||
x_dyn = Tensor(shape=[None, None, None], dtype=ms.float32)
|
||||
net.set_inputs(x_dyn)
|
||||
net(x)
|
||||
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
def test_print_op_tuple():
|
||||
"""
|
||||
Feature: cpu Print op.
|
||||
Description: test Print with tuple input.
|
||||
Expectation: success.
|
||||
"""
|
||||
class PrintTupleNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
tuple_x = tuple((1, 2, 3, 4, 5))
|
||||
print("tuple_x:", tuple_x, x, "print success!")
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = PrintTupleNet()
|
||||
x = Tensor([6, 7, 8, 9, 10])
|
||||
net(x)
|
||||
|
|
Loading…
Reference in New Issue