forked from mindspore-Ecosystem/mindspore
change feature map flag
This commit is contained in:
parent
1321483749
commit
687cd623f9
|
@ -187,7 +187,7 @@ class CNodeDecoder {
|
|||
if ((node.first)->isa<Parameter>()) {
|
||||
auto parameter = (node.first)->cast<ParameterPtr>();
|
||||
bool is_weight = AnfAlgo::IsParameterWeight(parameter);
|
||||
kernel_info->SetFeatureMapFlag(!is_weight);
|
||||
kernel_info->set_feature_map_flag(!is_weight);
|
||||
if (!is_weight) {
|
||||
feature_map_input_indexs.push_back(index - 1);
|
||||
}
|
||||
|
@ -200,7 +200,7 @@ class CNodeDecoder {
|
|||
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode_);
|
||||
}
|
||||
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
|
||||
kernel_info->SetFeatureMapFlag(true);
|
||||
kernel_info->set_feature_map_flag(true);
|
||||
}
|
||||
if (AnfAlgo::IsRealCNodeKernel(cnode_)) {
|
||||
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode_);
|
||||
|
|
|
@ -31,6 +31,8 @@ namespace session {
|
|||
namespace {
|
||||
constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
|
||||
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
|
||||
const std::set<std::string> kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(),
|
||||
prim::kPrimAssignSub->name()};
|
||||
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -417,21 +419,41 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
|
|||
}
|
||||
}
|
||||
|
||||
void KernelGraph::ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const {
|
||||
if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(cnode)) == kOpAssignKernelNameList.end()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported to change the node [Assign , AssignSub, AssignAdd] node's input feature map "
|
||||
"flag but got the node :"
|
||||
<< cnode->DebugString();
|
||||
}
|
||||
auto input_node = AnfAlgo::GetInputNode(cnode, 0);
|
||||
auto assign_value_node = AnfAlgo::GetInputNode(cnode, 1);
|
||||
if (AnfAlgo::IsFeatureMapOutput(input_node)) {
|
||||
return;
|
||||
}
|
||||
if (!AnfAlgo::IsFeatureMapOutput(input_node) && AnfAlgo::IsFeatureMapOutput(assign_value_node)) {
|
||||
auto kernel_info = static_cast<device::KernelInfo *>(input_node->kernel_info());
|
||||
kernel_info->set_feature_map_flag(true);
|
||||
}
|
||||
}
|
||||
|
||||
void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
node->set_kernel_info(kernel_info);
|
||||
if (node->isa<CNode>()) {
|
||||
if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(node)) != kOpAssignKernelNameList.end()) {
|
||||
ResetAssignInputFeaatureMapFlag(node->cast<CNodePtr>());
|
||||
}
|
||||
std::vector<size_t> feature_map_input_indexs;
|
||||
kernel_info->SetFeatureMapFlag(false);
|
||||
kernel_info->set_feature_map_flag(false);
|
||||
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) {
|
||||
if (AnfAlgo::IsFeatureMapInput(node, index)) {
|
||||
kernel_info->SetFeatureMapFlag(true);
|
||||
kernel_info->set_feature_map_flag(true);
|
||||
feature_map_input_indexs.push_back(index);
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::GetInputTensorNum(node) == 0) {
|
||||
kernel_info->SetFeatureMapFlag(true);
|
||||
kernel_info->set_feature_map_flag(true);
|
||||
}
|
||||
if (AnfAlgo::IsRealKernel(node)) {
|
||||
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
|
||||
|
@ -446,7 +468,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
|
|||
std::vector<TypeId> types;
|
||||
std::vector<std::string> formats = {kOpFormat_DEFAULT};
|
||||
if (node->isa<ValueNode>()) {
|
||||
kernel_info->SetFeatureMapFlag(false);
|
||||
kernel_info->set_feature_map_flag(false);
|
||||
types.emplace_back(kTypeUnknown);
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
SyncDeviceInfoToValueNode(value_node, &formats, &types);
|
||||
|
@ -455,7 +477,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
|
|||
auto parameter = node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
bool is_weight = AnfAlgo ::IsParameterWeight(parameter);
|
||||
kernel_info->SetFeatureMapFlag(!is_weight);
|
||||
kernel_info->set_feature_map_flag(!is_weight);
|
||||
types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0));
|
||||
}
|
||||
// set parameter initaial device data type
|
||||
|
|
|
@ -100,6 +100,7 @@ class KernelGraph : public FuncGraph {
|
|||
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs) override;
|
||||
void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
|
||||
CNodePtr NewCNode(const CNodePtr &cnode);
|
||||
void ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const;
|
||||
ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr);
|
||||
ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract);
|
||||
ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value);
|
||||
|
|
|
@ -837,7 +837,6 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker
|
|||
new_value_node->set_abstract(value_node->abstract());
|
||||
// create new kernel_info of new value_node
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
kernel_info->SetFeatureMapFlag(false);
|
||||
new_value_node->set_kernel_info(kernel_info);
|
||||
// create kernel_build_info for new value node
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
|
|
|
@ -48,7 +48,7 @@ class KernelInfo : public KernelInfoDevice {
|
|||
void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
|
||||
select_kernel_build_info_ = select_kernel_build_info;
|
||||
}
|
||||
void SetFeatureMapFlag(bool flag) { is_feature_map_ = flag; }
|
||||
void set_feature_map_flag(bool flag) { is_feature_map_ = flag; }
|
||||
const DeviceAddress *GetOutputAddr(size_t index) const;
|
||||
DeviceAddressPtr GetMutableOutputAddr(size_t index) const;
|
||||
bool OutputAddrExist(size_t index) const;
|
||||
|
|
Loading…
Reference in New Issue