forked from mindspore-Ecosystem/mindspore
!1911 add a function to charge the node input and output is a scalar
Merge pull request !1911 from lianliguang/add-a-function-to-charge-the-node-input-or-output-if-is-a-scalar
This commit is contained in:
commit
beb714d2d0
|
@ -37,11 +37,11 @@ class SupportedChecker {
|
||||||
public:
|
public:
|
||||||
SupportedChecker() = default;
|
SupportedChecker() = default;
|
||||||
virtual ~SupportedChecker() = default;
|
virtual ~SupportedChecker() = default;
|
||||||
virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
|
virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node,
|
||||||
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
|
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
|
||||||
return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info);
|
return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info);
|
||||||
}
|
}
|
||||||
virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node,
|
virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node,
|
||||||
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
|
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
|
||||||
return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info);
|
return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info);
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,9 +38,9 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
|
auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
|
||||||
if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) {
|
if (supported_checker_->CheckAICoreSupported(node, kernel_builder_info)) {
|
||||||
return node;
|
return nullptr;
|
||||||
} else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) {
|
} else if (supported_checker_->CheckAICPUSupported(node, kernel_builder_info)) {
|
||||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info);
|
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info);
|
||||||
builder->SetKernelType(AICPU_KERNEL);
|
builder->SetKernelType(AICPU_KERNEL);
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
||||||
|
@ -49,7 +49,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
|
||||||
MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node ["
|
MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node ["
|
||||||
<< node->DebugString() << "]";
|
<< node->DebugString() << "]";
|
||||||
}
|
}
|
||||||
return node;
|
return nullptr;
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -148,7 +148,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
|
||||||
auto indices_const = CreateValueNode(new_cnode);
|
auto indices_const = CreateValueNode(new_cnode);
|
||||||
new_cnode->add_input(indices_const);
|
new_cnode->add_input(indices_const);
|
||||||
MS_EXCEPTION_IF_NULL(supported_checker_);
|
MS_EXCEPTION_IF_NULL(supported_checker_);
|
||||||
if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) {
|
if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) {
|
||||||
MS_LOG(INFO) << "split topk failed, check to aicpu.";
|
MS_LOG(INFO) << "split topk failed, check to aicpu.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap
|
||||||
new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor());
|
new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor());
|
||||||
|
|
||||||
auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName);
|
auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName);
|
||||||
if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) {
|
if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) {
|
||||||
std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata),
|
std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata),
|
||||||
utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
|
utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
|
||||||
auto new_node = func_graph->NewCNode(inputs);
|
auto new_node = func_graph->NewCNode(inputs);
|
||||||
|
|
|
@ -976,5 +976,21 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
|
||||||
}
|
}
|
||||||
MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString();
|
MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) {
|
||||||
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
|
||||||
|
if (shape.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return shape.size() == kShape1dDims && shape[0] == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) {
|
||||||
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
|
||||||
|
if (shape.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return shape.size() == kShape1dDims && shape[0] == 1;
|
||||||
|
}
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -185,6 +185,8 @@ class AnfRuntimeAlgorithm {
|
||||||
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
|
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
|
||||||
static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
|
static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
|
||||||
static bool IsSwitchCall(const CNodePtr &call_node);
|
static bool IsSwitchCall(const CNodePtr &call_node);
|
||||||
|
static bool IsScalarInput(const CNodePtr &cnode, size_t index);
|
||||||
|
static bool IsScalarOutput(const CNodePtr &cnode, size_t index);
|
||||||
};
|
};
|
||||||
} // namespace session
|
} // namespace session
|
||||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||||
|
|
|
@ -207,7 +207,9 @@ constexpr auto kValueTargetOther = "target_other";
|
||||||
|
|
||||||
// some size
|
// some size
|
||||||
const size_t kShape4dDims = 4;
|
const size_t kShape4dDims = 4;
|
||||||
|
const size_t kShape2dDims = 2;
|
||||||
const size_t kShape5dDims = 5;
|
const size_t kShape5dDims = 5;
|
||||||
|
const size_t kShape1dDims = 1;
|
||||||
const size_t kCubeSize = 16;
|
const size_t kCubeSize = 16;
|
||||||
const size_t kMemAlignSize = 512;
|
const size_t kMemAlignSize = 512;
|
||||||
const int kParameterDataTensorMask = 0;
|
const int kParameterDataTensorMask = 0;
|
||||||
|
|
|
@ -55,8 +55,7 @@ class MockSupportedChecker : public SupportedChecker {
|
||||||
public:
|
public:
|
||||||
MockSupportedChecker() = default;
|
MockSupportedChecker() = default;
|
||||||
~MockSupportedChecker() override = default;
|
~MockSupportedChecker() override = default;
|
||||||
bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
|
bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
||||||
const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}; // namespace opt
|
}; // namespace opt
|
||||||
|
|
|
@ -42,7 +42,7 @@ class MockSupportedChecker : public SupportedChecker {
|
||||||
public:
|
public:
|
||||||
MockSupportedChecker() = default;
|
MockSupportedChecker() = default;
|
||||||
~MockSupportedChecker() override = default;
|
~MockSupportedChecker() override = default;
|
||||||
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue