change feature map flag

This commit is contained in:
LianLiguang 2020-11-26 16:53:40 +08:00
parent 1321483749
commit 687cd623f9
5 changed files with 31 additions and 9 deletions

View File

@ -187,7 +187,7 @@ class CNodeDecoder {
if ((node.first)->isa<Parameter>()) {
auto parameter = (node.first)->cast<ParameterPtr>();
bool is_weight = AnfAlgo::IsParameterWeight(parameter);
kernel_info->SetFeatureMapFlag(!is_weight);
kernel_info->set_feature_map_flag(!is_weight);
if (!is_weight) {
feature_map_input_indexs.push_back(index - 1);
}
@ -200,7 +200,7 @@ class CNodeDecoder {
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode_);
}
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
kernel_info->SetFeatureMapFlag(true);
kernel_info->set_feature_map_flag(true);
}
if (AnfAlgo::IsRealCNodeKernel(cnode_)) {
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode_);

View File

@ -31,6 +31,8 @@ namespace session {
namespace {
constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
const std::set<std::string> kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(),
prim::kPrimAssignSub->name()};
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
std::unordered_set<AnfNodePtr> *visited_nodes) {
MS_EXCEPTION_IF_NULL(node);
@ -417,21 +419,41 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
}
}
void KernelGraph::ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const {
if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(cnode)) == kOpAssignKernelNameList.end()) {
MS_LOG(EXCEPTION) << "Only supported to change the node [Assign , AssignSub, AssignAdd] node's input feature map "
"flag but got the node :"
<< cnode->DebugString();
}
auto input_node = AnfAlgo::GetInputNode(cnode, 0);
auto assign_value_node = AnfAlgo::GetInputNode(cnode, 1);
if (AnfAlgo::IsFeatureMapOutput(input_node)) {
return;
}
if (!AnfAlgo::IsFeatureMapOutput(input_node) && AnfAlgo::IsFeatureMapOutput(assign_value_node)) {
auto kernel_info = static_cast<device::KernelInfo *>(input_node->kernel_info());
kernel_info->set_feature_map_flag(true);
}
}
void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = std::make_shared<device::KernelInfo>();
node->set_kernel_info(kernel_info);
if (node->isa<CNode>()) {
if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(node)) != kOpAssignKernelNameList.end()) {
ResetAssignInputFeaatureMapFlag(node->cast<CNodePtr>());
}
std::vector<size_t> feature_map_input_indexs;
kernel_info->SetFeatureMapFlag(false);
kernel_info->set_feature_map_flag(false);
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) {
if (AnfAlgo::IsFeatureMapInput(node, index)) {
kernel_info->SetFeatureMapFlag(true);
kernel_info->set_feature_map_flag(true);
feature_map_input_indexs.push_back(index);
}
}
if (AnfAlgo::GetInputTensorNum(node) == 0) {
kernel_info->SetFeatureMapFlag(true);
kernel_info->set_feature_map_flag(true);
}
if (AnfAlgo::IsRealKernel(node)) {
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
@ -446,7 +468,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
std::vector<TypeId> types;
std::vector<std::string> formats = {kOpFormat_DEFAULT};
if (node->isa<ValueNode>()) {
kernel_info->SetFeatureMapFlag(false);
kernel_info->set_feature_map_flag(false);
types.emplace_back(kTypeUnknown);
auto value_node = node->cast<ValueNodePtr>();
SyncDeviceInfoToValueNode(value_node, &formats, &types);
@ -455,7 +477,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
auto parameter = node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
bool is_weight = AnfAlgo ::IsParameterWeight(parameter);
kernel_info->SetFeatureMapFlag(!is_weight);
kernel_info->set_feature_map_flag(!is_weight);
types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0));
}
// set parameter initaial device data type

View File

@ -100,6 +100,7 @@ class KernelGraph : public FuncGraph {
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override;
void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
CNodePtr NewCNode(const CNodePtr &cnode);
void ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const;
ParameterPtr NewParameter(const ParameterPtr &parameter = nullptr);
ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract);
ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value);

View File

@ -837,7 +837,6 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker
new_value_node->set_abstract(value_node->abstract());
// create new kernel_info of new value_node
auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_info->SetFeatureMapFlag(false);
new_value_node->set_kernel_info(kernel_info);
// create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();

View File

@ -48,7 +48,7 @@ class KernelInfo : public KernelInfoDevice {
void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
select_kernel_build_info_ = select_kernel_build_info;
}
void SetFeatureMapFlag(bool flag) { is_feature_map_ = flag; }
void set_feature_map_flag(bool flag) { is_feature_map_ = flag; }
const DeviceAddress *GetOutputAddr(size_t index) const;
DeviceAddressPtr GetMutableOutputAddr(size_t index) const;
bool OutputAddrExist(size_t index) const;