!27606 acc infershape in runtime

Merge pull request !27606 from fangzehua/acc_infer
This commit is contained in:
i-robot 2021-12-16 02:05:19 +00:00 committed by Gitee
commit 481e87d143
14 changed files with 273 additions and 118 deletions

View File

@ -737,12 +737,33 @@ KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_nod
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode." << trace::DumpSourceLines(anf_node);
}
if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
return VisitKernelWithReturnType(anf_node, 0, skip_nop_node);
auto kernel_info = anf_node->kernel_info();
if (kernel_info) {
auto runtime_cache = kernel_info->runtime_cache();
MS_EXCEPTION_IF_NULL(runtime_cache);
if (runtime_cache->is_valid()) {
auto output = runtime_cache->get_prev_node_output(input_idx);
if (output.first != nullptr) {
return output;
}
}
}
auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
MS_EXCEPTION_IF_NULL(input_node);
return VisitKernelWithReturnType(input_node, 0, skip_nop_node);
KernelWithIndex res;
if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
res = VisitKernelWithReturnType(anf_node, 0, skip_nop_node);
} else {
auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
MS_EXCEPTION_IF_NULL(input_node);
res = VisitKernelWithReturnType(input_node, 0, skip_nop_node);
}
if (kernel_info) {
auto runtime_cache = kernel_info->runtime_cache();
MS_EXCEPTION_IF_NULL(runtime_cache);
if (runtime_cache->is_valid()) {
runtime_cache->set_prev_node_output(input_idx, res);
}
}
return res;
}
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
@ -2192,9 +2213,9 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::map<uint32_t, te
for (size_t i = 0; i < input_size; ++i) {
auto input_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
auto real_input = input_with_index.first;
MS_EXCEPTION_IF_NULL(real_input);
auto cnode_input = node->input(i + 1);
MS_EXCEPTION_IF_NULL(cnode_input);
MS_EXCEPTION_IF_NULL(real_input);
if (depend_tensors != nullptr) {
auto iter_tensor = depend_tensors->find(i);
if (iter_tensor != depend_tensors->end()) {
@ -2214,28 +2235,32 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::map<uint32_t, te
}
}
}
if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) {
auto base_shape = real_input->Shape();
if (!base_shape->isa<abstract::TupleShape>()) {
MS_LOG(EXCEPTION) << "Node:" << node->DebugString()
<< " input is a tuple_get_item but real input node shape is not a TupleShape. trace: "
<< trace::DumpSourceLines(real_input);
}
auto abs = real_input->abstract()->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abs);
auto tuple_get_item_indexk = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
auto abs_i = abs->elements()[tuple_get_item_indexk];
(void)args_spec_list.emplace_back(abs_i);
} else if (cnode_input->isa<CNode>() && AnfAlgo::GetCNodeName(cnode_input) == prim::kPrimReshape->name()) {
(void)args_spec_list.emplace_back(cnode_input->abstract());
} else {
(void)args_spec_list.emplace_back(real_input->abstract());
}
AddArgList(&args_spec_list, cnode_input, real_input, i);
}
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
node->set_abstract(eval_result);
}
void AnfRuntimeAlgorithm::AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input,
const AnfNodePtr &real_input, size_t index) {
if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) {
auto base_shape = real_input->Shape();
if (!base_shape->isa<abstract::TupleShape>()) {
MS_LOG(EXCEPTION) << "Node input is a tuple_get_item but real input node shape is not a TupleShape. trace: "
<< trace::DumpSourceLines(real_input);
}
auto abs = real_input->abstract()->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abs);
auto tuple_get_item_indexk = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
auto abs_i = abs->elements()[tuple_get_item_indexk];
(void)args_spec_list->emplace_back(abs_i);
} else if (cnode_input->isa<CNode>() && AnfAlgo::GetCNodeName(cnode_input) == prim::kPrimReshape->name()) {
(void)args_spec_list->emplace_back(cnode_input->abstract());
} else {
(void)args_spec_list->emplace_back(real_input->abstract());
}
}
void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(const NotNull<KernelGraphPtr> &root_graph) {
auto return_node = root_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);

View File

@ -298,6 +298,8 @@ class AnfRuntimeAlgorithm {
static std::vector<int64_t> GetOutputMinShape(const AnfNodePtr &anf_node, size_t index);
static bool IsNodeDynamicShape(const AnfNodePtr &node);
static void InferShape(const CNodePtr &node, std::map<uint32_t, tensor::TensorPtr> *depend_tensors = nullptr);
static void AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input,
const AnfNodePtr &real_input, size_t index);
static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
// Find real input nodes.

