forked from mindspore-Ecosystem/mindspore
fix bug of hccl kernel info
This commit is contained in:
parent
82b4cadad2
commit
ea9b5468bb
|
@ -16,12 +16,30 @@
|
|||
|
||||
#include "kernel/hccl/hccl_kernel_metadata.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include "utils/utils.h"
|
||||
#include "kernel/hccl/hcom_util.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) {
|
||||
const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0};
|
||||
auto op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index);
|
||||
if (op_name != kReduceScatter && op_name != kAllGatherOpName) {
|
||||
return format;
|
||||
}
|
||||
if (format == kOpFormat_FRAC_NZ && AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index).size() <= 2) {
|
||||
return kOpFormat_DEFAULT;
|
||||
}
|
||||
if (kReduceNoSupportedSet.find(format) != kReduceNoSupportedSet.end()) {
|
||||
return kOpFormat_DEFAULT;
|
||||
}
|
||||
return format;
|
||||
}
|
||||
} // namespace
|
||||
void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
|
||||
const std::vector<TypeId> kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16,
|
||||
kNumberTypeFloat32, kNumberTypeInt16};
|
||||
|
@ -36,13 +54,13 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
|
|||
std::vector<std::string> inputs_format{};
|
||||
std::vector<TypeId> inputs_type{};
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||
inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index));
|
||||
inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index));
|
||||
inputs_type.push_back(type);
|
||||
}
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
|
||||
outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index));
|
||||
outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index));
|
||||
outputs_type.push_back(type);
|
||||
}
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
||||
|
|
|
@ -428,5 +428,5 @@ def test_pynative_resnet50():
|
|||
cost_time = end_time - start_time
|
||||
print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
|
||||
if step > 1:
|
||||
assert cost_time < 0.5
|
||||
assert cost_time < 0.3
|
||||
|
Loading…
Reference in New Issue