forked from mindspore-Ecosystem/mindspore
add output filter for BatchNorm operator
Add some comments addressed John's comments CI check CI check part2
This commit is contained in:
parent
094e701e50
commit
ea8c8361d6
|
@ -76,6 +76,24 @@ bool GPUKernelRuntime::Init() {
|
|||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<int> CheckRealOutput(const std::string &node_name, const size_t &output_size) {
|
||||
// define a vector containing real output number
|
||||
std::vector<int> real_outputs;
|
||||
// P.FusedBatchNorm is used for training; P.BatchNorm is used for inference
|
||||
// can add the filter list for more operators here....
|
||||
if (node_name == "FusedBatchNorm") {
|
||||
MS_LOG(INFO) << "loading node named FusedBatchNorm.";
|
||||
real_outputs.insert(real_outputs.end(), {0, 3, 4});
|
||||
} else {
|
||||
// by default, TensorLoader will load all outputs
|
||||
for (size_t j = 0; j < output_size; ++j) {
|
||||
real_outputs.push_back(j);
|
||||
}
|
||||
}
|
||||
return real_outputs;
|
||||
}
|
||||
|
||||
void LoadKernelData(Debugger *debugger, const CNodePtr &kernel,
|
||||
const std::vector<mindspore::kernel::AddressPtr> &kernel_inputs,
|
||||
const std::vector<mindspore::kernel::AddressPtr> &kernel_workspaces,
|
||||
|
@ -120,7 +138,13 @@ void LoadKernelData(Debugger *debugger, const CNodePtr &kernel,
|
|||
|
||||
// get outputs
|
||||
auto output_size = AnfAlgo::GetOutputTensorNum(kernel);
|
||||
for (size_t j = 0; j < output_size; ++j) {
|
||||
auto node_name = AnfAlgo::GetCNodeName(kernel);
|
||||
|
||||
std::vector<int> real_outputs;
|
||||
real_outputs = CheckRealOutput(node_name, output_size);
|
||||
|
||||
for (std::vector<int>::iterator it = real_outputs.begin(); it != real_outputs.end(); ++it) {
|
||||
auto j = *it;
|
||||
auto addr = kernel_outputs[j];
|
||||
auto type = AnfAlgo::GetOutputInferDataType(kernel, j);
|
||||
auto format = kOpFormat_DEFAULT;
|
||||
|
|
Loading…
Reference in New Issue