View File

@ -1001,20 +1001,21 @@ void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
const int64_t groups, const std::vector<int64_t> &input_hidden_size) {
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
{kOpFormat_NHWC, NhwcDeviceShape},
{kOpFormat_HWCN, HwchDeviceShape},
{kOpFormat_FRAC_Z, FracZDeviceShape},
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape},
{kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
{kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
{kOpFormat_NCDHW, NcdhwDeviceShape},
{kOpFormat_ChannelLast, ChannelLastDeviceShape},
{kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape},
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape},
{kOpFormat_FRAC_NZ, FracNZDeviceShape},
{kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceShape}};
static const std::map<std::string, DeviceShapeTransfer> device_shape_map{
{kOpFormat_NCHW, NchwDeviceShape},
{kOpFormat_NHWC, NhwcDeviceShape},
{kOpFormat_HWCN, HwchDeviceShape},
{kOpFormat_FRAC_Z, FracZDeviceShape},
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape},
{kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
{kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
{kOpFormat_NCDHW, NcdhwDeviceShape},
{kOpFormat_ChannelLast, ChannelLastDeviceShape},
{kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape},
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape},
{kOpFormat_FRAC_NZ, FracNZDeviceShape},
{kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceShape}};
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
return shape;

View File

@ -47,61 +47,66 @@ void DynamicKernel::Initialize() {
return;
}
MS_LOG(INFO) << "Have depends";
(void)std::transform(ret.begin(), ret.end(), std::back_inserter(depend_list_),
(void)std::transform(ret.begin(), ret.end(), std::inserter(depend_list_, depend_list_.begin()),
[](const int64_t &value) { return static_cast<int>(value); });
MS_LOG(INFO) << "Init End";
}
int DynamicKernel::GetKernelType() const { return AnfAlgo::GetKernelType(cnode_ptr_.lock()); }
void DynamicKernel::RebuildDependTensor() {
depend_tensor_map_.clear();
auto cnode = cnode_ptr_.lock();
MS_EXCEPTION_IF_NULL(cnode);
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
for (auto depend : depend_list_) {
auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, depend);
bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, depend, skip_nop_node);
std::vector<int64_t> shapes = trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second);
auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second);
auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
MS_EXCEPTION_IF_NULL(out_tensor);
// The second parameter must be false, otherwise the device address cannot be released and allocated, and the
// address size will be wrong in the dynamic shape scenario.
out_tensor->set_device_address(output_addr, false);
auto ret = depend_tensor_map_.try_emplace(depend, out_tensor);
if (!ret.second) {
MS_LOG(EXCEPTION) << "Insert map failed";
}
}
}
void DynamicKernel::InferShape() {
auto cnode = cnode_ptr_.lock();
MS_EXCEPTION_IF_NULL(cnode);
MS_LOG(INFO) << "InferShape start, node:" << cnode->fullname_with_scope();
InferShapeRecursive();
depend_tensor_map_.clear();
auto inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Invalid inputs";
}
// rebuild depend tensor map for gpu dynamic memory allocation.
RebuildDependTensor();
AnfAlgo::InferShape(cnode, &depend_tensor_map_);
}
void DynamicKernel::InferShapeRecursive() {
auto cnode = cnode_ptr_.lock();
MS_EXCEPTION_IF_NULL(cnode);
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
AbstractBasePtrList args_spec_list;
auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);
auto input_size = AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 0; i < input_size; i++) {
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i);
auto input_node = input_node_with_index.first;
MS_EXCEPTION_IF_NULL(input_node);
InferShapeForNopNode(&input_node);
auto real_input = input_node_with_index.first;
MS_EXCEPTION_IF_NULL(real_input);
auto cnode_input = cnode->input(i + 1);
MS_EXCEPTION_IF_NULL(cnode_input);
InferShapeForNopNode(&real_input);
if (depend_list_.find(i) != depend_list_.end()) {
auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i);
bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, i, skip_nop_node);
std::vector<int64_t> shapes =
trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second);
auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second);
auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
MS_EXCEPTION_IF_NULL(out_tensor);
// The second parameter must be false, otherwise the device address cannot be released and allocated, and the
// address size will be wrong in the dynamic shape scenario.
out_tensor->set_device_address(output_addr, false);
auto ret = depend_tensor_map_.try_emplace(i, out_tensor);
if (!ret.second) {
MS_LOG(EXCEPTION) << "Insert map failed";
}
out_tensor->data_sync();
auto real_abs = real_input->abstract();
if (real_abs->isa<abstract::AbstractTensor>()) {
real_input->abstract()->set_value(out_tensor);
} else if (real_abs->isa<abstract::AbstractTuple>()) {
auto tuple_get_item_index = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
auto abstract_tuple = real_abs->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abstract_tuple);
auto tuple_elements = abstract_tuple->elements()[tuple_get_item_index];
tuple_elements->set_value(out_tensor);
}
}
AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i);
}
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
cnode->set_abstract(eval_result);
}
void DynamicKernel::InferShapeForNopNode(AnfNodePtr *input_node) {

View File

@ -21,6 +21,7 @@
#include <string>
#include <vector>
#include <map>
#include <set>
#include "ir/anf.h"
#include "ir/tensor.h"
#include "abstract/primitive_infer_map.h"
@ -46,8 +47,6 @@ class DynamicKernel {
[[nodiscard]] int GetKernelType() const;
protected:
void RebuildDependTensor();
void InferShapeRecursive();
static void InferShapeForNopNode(AnfNodePtr *input_node);
void *stream_;
@ -55,7 +54,7 @@ class DynamicKernel {
bool is_dynamic_shape_;
bool is_input_dynamic_shape_;
bool is_output_dynamic_shape_;
std::vector<uint32_t> depend_list_;
std::set<uint32_t> depend_list_;
std::map<uint32_t, tensor::TensorPtr> depend_tensor_map_;
};
using DynamicKernelPtr = std::shared_ptr<DynamicKernel>;

View File

@ -479,6 +479,7 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
(void)mindspore::RDR::RecordGraphExecOrder(SubModuleId::SM_SESSION, exec_order_name, kernels);
#endif
device_context->EnableRuntimeCache(graph);
session_->DumpGraph(graph);
return graph->graph_id();
}

View File

@ -156,6 +156,20 @@ class DeviceContext {
// Dump all graphs.
virtual void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) const {}
void EnableRuntimeCache(const KernelGraphPtr &graph) const {
auto node_list = graph->TopoSort(graph->get_return());
for (auto &node : node_list) {
auto kernel_info = node->kernel_info();
if (!kernel_info) {
continue;
}
MS_EXCEPTION_IF_NULL(kernel_info);
auto runtime_cache = kernel_info->runtime_cache();
MS_EXCEPTION_IF_NULL(runtime_cache);
runtime_cache->set_valid();
}
}
protected:
DeviceContextKey device_context_key_;

View File

@ -370,7 +370,7 @@ void GPUDeviceContext::UpdateDynamicShape(const CNodePtr &kernel) const {
MS_EXCEPTION_IF_NULL(ms_context);
bool is_pynative_infer = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
bool is_pynative_mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
std::vector<int64_t> dynamic_shape_depends = abstract::GetDependsFormMap(kernel);
auto dynamic_shape_depends = abstract::GetDependsFormMap(kernel);
if ((is_pynative_infer || is_pynative_mode) && dynamic_shape_depends.empty()) {
return;
}

View File

@ -19,6 +19,7 @@
#include "abstract/primitive_infer_map.h"
#include <string>
#include <vector>
#include <set>
#include "ops/exp.h"
#include "ops/log.h"
#include "ops/reciprocal.h"
@ -39,9 +40,9 @@
#include "ops/grad/slice_grad.h"
namespace mindspore {
namespace abstract {
std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
using ShapeVec = std::vector<int64_t>;
using PrimShapeDependMap = mindspore::HashMap<std::string, ShapeVec>;
std::set<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
using ShapeSet = std::set<int64_t>;
using PrimShapeDependMap = mindspore::HashMap<std::string, ShapeSet>;
static const auto &kOneHot = prim::kPrimOneHot->name();
static const auto &kDropoutGenMask = prim::kPrimDropoutGenMask->name();
static const auto &kTranspose = prim::kPrimTranspose->name();
@ -63,24 +64,24 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
static const auto &kReshape = prim::kPrimReshape->name();
static const auto &kDynamicReshape = prim::kPrimDynamicReshape->name();
// Common dynamic shape depends.
static const PrimShapeDependMap dynamic_shape_depends{{kUnsortedSegmentSum, ShapeVec{2}},
{kUnsortedSegmentMin, ShapeVec{2}},
{kUnsortedSegmentMax, ShapeVec{2}},
{kGather, ShapeVec{2}},
{kGatherV2, ShapeVec{2}},
{kRange, ShapeVec{0, 1, 2}},
{kConv2DBackpropFilter, ShapeVec{2}},
{kConv2DBackpropInput, ShapeVec{2}},
{kOneHot, ShapeVec{1, 3}},
{kDropoutGenMask, ShapeVec{0}},
{kStridedSlice, ShapeVec{1, 2, 3}},
{kStridedSliceGrad, ShapeVec{1, 2, 3, 4}},
{kTile, ShapeVec{1}},
{kReshape, ShapeVec{1}},
{kDynamicReshape, ShapeVec{1}},
{kSlice, ShapeVec{1, 2}},
{kSliceGrad, ShapeVec{2, 3}},
{kDynamicBroadcastTo, ShapeVec{1}}};
static const PrimShapeDependMap dynamic_shape_depends{{kUnsortedSegmentSum, ShapeSet{2}},
{kUnsortedSegmentMin, ShapeSet{2}},
{kUnsortedSegmentMax, ShapeSet{2}},
{kGather, ShapeSet{2}},
{kGatherV2, ShapeSet{2}},
{kRange, ShapeSet{0, 1, 2}},
{kConv2DBackpropFilter, ShapeSet{2}},
{kConv2DBackpropInput, ShapeSet{2}},
{kOneHot, ShapeSet{1, 3}},
{kDropoutGenMask, ShapeSet{0}},
{kStridedSlice, ShapeSet{1, 2, 3}},
{kStridedSliceGrad, ShapeSet{1, 2, 3, 4}},
{kTile, ShapeSet{1}},
{kReshape, ShapeSet{1}},
{kDynamicReshape, ShapeSet{1}},
{kSlice, ShapeSet{1, 2}},
{kSliceGrad, ShapeSet{2, 3}},
{kDynamicBroadcastTo, ShapeSet{1}}};
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
@ -101,9 +102,9 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
auto iter = dynamic_shape_depends.find(prim_name);
if (iter != dynamic_shape_depends.end()) {
int64_t cnode_input_size = SizeToLong(cnode->inputs().size());
std::vector<int64_t> res;
ShapeSet res;
auto ori = iter->second;
(void)std::copy_if(ori.begin(), ori.end(), std::back_inserter(res),
(void)std::copy_if(ori.begin(), ori.end(), std::inserter(res, res.begin()),
[&](auto idx) { return idx < cnode_input_size - 1; });
return res;
}

View File

@ -19,6 +19,7 @@
#define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
#include <vector>
#include <set>
#include <memory>
#include "utils/hash_map.h"
#include "ir/primitive.h"
@ -50,7 +51,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap();
StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive);
std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode);
std::set<int64_t> GetDependsFormMap(const CNodePtr &cnode);
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg);

View File

@ -622,14 +622,36 @@ std::string GetOriginNodeTarget(const AnfNodePtr &node) {
}
std::string GetCNodeTarget(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
auto target = GetOriginNodeTarget(node);
if (target != kTargetUnDefined) {
return target;
auto kernel_info = node->kernel_info();
if (kernel_info != nullptr) {
auto runtime_cache = kernel_info->runtime_cache();
MS_EXCEPTION_IF_NULL(runtime_cache);
if (runtime_cache->is_valid()) {
auto tmp_target = runtime_cache->device_target();
if (!tmp_target.empty()) {
return tmp_target;
}
}
}
return default_target;
std::string target;
auto ori_target = GetOriginNodeTarget(node);
if (ori_target != kTargetUnDefined) {
target = ori_target;
} else {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
}
if (kernel_info != nullptr) {
auto runtime_cache = kernel_info->runtime_cache();
MS_EXCEPTION_IF_NULL(runtime_cache);
if (runtime_cache->is_valid()) {
runtime_cache->set_device_target(target);
}
}
return target;
}
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {

View File

@ -18,6 +18,10 @@
#define MINDSPORE_CORE_IR_KERNEL_INFO_DEV_H_
#include <memory>
#include <map>
#include <utility>
#include <string>
#include "utils/info.h"
namespace mindspore {
enum Axis : int {
@ -26,11 +30,51 @@ enum Axis : int {
H,
W,
};
// Cache some runtime information which not be changed.
class RuntimeCache {
public:
std::pair<AnfNodePtr, size_t> get_prev_node_output(size_t index) {
auto it = prev_node_output_map_.find(index);
if (it != prev_node_output_map_.end()) {
return it->second;
} else {
return std::pair<AnfNodePtr, size_t>();
}
}
void set_prev_node_output(size_t index, std::pair<AnfNodePtr, size_t> output) {
auto pr = std::make_pair(index, output);
(void)prev_node_output_map_.insert(pr);
}
std::string device_target() { return device_target_; }
void set_device_target(const std::string &target) { device_target_ = target; }
bool is_valid() { return is_valid_; }
void set_valid() { is_valid_ = true; }
void set_output_tensor_num(const ssize_t output_tensor_num) { output_tensor_num_ = output_tensor_num; }
ssize_t output_tensor_num() const { return output_tensor_num_; }
void set_real_kernel(enum CacheBool b) { is_real_kernel_ = b; }
enum CacheBool is_real_kernel() { return is_real_kernel_; }
private:
bool is_valid_{false};
std::map<size_t, std::pair<AnfNodePtr, size_t>> prev_node_output_map_;
std::string device_target_;
ssize_t output_tensor_num_ = -1;
enum CacheBool is_real_kernel_ = CacheBool::UNCACHED;
};
// Interface for device kernel program information.
class KernelInfoDevice {
public:
// If kernel program was built and build info is set.
virtual bool has_build_info() const = 0;
RuntimeCache *runtime_cache() { return &runtime_cache_; }
private:
RuntimeCache runtime_cache_;
};
using KernelInfoDevicePtr = std::shared_ptr<KernelInfoDevice>;
} // namespace mindspore

View File

@ -108,7 +108,28 @@ bool AnfUtils::IsRealKernel(const AnfNodePtr &node) {
if (cnode->size() == 0) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString() << trace::DumpSourceLines(node);
}
return !IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), virtual_prims);
auto kernel_info = cnode->kernel_info();
if (kernel_info) {
auto runtime_cache = kernel_info->runtime_cache();
MS_EXCEPTION_IF_NULL(runtime_cache);
if (runtime_cache->is_real_kernel() != CacheBool::UNCACHED) {
return (runtime_cache->is_real_kernel() == CacheBool::TRUE);
}
}
bool res = !IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), virtual_prims);
if (kernel_info) {
auto runtime_cache = kernel_info->runtime_cache();
MS_EXCEPTION_IF_NULL(runtime_cache);
if (res) {
runtime_cache->set_real_kernel(CacheBool::TRUE);
} else {
runtime_cache->set_real_kernel(CacheBool::FALSE);
}
}
return res;
}
bool AnfUtils::IsRealCNodeKernel(const AnfNodePtr &node) {
@ -183,19 +204,38 @@ size_t AnfUtils::GetInputTensorNum(const AnfNodePtr &node) {
size_t AnfUtils::GetOutputTensorNum(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
if (kernel_info) {
auto runtime_cache = kernel_info->runtime_cache();
if (runtime_cache->is_valid()) {
ssize_t output_tensor_num = runtime_cache->output_tensor_num();
if (output_tensor_num >= 0) {
return static_cast<size_t>(output_tensor_num);
}
}
}
size_t res;
TypePtr type = node->Type();
if (type == nullptr) {
return 0;
}
if (type->isa<Tuple>()) {
res = 0;
} else if (type->isa<Tuple>()) {
auto tuple_type = type->cast<TuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_type);
return tuple_type->size();
res = tuple_type->size();
} else if (type->isa<TypeNone>()) {
res = 0;
} else {
res = 1;
}
if (type->isa<TypeNone>()) {
return 0;
if (kernel_info) {
auto runtime_cache = kernel_info->runtime_cache();
if (runtime_cache->is_valid()) {
runtime_cache->set_output_tensor_num(static_cast<ssize_t>(res));
}
}
return 1;
return res;
}
std::pair<AnfNodePtr, size_t> AnfUtils::VisitKernel(const AnfNodePtr &anf_node, size_t index) {

View File

@ -29,7 +29,7 @@
namespace mindspore {
enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSourceLineTipInLine = 2 };
typedef enum CacheBool { UNCACHED = -1, FALSE, TRUE } CacheBool;
// Location class record the location in source code.
class Location {
public: