add reshape type to tensor
This commit is contained in:
parent
0ff1000b24
commit
6760d9976d
|
@ -30,14 +30,6 @@ namespace mindspore {
|
|||
enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AKG_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL };
|
||||
|
||||
namespace kernel {
|
||||
|
||||
enum Axis : int {
|
||||
N = 0,
|
||||
C,
|
||||
H,
|
||||
W,
|
||||
};
|
||||
|
||||
// Supported fusion type
|
||||
enum FusionType {
|
||||
CONVLUTION = 0,
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/kernel_info_dev.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -406,16 +406,16 @@ void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, st
|
|||
for (const auto &c : reshape_type_str) {
|
||||
switch (c) {
|
||||
case 'N':
|
||||
reshape_type_vec->push_back(kernel::N);
|
||||
reshape_type_vec->push_back(N);
|
||||
break;
|
||||
case 'C':
|
||||
reshape_type_vec->push_back(kernel::C);
|
||||
reshape_type_vec->push_back(C);
|
||||
break;
|
||||
case 'H':
|
||||
reshape_type_vec->push_back(kernel::H);
|
||||
reshape_type_vec->push_back(H);
|
||||
break;
|
||||
case 'W':
|
||||
reshape_type_vec->push_back(kernel::W);
|
||||
reshape_type_vec->push_back(W);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
|
||||
|
|
|
@ -55,7 +55,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
CNodePtr trans_data = nullptr;
|
||||
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0);
|
||||
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT;
|
||||
std::vector<kernel::Axis> padding_axis;
|
||||
std::vector<Axis> padding_axis;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// if insert transdata for input we need to change the input
|
||||
if (is_insert_input) {
|
||||
|
@ -170,7 +170,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|||
}
|
||||
} // namespace
|
||||
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type,
|
||||
const AnfNodePtr &trans_data, const std::vector<Axis> &reshape_type,
|
||||
const TypeId &type_id) {
|
||||
MS_EXCEPTION_IF_NULL(trans_data);
|
||||
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
|
||||
|
|
|
@ -86,7 +86,7 @@ class OpFinder {
|
|||
using OpFinderPtr = std::shared_ptr<OpFinder>;
|
||||
|
||||
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {},
|
||||
const AnfNodePtr &trans_data, const std::vector<Axis> &reshape_type = {},
|
||||
const TypeId &type_id = kTypeUnknown);
|
||||
|
||||
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
|
||||
|
|
|
@ -418,7 +418,7 @@ std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_n
|
|||
return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
||||
}
|
||||
|
||||
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
|
||||
std::vector<Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
|
||||
return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
|
||||
}
|
||||
|
@ -483,7 +483,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n
|
|||
return trans::TransShapeToDevice(infer_shape, format);
|
||||
}
|
||||
|
||||
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
|
||||
std::vector<Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (input_idx > GetInputTensorNum(node)) {
|
||||
MS_LOG(EXCEPTION) << "The index:" << input_idx
|
||||
|
@ -503,7 +503,7 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode
|
|||
return build_info->GetInputReshapeType(input_idx);
|
||||
}
|
||||
|
||||
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
|
||||
std::vector<Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (output_idx > GetOutputTensorNum(node)) {
|
||||
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "ir/dtype.h"
|
||||
#include "base/base.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "ir/kernel_info_dev.h"
|
||||
#include "runtime/device/device_address.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
|
@ -109,7 +110,7 @@ class AnfRuntimeAlgorithm {
|
|||
// get output format from prev node,input_index is the input index of current node related to prev node
|
||||
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
|
||||
// get reshape_type of from the output of input node.
|
||||
static std::vector<kernel::Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
|
||||
static std::vector<Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
|
||||
// get output shapes inferred by ME from input nodes.
|
||||
static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
|
||||
// get input shapes inferred by ME from input nodes.
|
||||
|
@ -119,9 +120,9 @@ class AnfRuntimeAlgorithm {
|
|||
// get input shapes which will built and run in device
|
||||
static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx);
|
||||
// Get Input Padding Axis
|
||||
static std::vector<kernel::Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
|
||||
static std::vector<Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
|
||||
// Get Output Padding Axis
|
||||
static std::vector<kernel::Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
|
||||
static std::vector<Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
|
||||
// get output data type inferred by ME of anf node
|
||||
static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
|
||||
// get output original data type from prev node,input_index is the input index of current node related to prev node
|
||||
|
|
|
@ -66,12 +66,13 @@ tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index
|
|||
if (type_id == kTypeUnknown) {
|
||||
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
|
||||
}
|
||||
tensor::TensorPtr tensor;
|
||||
tensor::TensorPtr tensor = nullptr;
|
||||
std::vector<int> temp_shape;
|
||||
if (graph->IsUniqueTargetInternalOutput(node, output_index)) {
|
||||
temp_shape.emplace_back(1);
|
||||
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
tensor->set_device_address(address);
|
||||
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
|
||||
tensor->set_dirty(false);
|
||||
return tensor;
|
||||
}
|
||||
|
@ -86,6 +87,7 @@ tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index
|
|||
graph->AddInternalOutputTensor(node, output_index, tensor);
|
||||
}
|
||||
}
|
||||
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
|
||||
// if in paynative mode,data only copyed to host when user want to print data
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
@ -240,6 +242,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
|
|||
} else {
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
|
||||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()});
|
||||
kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()});
|
||||
}
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
|
||||
// construct abstract of parameter
|
||||
|
|
|
@ -399,7 +399,7 @@ std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
|
|||
return shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<kernel::Axis> &padding_axis) {
|
||||
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis) {
|
||||
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
|
||||
return PaddingShapeTo4dByDefault(shape);
|
||||
}
|
||||
|
|
|
@ -51,8 +51,7 @@ size_t TypeIdSize(const TypeId data_type);
|
|||
size_t ShapeSize(const std::vector<size_t> &shape);
|
||||
size_t CubeSizeByType(const TypeId data_type);
|
||||
|
||||
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape,
|
||||
const std::vector<kernel::Axis> &padding_axis = {});
|
||||
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis = {});
|
||||
std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
|
||||
bool IsNeedPadding(const std::string &format, const size_t shape_size);
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format);
|
||||
|
|
|
@ -20,6 +20,12 @@
|
|||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
enum Axis : int {
|
||||
N = 0,
|
||||
C,
|
||||
H,
|
||||
W,
|
||||
};
|
||||
// Interface for device kernel program information.
|
||||
class KernelInfoDevice {
|
||||
public:
|
||||
|
|
|
@ -384,7 +384,8 @@ Tensor::Tensor(const Tensor &tensor)
|
|||
data_(tensor.data_),
|
||||
dirty_(tensor.dirty_),
|
||||
id_(tensor.id_),
|
||||
device_sync_(tensor.device_sync_) {}
|
||||
device_sync_(tensor.device_sync_),
|
||||
padding_type_(tensor.padding_type()) {}
|
||||
|
||||
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
|
||||
: MetaTensor(data_type, tensor.shape_),
|
||||
|
@ -392,7 +393,8 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
|
|||
data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)),
|
||||
dirty_(tensor.dirty_),
|
||||
id_(tensor.id_),
|
||||
device_sync_(tensor.device_sync_) {}
|
||||
device_sync_(tensor.device_sync_),
|
||||
padding_type_(tensor.padding_type()) {}
|
||||
|
||||
Tensor::Tensor(TypeId data_type, const std::vector<int> &shape, TensorDataPtr data)
|
||||
: MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {}
|
||||
|
@ -441,6 +443,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
|
|||
device_sync_ = tensor.device_sync_;
|
||||
data_ = tensor.data_;
|
||||
id_ = tensor.id_;
|
||||
padding_type_ = tensor.padding_type_;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
|
|
@ -221,6 +221,8 @@ class Tensor : public MetaTensor {
|
|||
|
||||
DeviceSyncPtr device_address() const { return device_sync_; }
|
||||
void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; }
|
||||
void set_padding_type(std::vector<Axis> padding_type) { padding_type_ = padding_type; }
|
||||
std::vector<Axis> padding_type() const { return padding_type_; }
|
||||
|
||||
std::string id() const { return id_; }
|
||||
|
||||
|
@ -230,6 +232,7 @@ class Tensor : public MetaTensor {
|
|||
bool dirty_{true};
|
||||
std::string id_{""};
|
||||
DeviceSyncPtr device_sync_{nullptr};
|
||||
std::vector<Axis> padding_type_;
|
||||
};
|
||||
using TensorPtr = std::shared_ptr<Tensor>;
|
||||
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
|
||||
|
|
|
@ -54,10 +54,10 @@ class A {
|
|||
std::shared_ptr<int> i;
|
||||
};
|
||||
|
||||
class C : public A {
|
||||
class Ca : public A {
|
||||
public:
|
||||
C() {}
|
||||
explicit C(signals *sigs) : A(sigs) { printf("conn C:%p\n", this); }
|
||||
Ca() {}
|
||||
explicit Ca(signals *sigs) : A(sigs) { printf("conn C:%p\n", this); }
|
||||
void FuncA(int v1, float v2, std::string str) { printf("C: --%d--%f--%s--\n", v1, v2, str.c_str()); }
|
||||
};
|
||||
|
||||
|
@ -71,13 +71,13 @@ class B : public A {
|
|||
TEST_F(TestSignal, test_common) {
|
||||
A objA;
|
||||
B objB;
|
||||
C objC;
|
||||
Ca objC;
|
||||
|
||||
Signal<void(int, float, std::string)> signal;
|
||||
|
||||
signal.connect(&objA, &A::FuncA);
|
||||
signal.connect(&objB, &B::FuncA);
|
||||
signal.connect(&objC, &C::FuncA);
|
||||
signal.connect(&objC, &Ca::FuncA);
|
||||
signal(20, 20, "Signal-Slot test");
|
||||
}
|
||||
|
||||
|
@ -85,11 +85,11 @@ TEST_F(TestSignal, test_sigs) {
|
|||
signals sigs;
|
||||
A objA(&sigs);
|
||||
B objB(&sigs);
|
||||
C objC(&sigs);
|
||||
Ca objC(&sigs);
|
||||
|
||||
sigs.signal.connect(&objA, &A::FuncA);
|
||||
sigs.signal.connect(&objB, &B::FuncA);
|
||||
sigs.signal.connect(&objC, &C::FuncA);
|
||||
sigs.signal.connect(&objC, &Ca::FuncA);
|
||||
sigs.signal(20, 20, "sigs Signal-Slot test");
|
||||
}
|
||||
|
||||
|
@ -97,7 +97,7 @@ TEST_F(TestSignal, test_sigs_Named) {
|
|||
signals sigs;
|
||||
A objA(&sigs);
|
||||
B objB(&sigs);
|
||||
C objC(&sigs);
|
||||
Ca objC(&sigs);
|
||||
|
||||
sigs.signal(10, 20, "Signal-Slot test");
|
||||
std::shared_ptr<Named> a;
|
||||
|
|
Loading…
Reference in New Issue