forked from mindspore-Ecosystem/mindspore
commit
076d8ae530
|
@ -86,7 +86,7 @@ class MatMulGpuKernel : public GpuKernel {
|
||||||
dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1)));
|
dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1)));
|
||||||
dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0)));
|
dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0)));
|
||||||
if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) {
|
if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) {
|
||||||
MS_LOG(WARNING) << "input and output type is float16, allow to use Tensor Core operations if possible";
|
MS_LOG(INFO) << "input and output type is float16, allow to use Tensor Core operations if possible";
|
||||||
algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||||
}
|
}
|
||||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||||
|
|
|
@ -86,6 +86,7 @@ class PoolingGradGpuKernel : public GpuKernel {
|
||||||
auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
||||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||||
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
|
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||||
|
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||||
is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask);
|
is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask);
|
||||||
if (is_null_input_) {
|
if (is_null_input_) {
|
||||||
MS_LOG(WARNING) << "PoolingGradGpuKernel input is null.";
|
MS_LOG(WARNING) << "PoolingGradGpuKernel input is null.";
|
||||||
|
@ -204,7 +205,6 @@ class PoolingGradGpuKernel : public GpuKernel {
|
||||||
"cudnnSetPoolingNdDescriptor failed");
|
"cudnnSetPoolingNdDescriptor failed");
|
||||||
}
|
}
|
||||||
void SetPoolingMode(const CNodePtr &kernel_node) {
|
void SetPoolingMode(const CNodePtr &kernel_node) {
|
||||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
|
||||||
mode_ = AnfAlgo::GetCNodeName(kernel_node);
|
mode_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
if (mode_ == "AvgPoolGradGpu") {
|
if (mode_ == "AvgPoolGradGpu") {
|
||||||
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
||||||
|
|
|
@ -345,7 +345,6 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
|
||||||
SetGraphKernelInfo(kernel_node, func_graph);
|
SetGraphKernelInfo(kernel_node, func_graph);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> inputs_format;
|
std::vector<std::string> inputs_format;
|
||||||
std::vector<TypeId> inputs_type;
|
std::vector<TypeId> inputs_type;
|
||||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||||
|
@ -368,12 +367,10 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
|
||||||
builder->SetInputsDeviceType(inputs_type);
|
builder->SetInputsDeviceType(inputs_type);
|
||||||
builder->SetOutputsFormat(outputs_format);
|
builder->SetOutputsFormat(outputs_format);
|
||||||
builder->SetOutputsDeviceType(outputs_type);
|
builder->SetOutputsDeviceType(outputs_type);
|
||||||
|
|
||||||
bool result = false;
|
bool result = false;
|
||||||
if (kernel_type == UNKNOWN_KERNEL_TYPE) {
|
if (kernel_type == UNKNOWN_KERNEL_TYPE) {
|
||||||
result =
|
result =
|
||||||
kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build());
|
kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build());
|
||||||
|
|
||||||
if (!result) {
|
if (!result) {
|
||||||
result = SelectAkgKernel(kernel_node, builder->Build());
|
result = SelectAkgKernel(kernel_node, builder->Build());
|
||||||
kernel_type = AKG_KERNEL;
|
kernel_type = AKG_KERNEL;
|
||||||
|
@ -381,7 +378,6 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
|
||||||
} else if (kernel_type == AKG_KERNEL) {
|
} else if (kernel_type == AKG_KERNEL) {
|
||||||
result = SelectAkgKernel(kernel_node, builder->Build());
|
result = SelectAkgKernel(kernel_node, builder->Build());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!result) {
|
if (!result) {
|
||||||
PrintUnsupportedTypeException(kernel_node, inputs_type, outputs_type);
|
PrintUnsupportedTypeException(kernel_node, inputs_type, outputs_type);
|
||||||
return;
|
return;
|
||||||
|
|
Loading…
Reference in New Issue