add reshape type to tensor

This commit is contained in:
WilliamLian 2020-08-13 16:42:02 +08:00
parent 0ff1000b24
commit 6760d9976d
14 changed files with 43 additions and 35 deletions

View File

@ -30,14 +30,6 @@ namespace mindspore {
enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AKG_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL }; enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AKG_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL };
namespace kernel { namespace kernel {
enum Axis : int {
N = 0,
C,
H,
W,
};
// Supported fusion type // Supported fusion type
enum FusionType { enum FusionType {
CONVLUTION = 0, CONVLUTION = 0,

View File

@ -22,6 +22,7 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include "ir/dtype.h" #include "ir/dtype.h"
#include "ir/kernel_info_dev.h"
#include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel.h"
namespace mindspore { namespace mindspore {

View File

@ -406,16 +406,16 @@ void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, st
for (const auto &c : reshape_type_str) { for (const auto &c : reshape_type_str) {
switch (c) { switch (c) {
case 'N': case 'N':
reshape_type_vec->push_back(kernel::N); reshape_type_vec->push_back(N);
break; break;
case 'C': case 'C':
reshape_type_vec->push_back(kernel::C); reshape_type_vec->push_back(C);
break; break;
case 'H': case 'H':
reshape_type_vec->push_back(kernel::H); reshape_type_vec->push_back(H);
break; break;
case 'W': case 'W':
reshape_type_vec->push_back(kernel::W); reshape_type_vec->push_back(W);
break; break;
default: default:
MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";

View File

@ -55,7 +55,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
CNodePtr trans_data = nullptr; CNodePtr trans_data = nullptr;
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0);
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT;
std::vector<kernel::Axis> padding_axis; std::vector<Axis> padding_axis;
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
// if insert transdata for input we need to change the input // if insert transdata for input we need to change the input
if (is_insert_input) { if (is_insert_input) {
@ -170,7 +170,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
} }
} // namespace } // namespace
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type, const AnfNodePtr &trans_data, const std::vector<Axis> &reshape_type,
const TypeId &type_id) { const TypeId &type_id) {
MS_EXCEPTION_IF_NULL(trans_data); MS_EXCEPTION_IF_NULL(trans_data);
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data); auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);

View File

@ -86,7 +86,7 @@ class OpFinder {
using OpFinderPtr = std::shared_ptr<OpFinder>; using OpFinderPtr = std::shared_ptr<OpFinder>;
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {}, const AnfNodePtr &trans_data, const std::vector<Axis> &reshape_type = {},
const TypeId &type_id = kTypeUnknown); const TypeId &type_id = kTypeUnknown);
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,

View File

@ -418,7 +418,7 @@ std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_n
return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
} }
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) { std::vector<Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second); return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
} }
@ -483,7 +483,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n
return trans::TransShapeToDevice(infer_shape, format); return trans::TransShapeToDevice(infer_shape, format);
} }
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { std::vector<Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (input_idx > GetInputTensorNum(node)) { if (input_idx > GetInputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index:" << input_idx MS_LOG(EXCEPTION) << "The index:" << input_idx
@ -503,7 +503,7 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode
return build_info->GetInputReshapeType(input_idx); return build_info->GetInputReshapeType(input_idx);
} }
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) { std::vector<Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (output_idx > GetOutputTensorNum(node)) { if (output_idx > GetOutputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "

View File

@ -27,6 +27,7 @@
#include "ir/dtype.h" #include "ir/dtype.h"
#include "base/base.h" #include "base/base.h"
#include "ir/primitive.h" #include "ir/primitive.h"
#include "ir/kernel_info_dev.h"
#include "runtime/device/device_address.h" #include "runtime/device/device_address.h"
#include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/kernel_build_info.h" #include "backend/kernel_compiler/kernel_build_info.h"
@ -109,7 +110,7 @@ class AnfRuntimeAlgorithm {
// get output format from prev node,input_index is the input index of current node related to prev node // get output format from prev node,input_index is the input index of current node related to prev node
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
// get reshape_type of from the output of input node. // get reshape_type of from the output of input node.
static std::vector<kernel::Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx); static std::vector<Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
// get output shapes inferred by ME from input nodes. // get output shapes inferred by ME from input nodes.
static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx); static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
// get input shapes inferred by ME from input nodes. // get input shapes inferred by ME from input nodes.
@ -119,9 +120,9 @@ class AnfRuntimeAlgorithm {
// get input shapes which will built and run in device // get input shapes which will built and run in device
static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx); static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx);
// Get Input Padding Axis // Get Input Padding Axis
static std::vector<kernel::Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); static std::vector<Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
// Get Output Padding Axis // Get Output Padding Axis
static std::vector<kernel::Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); static std::vector<Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
// get output data type inferred by ME of anf node // get output data type inferred by ME of anf node
static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx); static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
// get output original data type from prev node,input_index is the input index of current node related to prev node // get output original data type from prev node,input_index is the input index of current node related to prev node

