forked from mindspore-Ecosystem/mindspore
!1911 add a function to charge the node input and output is a scalar
Merge pull request !1911 from lianliguang/add-a-function-to-charge-the-node-input-or-output-if-is-a-scalar
This commit is contained in:
commit
beb714d2d0
|
@ -37,11 +37,11 @@ class SupportedChecker {
|
|||
public:
|
||||
SupportedChecker() = default;
|
||||
virtual ~SupportedChecker() = default;
|
||||
virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
|
||||
virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node,
|
||||
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
|
||||
return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info);
|
||||
}
|
||||
virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node,
|
||||
virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node,
|
||||
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
|
||||
return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info);
|
||||
}
|
||||
|
|
|
@ -38,9 +38,9 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
|
|||
return nullptr;
|
||||
}
|
||||
auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
|
||||
if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) {
|
||||
return node;
|
||||
} else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) {
|
||||
if (supported_checker_->CheckAICoreSupported(node, kernel_builder_info)) {
|
||||
return nullptr;
|
||||
} else if (supported_checker_->CheckAICPUSupported(node, kernel_builder_info)) {
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info);
|
||||
builder->SetKernelType(AICPU_KERNEL);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
||||
|
@ -49,7 +49,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
|
|||
MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node ["
|
||||
<< node->DebugString() << "]";
|
||||
}
|
||||
return node;
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -148,7 +148,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
|
|||
auto indices_const = CreateValueNode(new_cnode);
|
||||
new_cnode->add_input(indices_const);
|
||||
MS_EXCEPTION_IF_NULL(supported_checker_);
|
||||
if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) {
|
||||
if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) {
|
||||
MS_LOG(INFO) << "split topk failed, check to aicpu.";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap
|
|||
new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor());
|
||||
|
||||
auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName);
|
||||
if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) {
|
||||
if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) {
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata),
|
||||
utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
|
||||
auto new_node = func_graph->NewCNode(inputs);
|
||||
|
|
|
@ -976,5 +976,21 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
|
|||
}
|
||||
MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString();
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) {
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
|
||||
if (shape.empty()) {
|
||||
return true;
|
||||
}
|
||||
return shape.size() == kShape1dDims && shape[0] == 1;
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) {
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
|
||||
if (shape.empty()) {
|
||||
return true;
|
||||
}
|
||||
return shape.size() == kShape1dDims && shape[0] == 1;
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -185,6 +185,8 @@ class AnfRuntimeAlgorithm {
|
|||
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
|
||||
static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
|
||||
static bool IsSwitchCall(const CNodePtr &call_node);
|
||||
static bool IsScalarInput(const CNodePtr &cnode, size_t index);
|
||||
static bool IsScalarOutput(const CNodePtr &cnode, size_t index);
|
||||
};
|
||||
} // namespace session
|
||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||
|
|
|
@ -207,7 +207,9 @@ constexpr auto kValueTargetOther = "target_other";
|
|||
|
||||
// some size
|
||||
const size_t kShape4dDims = 4;
|
||||
const size_t kShape2dDims = 2;
|
||||
const size_t kShape5dDims = 5;
|
||||
const size_t kShape1dDims = 1;
|
||||
const size_t kCubeSize = 16;
|
||||
const size_t kMemAlignSize = 512;
|
||||
const int kParameterDataTensorMask = 0;
|
||||
|
|
|
@ -55,8 +55,7 @@ class MockSupportedChecker : public SupportedChecker {
|
|||
public:
|
||||
MockSupportedChecker() = default;
|
||||
~MockSupportedChecker() override = default;
|
||||
bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
|
||||
const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
||||
bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
||||
return true;
|
||||
}
|
||||
}; // namespace opt
|
||||
|
|
|
@ -42,7 +42,7 @@ class MockSupportedChecker : public SupportedChecker {
|
|||
public:
|
||||
MockSupportedChecker() = default;
|
||||
~MockSupportedChecker() override = default;
|
||||
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
||||
bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue