forked from mindspore-Ecosystem/mindspore
!852 add warning info to statistics how much nodes using raise or reduce to selected kernel info
Merge pull request !852 from lianliguang/master
This commit is contained in:
commit
ea0fb2ccae
|
@ -342,7 +342,7 @@ void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelB
|
|||
std::vector<int> *node_mix_precision_datatype_index) {
|
||||
MS_EXCEPTION_IF_NULL(node_mix_precision_datatype);
|
||||
bool add_node_datatype_flag = false;
|
||||
if (node_mix_precision_datatype->size() == 0) {
|
||||
if (node_mix_precision_datatype->empty()) {
|
||||
add_node_datatype_flag = true;
|
||||
}
|
||||
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
|
||||
|
@ -464,8 +464,9 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void SelectKernelInfo(const CNodePtr &kernel_node) {
|
||||
int SelectKernelInfo(const CNodePtr &kernel_node) {
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
int status = kStatusAllMatched;
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
bool precision_reduce = false;
|
||||
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
|
||||
|
@ -486,11 +487,13 @@ void SelectKernelInfo(const CNodePtr &kernel_node) {
|
|||
<< "] cannot find valid kernel info, not supported the type" << buffer.str();
|
||||
} else {
|
||||
PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
|
||||
status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
|
||||
}
|
||||
}
|
||||
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
|
||||
// Set format and data type for input tensor.
|
||||
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
|
||||
return status;
|
||||
}
|
||||
|
||||
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
void SelectKernelInfo(const CNodePtr &kernel_node);
|
||||
int SelectKernelInfo(const CNodePtr &kernel_node);
|
||||
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, const kernel::KernelBuildInfoPtr &new_kernel_build_info);
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -325,10 +325,25 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
|
|||
// compile graph steps
|
||||
void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
size_t raise_precision_count = 0;
|
||||
size_t reduce_precision_count = 0;
|
||||
for (const auto &cnode : kernel_graph.execution_order()) {
|
||||
device::ascend::SelectKernelInfo(cnode);
|
||||
auto status = device::ascend::SelectKernelInfo(cnode);
|
||||
if (status == kStatusRaisePrecision) {
|
||||
raise_precision_count++;
|
||||
} else if (status == kStatusReducePrecision) {
|
||||
reduce_precision_count++;
|
||||
}
|
||||
MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();
|
||||
}
|
||||
if (raise_precision_count > 0) {
|
||||
MS_LOG(WARNING) << "There has " << raise_precision_count
|
||||
<< " node/nodes used raise precision to selected the kernel!";
|
||||
}
|
||||
if (reduce_precision_count > 0) {
|
||||
MS_LOG(WARNING) << "There has " << reduce_precision_count
|
||||
<< " node/nodes used reduce precision to selected the kernel!";
|
||||
}
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
|
|
|
@ -186,7 +186,10 @@ constexpr auto kControlDependBehindIndex = 2;
|
|||
// index define of depend
|
||||
constexpr auto kRealInputIndexInDepend = 1;
|
||||
constexpr auto kDependAttachNodeIndex = 2;
|
||||
|
||||
// status of kernel select result
|
||||
const int kStatusReducePrecision = -1;
|
||||
const int kStatusRaisePrecision = 1;
|
||||
const int kStatusAllMatched = 0;
|
||||
// format
|
||||
constexpr auto kOpFormat_DEFAULT = "DefaultFormat";
|
||||
constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0";
|
||||
|
|
Loading…
Reference in New Issue