View File

@ -66,12 +66,13 @@ tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index
if (type_id == kTypeUnknown) { if (type_id == kTypeUnknown) {
type_id = AnfAlgo::GetOutputInferDataType(node, output_index); type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
} }
tensor::TensorPtr tensor; tensor::TensorPtr tensor = nullptr;
std::vector<int> temp_shape; std::vector<int> temp_shape;
if (graph->IsUniqueTargetInternalOutput(node, output_index)) { if (graph->IsUniqueTargetInternalOutput(node, output_index)) {
temp_shape.emplace_back(1); temp_shape.emplace_back(1);
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_device_address(address); tensor->set_device_address(address);
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
tensor->set_dirty(false); tensor->set_dirty(false);
return tensor; return tensor;
} }
@ -86,6 +87,7 @@ tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index
graph->AddInternalOutputTensor(node, output_index, tensor); graph->AddInternalOutputTensor(node, output_index, tensor);
} }
} }
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
// if in paynative mode,data only copyed to host when user want to print data // if in paynative mode,data only copyed to host when user want to print data
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
@ -240,6 +242,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
} else { } else {
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()}); kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()}); kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()});
kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()});
} }
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get()); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
// construct abstract of parameter // construct abstract of parameter

View File

@ -399,7 +399,7 @@ std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
return shape; return shape;
} }
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<kernel::Axis> &padding_axis) { std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis) {
if (padding_axis.empty() || shape.size() != padding_axis.size()) { if (padding_axis.empty() || shape.size() != padding_axis.size()) {
return PaddingShapeTo4dByDefault(shape); return PaddingShapeTo4dByDefault(shape);
} }

View File

@ -51,8 +51,7 @@ size_t TypeIdSize(const TypeId data_type);
size_t ShapeSize(const std::vector<size_t> &shape); size_t ShapeSize(const std::vector<size_t> &shape);
size_t CubeSizeByType(const TypeId data_type); size_t CubeSizeByType(const TypeId data_type);
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis = {});
const std::vector<kernel::Axis> &padding_axis = {});
std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index); std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
bool IsNeedPadding(const std::string &format, const size_t shape_size); bool IsNeedPadding(const std::string &format, const size_t shape_size);
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format); std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format);

View File

@ -20,6 +20,12 @@
#include <memory> #include <memory>
namespace mindspore { namespace mindspore {
enum Axis : int {
N = 0,
C,
H,
W,
};
// Interface for device kernel program information. // Interface for device kernel program information.
class KernelInfoDevice { class KernelInfoDevice {
public: public:

View File

@ -384,7 +384,8 @@ Tensor::Tensor(const Tensor &tensor)
data_(tensor.data_), data_(tensor.data_),
dirty_(tensor.dirty_), dirty_(tensor.dirty_),
id_(tensor.id_), id_(tensor.id_),
device_sync_(tensor.device_sync_) {} device_sync_(tensor.device_sync_),
padding_type_(tensor.padding_type()) {}
Tensor::Tensor(const Tensor &tensor, TypeId data_type) Tensor::Tensor(const Tensor &tensor, TypeId data_type)
: MetaTensor(data_type, tensor.shape_), : MetaTensor(data_type, tensor.shape_),
@ -392,7 +393,8 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)), data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)),
dirty_(tensor.dirty_), dirty_(tensor.dirty_),
id_(tensor.id_), id_(tensor.id_),
device_sync_(tensor.device_sync_) {} device_sync_(tensor.device_sync_),
padding_type_(tensor.padding_type()) {}
Tensor::Tensor(TypeId data_type, const std::vector<int> &shape, TensorDataPtr data) Tensor::Tensor(TypeId data_type, const std::vector<int> &shape, TensorDataPtr data)
: MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {} : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {}
@ -441,6 +443,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
device_sync_ = tensor.device_sync_; device_sync_ = tensor.device_sync_;
data_ = tensor.data_; data_ = tensor.data_;
id_ = tensor.id_; id_ = tensor.id_;
padding_type_ = tensor.padding_type_;
} }
return *this; return *this;
} }

