!6458 GPU codex fix

Merge pull request !6458 from VectorSL/codex-fix
This commit is contained in:
mindspore-ci-bot 2020-09-18 15:20:22 +08:00 committed by Gitee
commit 076d8ae530
3 changed files with 2 additions and 6 deletions

View File

@ -86,7 +86,7 @@ class MatMulGpuKernel : public GpuKernel {
dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))); dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1)));
dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))); dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0)));
if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) { if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) {
MS_LOG(WARNING) << "input and output type is float16, allow to use Tensor Core operations if possible"; MS_LOG(INFO) << "input and output type is float16, allow to use Tensor Core operations if possible";
algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP; algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
} }
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);

View File

@ -86,6 +86,7 @@ class PoolingGradGpuKernel : public GpuKernel {
auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask); is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask);
if (is_null_input_) { if (is_null_input_) {
MS_LOG(WARNING) << "PoolingGradGpuKernel input is null."; MS_LOG(WARNING) << "PoolingGradGpuKernel input is null.";
@ -204,7 +205,6 @@ class PoolingGradGpuKernel : public GpuKernel {
"cudnnSetPoolingNdDescriptor failed"); "cudnnSetPoolingNdDescriptor failed");
} }
void SetPoolingMode(const CNodePtr &kernel_node) { void SetPoolingMode(const CNodePtr &kernel_node) {
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
mode_ = AnfAlgo::GetCNodeName(kernel_node); mode_ = AnfAlgo::GetCNodeName(kernel_node);
if (mode_ == "AvgPoolGradGpu") { if (mode_ == "AvgPoolGradGpu") {
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;

View File

@ -345,7 +345,6 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
SetGraphKernelInfo(kernel_node, func_graph); SetGraphKernelInfo(kernel_node, func_graph);
return; return;
} }
std::vector<std::string> inputs_format; std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type; std::vector<TypeId> inputs_type;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
@ -368,12 +367,10 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
builder->SetInputsDeviceType(inputs_type); builder->SetInputsDeviceType(inputs_type);
builder->SetOutputsFormat(outputs_format); builder->SetOutputsFormat(outputs_format);
builder->SetOutputsDeviceType(outputs_type); builder->SetOutputsDeviceType(outputs_type);
bool result = false; bool result = false;
if (kernel_type == UNKNOWN_KERNEL_TYPE) { if (kernel_type == UNKNOWN_KERNEL_TYPE) {
result = result =
kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build()); kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build());
if (!result) { if (!result) {
result = SelectAkgKernel(kernel_node, builder->Build()); result = SelectAkgKernel(kernel_node, builder->Build());
kernel_type = AKG_KERNEL; kernel_type = AKG_KERNEL;
@ -381,7 +378,6 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
} else if (kernel_type == AKG_KERNEL) { } else if (kernel_type == AKG_KERNEL) {
result = SelectAkgKernel(kernel_node, builder->Build()); result = SelectAkgKernel(kernel_node, builder->Build());
} }
if (!result) { if (!result) {
PrintUnsupportedTypeException(kernel_node, inputs_type, outputs_type); PrintUnsupportedTypeException(kernel_node, inputs_type, outputs_type);
return; return;