forked from mindspore-Ecosystem/mindspore
!27606 acc infershape in runtime
Merge pull request !27606 from fangzehua/acc_infer
This commit is contained in:
commit
481e87d143
|
@ -737,12 +737,33 @@ KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_nod
|
|||
if (!anf_node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode." << trace::DumpSourceLines(anf_node);
|
||||
}
|
||||
if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
|
||||
return VisitKernelWithReturnType(anf_node, 0, skip_nop_node);
|
||||
auto kernel_info = anf_node->kernel_info();
|
||||
if (kernel_info) {
|
||||
auto runtime_cache = kernel_info->runtime_cache();
|
||||
MS_EXCEPTION_IF_NULL(runtime_cache);
|
||||
if (runtime_cache->is_valid()) {
|
||||
auto output = runtime_cache->get_prev_node_output(input_idx);
|
||||
if (output.first != nullptr) {
|
||||
return output;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
return VisitKernelWithReturnType(input_node, 0, skip_nop_node);
|
||||
KernelWithIndex res;
|
||||
if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
|
||||
res = VisitKernelWithReturnType(anf_node, 0, skip_nop_node);
|
||||
} else {
|
||||
auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
res = VisitKernelWithReturnType(input_node, 0, skip_nop_node);
|
||||
}
|
||||
if (kernel_info) {
|
||||
auto runtime_cache = kernel_info->runtime_cache();
|
||||
MS_EXCEPTION_IF_NULL(runtime_cache);
|
||||
if (runtime_cache->is_valid()) {
|
||||
runtime_cache->set_prev_node_output(input_idx, res);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||
|
@ -2192,9 +2213,9 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::map<uint32_t, te
|
|||
for (size_t i = 0; i < input_size; ++i) {
|
||||
auto input_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
|
||||
auto real_input = input_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
auto cnode_input = node->input(i + 1);
|
||||
MS_EXCEPTION_IF_NULL(cnode_input);
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
if (depend_tensors != nullptr) {
|
||||
auto iter_tensor = depend_tensors->find(i);
|
||||
if (iter_tensor != depend_tensors->end()) {
|
||||
|
@ -2214,28 +2235,32 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::map<uint32_t, te
|
|||
}
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) {
|
||||
auto base_shape = real_input->Shape();
|
||||
if (!base_shape->isa<abstract::TupleShape>()) {
|
||||
MS_LOG(EXCEPTION) << "Node:" << node->DebugString()
|
||||
<< " input is a tuple_get_item but real input node shape is not a TupleShape. trace: "
|
||||
<< trace::DumpSourceLines(real_input);
|
||||
}
|
||||
auto abs = real_input->abstract()->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
auto tuple_get_item_indexk = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
|
||||
auto abs_i = abs->elements()[tuple_get_item_indexk];
|
||||
(void)args_spec_list.emplace_back(abs_i);
|
||||
} else if (cnode_input->isa<CNode>() && AnfAlgo::GetCNodeName(cnode_input) == prim::kPrimReshape->name()) {
|
||||
(void)args_spec_list.emplace_back(cnode_input->abstract());
|
||||
} else {
|
||||
(void)args_spec_list.emplace_back(real_input->abstract());
|
||||
}
|
||||
AddArgList(&args_spec_list, cnode_input, real_input, i);
|
||||
}
|
||||
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
|
||||
node->set_abstract(eval_result);
|
||||
}
|
||||
|
||||
void AnfRuntimeAlgorithm::AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input,
|
||||
const AnfNodePtr &real_input, size_t index) {
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) {
|
||||
auto base_shape = real_input->Shape();
|
||||
if (!base_shape->isa<abstract::TupleShape>()) {
|
||||
MS_LOG(EXCEPTION) << "Node input is a tuple_get_item but real input node shape is not a TupleShape. trace: "
|
||||
<< trace::DumpSourceLines(real_input);
|
||||
}
|
||||
auto abs = real_input->abstract()->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
auto tuple_get_item_indexk = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
|
||||
auto abs_i = abs->elements()[tuple_get_item_indexk];
|
||||
(void)args_spec_list->emplace_back(abs_i);
|
||||
} else if (cnode_input->isa<CNode>() && AnfAlgo::GetCNodeName(cnode_input) == prim::kPrimReshape->name()) {
|
||||
(void)args_spec_list->emplace_back(cnode_input->abstract());
|
||||
} else {
|
||||
(void)args_spec_list->emplace_back(real_input->abstract());
|
||||
}
|
||||
}
|
||||
|
||||
void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(const NotNull<KernelGraphPtr> &root_graph) {
|
||||
auto return_node = root_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
|
|
|
@ -298,6 +298,8 @@ class AnfRuntimeAlgorithm {
|
|||
static std::vector<int64_t> GetOutputMinShape(const AnfNodePtr &anf_node, size_t index);
|
||||
static bool IsNodeDynamicShape(const AnfNodePtr &node);
|
||||
static void InferShape(const CNodePtr &node, std::map<uint32_t, tensor::TensorPtr> *depend_tensors = nullptr);
|
||||
static void AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input,
|
||||
const AnfNodePtr &real_input, size_t index);
|
||||
static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
|
||||
static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
|
||||
// Find real input nodes.
|
||||
|
|
|
@ -1001,20 +1001,21 @@ void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5
|
|||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
||||
const int64_t groups, const std::vector<int64_t> &input_hidden_size) {
|
||||
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
|
||||
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
|
||||
{kOpFormat_NHWC, NhwcDeviceShape},
|
||||
{kOpFormat_HWCN, HwchDeviceShape},
|
||||
{kOpFormat_FRAC_Z, FracZDeviceShape},
|
||||
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape},
|
||||
{kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
|
||||
{kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
|
||||
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
|
||||
{kOpFormat_NCDHW, NcdhwDeviceShape},
|
||||
{kOpFormat_ChannelLast, ChannelLastDeviceShape},
|
||||
{kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape},
|
||||
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape},
|
||||
{kOpFormat_FRAC_NZ, FracNZDeviceShape},
|
||||
{kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceShape}};
|
||||
static const std::map<std::string, DeviceShapeTransfer> device_shape_map{
|
||||
{kOpFormat_NCHW, NchwDeviceShape},
|
||||
{kOpFormat_NHWC, NhwcDeviceShape},
|
||||
{kOpFormat_HWCN, HwchDeviceShape},
|
||||
{kOpFormat_FRAC_Z, FracZDeviceShape},
|
||||
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape},
|
||||
{kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
|
||||
{kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
|
||||
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
|
||||
{kOpFormat_NCDHW, NcdhwDeviceShape},
|
||||
{kOpFormat_ChannelLast, ChannelLastDeviceShape},
|
||||
{kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape},
|
||||
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape},
|
||||
{kOpFormat_FRAC_NZ, FracNZDeviceShape},
|
||||
{kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceShape}};
|
||||
|
||||
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
|
||||
return shape;
|
||||
|
|
|
@ -47,61 +47,66 @@ void DynamicKernel::Initialize() {
|
|||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Have depends";
|
||||
(void)std::transform(ret.begin(), ret.end(), std::back_inserter(depend_list_),
|
||||
(void)std::transform(ret.begin(), ret.end(), std::inserter(depend_list_, depend_list_.begin()),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
MS_LOG(INFO) << "Init End";
|
||||
}
|
||||
|
||||
int DynamicKernel::GetKernelType() const { return AnfAlgo::GetKernelType(cnode_ptr_.lock()); }
|
||||
|
||||
void DynamicKernel::RebuildDependTensor() {
|
||||
depend_tensor_map_.clear();
|
||||
auto cnode = cnode_ptr_.lock();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
for (auto depend : depend_list_) {
|
||||
auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, depend);
|
||||
bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
|
||||
auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, depend, skip_nop_node);
|
||||
std::vector<int64_t> shapes = trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second);
|
||||
auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second);
|
||||
auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
|
||||
MS_EXCEPTION_IF_NULL(out_tensor);
|
||||
// The second parameter must be false, otherwise the device address cannot be released and allocated, and the
|
||||
// address size will be wrong in the dynamic shape scenario.
|
||||
out_tensor->set_device_address(output_addr, false);
|
||||
auto ret = depend_tensor_map_.try_emplace(depend, out_tensor);
|
||||
if (!ret.second) {
|
||||
MS_LOG(EXCEPTION) << "Insert map failed";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DynamicKernel::InferShape() {
|
||||
auto cnode = cnode_ptr_.lock();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_LOG(INFO) << "InferShape start, node:" << cnode->fullname_with_scope();
|
||||
InferShapeRecursive();
|
||||
depend_tensor_map_.clear();
|
||||
auto inputs = cnode->inputs();
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid inputs";
|
||||
}
|
||||
// rebuild depend tensor map for gpu dynamic memory allocation.
|
||||
RebuildDependTensor();
|
||||
AnfAlgo::InferShape(cnode, &depend_tensor_map_);
|
||||
}
|
||||
|
||||
void DynamicKernel::InferShapeRecursive() {
|
||||
auto cnode = cnode_ptr_.lock();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);
|
||||
auto input_size = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t i = 0; i < input_size; i++) {
|
||||
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i);
|
||||
auto input_node = input_node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
InferShapeForNopNode(&input_node);
|
||||
auto real_input = input_node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
auto cnode_input = cnode->input(i + 1);
|
||||
MS_EXCEPTION_IF_NULL(cnode_input);
|
||||
InferShapeForNopNode(&real_input);
|
||||
if (depend_list_.find(i) != depend_list_.end()) {
|
||||
auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i);
|
||||
bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
|
||||
auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, i, skip_nop_node);
|
||||
std::vector<int64_t> shapes =
|
||||
trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second);
|
||||
auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second);
|
||||
auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
|
||||
MS_EXCEPTION_IF_NULL(out_tensor);
|
||||
// The second parameter must be false, otherwise the device address cannot be released and allocated, and the
|
||||
// address size will be wrong in the dynamic shape scenario.
|
||||
out_tensor->set_device_address(output_addr, false);
|
||||
auto ret = depend_tensor_map_.try_emplace(i, out_tensor);
|
||||
if (!ret.second) {
|
||||
MS_LOG(EXCEPTION) << "Insert map failed";
|
||||
}
|
||||
out_tensor->data_sync();
|
||||
auto real_abs = real_input->abstract();
|
||||
if (real_abs->isa<abstract::AbstractTensor>()) {
|
||||
real_input->abstract()->set_value(out_tensor);
|
||||
} else if (real_abs->isa<abstract::AbstractTuple>()) {
|
||||
auto tuple_get_item_index = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
|
||||
auto abstract_tuple = real_abs->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
auto tuple_elements = abstract_tuple->elements()[tuple_get_item_index];
|
||||
tuple_elements->set_value(out_tensor);
|
||||
}
|
||||
}
|
||||
AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i);
|
||||
}
|
||||
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
|
||||
cnode->set_abstract(eval_result);
|
||||
}
|
||||
|
||||
void DynamicKernel::InferShapeForNopNode(AnfNodePtr *input_node) {
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
@ -46,8 +47,6 @@ class DynamicKernel {
|
|||
[[nodiscard]] int GetKernelType() const;
|
||||
|
||||
protected:
|
||||
void RebuildDependTensor();
|
||||
void InferShapeRecursive();
|
||||
static void InferShapeForNopNode(AnfNodePtr *input_node);
|
||||
|
||||
void *stream_;
|
||||
|
@ -55,7 +54,7 @@ class DynamicKernel {
|
|||
bool is_dynamic_shape_;
|
||||
bool is_input_dynamic_shape_;
|
||||
bool is_output_dynamic_shape_;
|
||||
std::vector<uint32_t> depend_list_;
|
||||
std::set<uint32_t> depend_list_;
|
||||
std::map<uint32_t, tensor::TensorPtr> depend_tensor_map_;
|
||||
};
|
||||
using DynamicKernelPtr = std::shared_ptr<DynamicKernel>;
|
||||
|
|
|
@ -479,6 +479,7 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
|||
(void)mindspore::RDR::RecordGraphExecOrder(SubModuleId::SM_SESSION, exec_order_name, kernels);
|
||||
#endif
|
||||
|
||||
device_context->EnableRuntimeCache(graph);
|
||||
session_->DumpGraph(graph);
|
||||
return graph->graph_id();
|
||||
}
|
||||
|
|
|
@ -156,6 +156,20 @@ class DeviceContext {
|
|||
// Dump all graphs.
|
||||
virtual void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) const {}
|
||||
|
||||
void EnableRuntimeCache(const KernelGraphPtr &graph) const {
|
||||
auto node_list = graph->TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
auto kernel_info = node->kernel_info();
|
||||
if (!kernel_info) {
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto runtime_cache = kernel_info->runtime_cache();
|
||||
MS_EXCEPTION_IF_NULL(runtime_cache);
|
||||
runtime_cache->set_valid();
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
DeviceContextKey device_context_key_;
|
||||
|
||||
|
|
|
@ -370,7 +370,7 @@ void GPUDeviceContext::UpdateDynamicShape(const CNodePtr &kernel) const {
|
|||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
bool is_pynative_infer = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
|
||||
bool is_pynative_mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
|
||||
std::vector<int64_t> dynamic_shape_depends = abstract::GetDependsFormMap(kernel);
|
||||
auto dynamic_shape_depends = abstract::GetDependsFormMap(kernel);
|
||||
if ((is_pynative_infer || is_pynative_mode) && dynamic_shape_depends.empty()) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "abstract/primitive_infer_map.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include "ops/exp.h"
|
||||
#include "ops/log.h"
|
||||
#include "ops/reciprocal.h"
|
||||
|
@ -39,9 +40,9 @@
|
|||
#include "ops/grad/slice_grad.h"
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
|
||||
using ShapeVec = std::vector<int64_t>;
|
||||
using PrimShapeDependMap = mindspore::HashMap<std::string, ShapeVec>;
|
||||
std::set<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
|
||||
using ShapeSet = std::set<int64_t>;
|
||||
using PrimShapeDependMap = mindspore::HashMap<std::string, ShapeSet>;
|
||||
static const auto &kOneHot = prim::kPrimOneHot->name();
|
||||
static const auto &kDropoutGenMask = prim::kPrimDropoutGenMask->name();
|
||||
static const auto &kTranspose = prim::kPrimTranspose->name();
|
||||
|
@ -63,24 +64,24 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
|
|||
static const auto &kReshape = prim::kPrimReshape->name();
|
||||
static const auto &kDynamicReshape = prim::kPrimDynamicReshape->name();
|
||||
// Common dynamic shape depends.
|
||||
static const PrimShapeDependMap dynamic_shape_depends{{kUnsortedSegmentSum, ShapeVec{2}},
|
||||
{kUnsortedSegmentMin, ShapeVec{2}},
|
||||
{kUnsortedSegmentMax, ShapeVec{2}},
|
||||
{kGather, ShapeVec{2}},
|
||||
{kGatherV2, ShapeVec{2}},
|
||||
{kRange, ShapeVec{0, 1, 2}},
|
||||
{kConv2DBackpropFilter, ShapeVec{2}},
|
||||
{kConv2DBackpropInput, ShapeVec{2}},
|
||||
{kOneHot, ShapeVec{1, 3}},
|
||||
{kDropoutGenMask, ShapeVec{0}},
|
||||
{kStridedSlice, ShapeVec{1, 2, 3}},
|
||||
{kStridedSliceGrad, ShapeVec{1, 2, 3, 4}},
|
||||
{kTile, ShapeVec{1}},
|
||||
{kReshape, ShapeVec{1}},
|
||||
{kDynamicReshape, ShapeVec{1}},
|
||||
{kSlice, ShapeVec{1, 2}},
|
||||
{kSliceGrad, ShapeVec{2, 3}},
|
||||
{kDynamicBroadcastTo, ShapeVec{1}}};
|
||||
static const PrimShapeDependMap dynamic_shape_depends{{kUnsortedSegmentSum, ShapeSet{2}},
|
||||
{kUnsortedSegmentMin, ShapeSet{2}},
|
||||
{kUnsortedSegmentMax, ShapeSet{2}},
|
||||
{kGather, ShapeSet{2}},
|
||||
{kGatherV2, ShapeSet{2}},
|
||||
{kRange, ShapeSet{0, 1, 2}},
|
||||
{kConv2DBackpropFilter, ShapeSet{2}},
|
||||
{kConv2DBackpropInput, ShapeSet{2}},
|
||||
{kOneHot, ShapeSet{1, 3}},
|
||||
{kDropoutGenMask, ShapeSet{0}},
|
||||
{kStridedSlice, ShapeSet{1, 2, 3}},
|
||||
{kStridedSliceGrad, ShapeSet{1, 2, 3, 4}},
|
||||
{kTile, ShapeSet{1}},
|
||||
{kReshape, ShapeSet{1}},
|
||||
{kDynamicReshape, ShapeSet{1}},
|
||||
{kSlice, ShapeSet{1, 2}},
|
||||
{kSliceGrad, ShapeSet{2, 3}},
|
||||
{kDynamicBroadcastTo, ShapeSet{1}}};
|
||||
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().empty()) {
|
||||
|
@ -101,9 +102,9 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
|
|||
auto iter = dynamic_shape_depends.find(prim_name);
|
||||
if (iter != dynamic_shape_depends.end()) {
|
||||
int64_t cnode_input_size = SizeToLong(cnode->inputs().size());
|
||||
std::vector<int64_t> res;
|
||||
ShapeSet res;
|
||||
auto ori = iter->second;
|
||||
(void)std::copy_if(ori.begin(), ori.end(), std::back_inserter(res),
|
||||
(void)std::copy_if(ori.begin(), ori.end(), std::inserter(res, res.begin()),
|
||||
[&](auto idx) { return idx < cnode_input_size - 1; });
|
||||
return res;
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include "utils/hash_map.h"
|
||||
#include "ir/primitive.h"
|
||||
|
@ -50,7 +51,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap();
|
|||
|
||||
StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive);
|
||||
|
||||
std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode);
|
||||
std::set<int64_t> GetDependsFormMap(const CNodePtr &cnode);
|
||||
|
||||
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg);
|
||||
|
||||
|
|
|
@ -622,14 +622,36 @@ std::string GetOriginNodeTarget(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
std::string GetCNodeTarget(const AnfNodePtr &node) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
auto target = GetOriginNodeTarget(node);
|
||||
if (target != kTargetUnDefined) {
|
||||
return target;
|
||||
auto kernel_info = node->kernel_info();
|
||||
if (kernel_info != nullptr) {
|
||||
auto runtime_cache = kernel_info->runtime_cache();
|
||||
MS_EXCEPTION_IF_NULL(runtime_cache);
|
||||
if (runtime_cache->is_valid()) {
|
||||
auto tmp_target = runtime_cache->device_target();
|
||||
if (!tmp_target.empty()) {
|
||||
return tmp_target;
|
||||
}
|
||||
}
|
||||
}
|
||||
return default_target;
|
||||
|
||||
std::string target;
|
||||
auto ori_target = GetOriginNodeTarget(node);
|
||||
if (ori_target != kTargetUnDefined) {
|
||||
target = ori_target;
|
||||
} else {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
}
|
||||
|
||||
if (kernel_info != nullptr) {
|
||||
auto runtime_cache = kernel_info->runtime_cache();
|
||||
MS_EXCEPTION_IF_NULL(runtime_cache);
|
||||
if (runtime_cache->is_valid()) {
|
||||
runtime_cache->set_device_target(target);
|
||||
}
|
||||
}
|
||||
return target;
|
||||
}
|
||||
|
||||
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
|
||||
|
|
|
@ -18,6 +18,10 @@
|
|||
#define MINDSPORE_CORE_IR_KERNEL_INFO_DEV_H_
|
||||
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include "utils/info.h"
|
||||
|
||||
namespace mindspore {
|
||||
enum Axis : int {
|
||||
|
@ -26,11 +30,51 @@ enum Axis : int {
|
|||
H,
|
||||
W,
|
||||
};
|
||||
|
||||
// Cache some runtime information which not be changed.
|
||||
class RuntimeCache {
|
||||
public:
|
||||
std::pair<AnfNodePtr, size_t> get_prev_node_output(size_t index) {
|
||||
auto it = prev_node_output_map_.find(index);
|
||||
if (it != prev_node_output_map_.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
return std::pair<AnfNodePtr, size_t>();
|
||||
}
|
||||
}
|
||||
|
||||
void set_prev_node_output(size_t index, std::pair<AnfNodePtr, size_t> output) {
|
||||
auto pr = std::make_pair(index, output);
|
||||
(void)prev_node_output_map_.insert(pr);
|
||||
}
|
||||
|
||||
std::string device_target() { return device_target_; }
|
||||
|
||||
void set_device_target(const std::string &target) { device_target_ = target; }
|
||||
bool is_valid() { return is_valid_; }
|
||||
void set_valid() { is_valid_ = true; }
|
||||
void set_output_tensor_num(const ssize_t output_tensor_num) { output_tensor_num_ = output_tensor_num; }
|
||||
ssize_t output_tensor_num() const { return output_tensor_num_; }
|
||||
void set_real_kernel(enum CacheBool b) { is_real_kernel_ = b; }
|
||||
enum CacheBool is_real_kernel() { return is_real_kernel_; }
|
||||
|
||||
private:
|
||||
bool is_valid_{false};
|
||||
std::map<size_t, std::pair<AnfNodePtr, size_t>> prev_node_output_map_;
|
||||
std::string device_target_;
|
||||
ssize_t output_tensor_num_ = -1;
|
||||
enum CacheBool is_real_kernel_ = CacheBool::UNCACHED;
|
||||
};
|
||||
// Interface for device kernel program information.
|
||||
class KernelInfoDevice {
|
||||
public:
|
||||
// If kernel program was built and build info is set.
|
||||
virtual bool has_build_info() const = 0;
|
||||
|
||||
RuntimeCache *runtime_cache() { return &runtime_cache_; }
|
||||
|
||||
private:
|
||||
RuntimeCache runtime_cache_;
|
||||
};
|
||||
using KernelInfoDevicePtr = std::shared_ptr<KernelInfoDevice>;
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -108,7 +108,28 @@ bool AnfUtils::IsRealKernel(const AnfNodePtr &node) {
|
|||
if (cnode->size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString() << trace::DumpSourceLines(node);
|
||||
}
|
||||
return !IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), virtual_prims);
|
||||
|
||||
auto kernel_info = cnode->kernel_info();
|
||||
if (kernel_info) {
|
||||
auto runtime_cache = kernel_info->runtime_cache();
|
||||
MS_EXCEPTION_IF_NULL(runtime_cache);
|
||||
if (runtime_cache->is_real_kernel() != CacheBool::UNCACHED) {
|
||||
return (runtime_cache->is_real_kernel() == CacheBool::TRUE);
|
||||
}
|
||||
}
|
||||
bool res = !IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), virtual_prims);
|
||||
|
||||
if (kernel_info) {
|
||||
auto runtime_cache = kernel_info->runtime_cache();
|
||||
MS_EXCEPTION_IF_NULL(runtime_cache);
|
||||
if (res) {
|
||||
runtime_cache->set_real_kernel(CacheBool::TRUE);
|
||||
} else {
|
||||
runtime_cache->set_real_kernel(CacheBool::FALSE);
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
bool AnfUtils::IsRealCNodeKernel(const AnfNodePtr &node) {
|
||||
|
@ -183,19 +204,38 @@ size_t AnfUtils::GetInputTensorNum(const AnfNodePtr &node) {
|
|||
|
||||
size_t AnfUtils::GetOutputTensorNum(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto kernel_info = node->kernel_info();
|
||||
if (kernel_info) {
|
||||
auto runtime_cache = kernel_info->runtime_cache();
|
||||
if (runtime_cache->is_valid()) {
|
||||
ssize_t output_tensor_num = runtime_cache->output_tensor_num();
|
||||
if (output_tensor_num >= 0) {
|
||||
return static_cast<size_t>(output_tensor_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t res;
|
||||
TypePtr type = node->Type();
|
||||
if (type == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
if (type->isa<Tuple>()) {
|
||||
res = 0;
|
||||
} else if (type->isa<Tuple>()) {
|
||||
auto tuple_type = type->cast<TuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_type);
|
||||
return tuple_type->size();
|
||||
res = tuple_type->size();
|
||||
} else if (type->isa<TypeNone>()) {
|
||||
res = 0;
|
||||
} else {
|
||||
res = 1;
|
||||
}
|
||||
if (type->isa<TypeNone>()) {
|
||||
return 0;
|
||||
|
||||
if (kernel_info) {
|
||||
auto runtime_cache = kernel_info->runtime_cache();
|
||||
if (runtime_cache->is_valid()) {
|
||||
runtime_cache->set_output_tensor_num(static_cast<ssize_t>(res));
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
return res;
|
||||
}
|
||||
|
||||
std::pair<AnfNodePtr, size_t> AnfUtils::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSourceLineTipInLine = 2 };
|
||||
|
||||
typedef enum CacheBool { UNCACHED = -1, FALSE, TRUE } CacheBool;
|
||||
// Location class record the location in source code.
|
||||
class Location {
|
||||
public:
|
||||
|
|
Loading…
Reference in New Issue