View File

@ -221,6 +221,8 @@ class Tensor : public MetaTensor {
DeviceSyncPtr device_address() const { return device_sync_; } DeviceSyncPtr device_address() const { return device_sync_; }
void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; } void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; }
void set_padding_type(std::vector<Axis> padding_type) { padding_type_ = padding_type; }
std::vector<Axis> padding_type() const { return padding_type_; }
std::string id() const { return id_; } std::string id() const { return id_; }
@ -230,6 +232,7 @@ class Tensor : public MetaTensor {
bool dirty_{true}; bool dirty_{true};
std::string id_{""}; std::string id_{""};
DeviceSyncPtr device_sync_{nullptr}; DeviceSyncPtr device_sync_{nullptr};
std::vector<Axis> padding_type_;
}; };
using TensorPtr = std::shared_ptr<Tensor>; using TensorPtr = std::shared_ptr<Tensor>;
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>; using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;

View File

@ -54,10 +54,10 @@ class A {
std::shared_ptr<int> i; std::shared_ptr<int> i;
}; };
class C : public A { class Ca : public A {
public: public:
C() {} Ca() {}
explicit C(signals *sigs) : A(sigs) { printf("conn C:%p\n", this); } explicit Ca(signals *sigs) : A(sigs) { printf("conn C:%p\n", this); }
void FuncA(int v1, float v2, std::string str) { printf("C: --%d--%f--%s--\n", v1, v2, str.c_str()); } void FuncA(int v1, float v2, std::string str) { printf("C: --%d--%f--%s--\n", v1, v2, str.c_str()); }
}; };
@ -71,13 +71,13 @@ class B : public A {
TEST_F(TestSignal, test_common) { TEST_F(TestSignal, test_common) {
A objA; A objA;
B objB; B objB;
C objC; Ca objC;
Signal<void(int, float, std::string)> signal; Signal<void(int, float, std::string)> signal;
signal.connect(&objA, &A::FuncA); signal.connect(&objA, &A::FuncA);
signal.connect(&objB, &B::FuncA); signal.connect(&objB, &B::FuncA);
signal.connect(&objC, &C::FuncA); signal.connect(&objC, &Ca::FuncA);
signal(20, 20, "Signal-Slot test"); signal(20, 20, "Signal-Slot test");
} }
@ -85,11 +85,11 @@ TEST_F(TestSignal, test_sigs) {
signals sigs; signals sigs;
A objA(&sigs); A objA(&sigs);
B objB(&sigs); B objB(&sigs);
C objC(&sigs); Ca objC(&sigs);
sigs.signal.connect(&objA, &A::FuncA); sigs.signal.connect(&objA, &A::FuncA);
sigs.signal.connect(&objB, &B::FuncA); sigs.signal.connect(&objB, &B::FuncA);
sigs.signal.connect(&objC, &C::FuncA); sigs.signal.connect(&objC, &Ca::FuncA);
sigs.signal(20, 20, "sigs Signal-Slot test"); sigs.signal(20, 20, "sigs Signal-Slot test");
} }
@ -97,7 +97,7 @@ TEST_F(TestSignal, test_sigs_Named) {
signals sigs; signals sigs;
A objA(&sigs); A objA(&sigs);
B objB(&sigs); B objB(&sigs);
C objC(&sigs); Ca objC(&sigs);
sigs.signal(10, 20, "Signal-Slot test"); sigs.signal(10, 20, "Signal-Slot test");
std::shared_ptr<Named> a; std::shared_ptr<Named> a;