!6532 GPU fix BnCast

Merge pull request !6532 from VectorSL/bncast
This commit is contained in:
mindspore-ci-bot 2020-09-19 15:58:10 +08:00 committed by Gitee
commit 149285b6f2
4 changed files with 9 additions and 5 deletions

View File

@ -144,7 +144,7 @@ class PoolingGpuFwdKernel : public GpuKernel {
void SetPoolingMode(const CNodePtr &kernel_node) { void SetPoolingMode(const CNodePtr &kernel_node) {
mode_ = AnfAlgo::GetCNodeName(kernel_node); mode_ = AnfAlgo::GetCNodeName(kernel_node);
if (mode_ == "AvgPool") { if (mode_ == "AvgPool") {
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
pad_value_ = 0.0; pad_value_ = 0.0;
} else { } else {
pooling_mode_ = CUDNN_POOLING_MAX; pooling_mode_ = CUDNN_POOLING_MAX;

View File

@ -207,7 +207,7 @@ class PoolingGradGpuKernel : public GpuKernel {
void SetPoolingMode(const CNodePtr &kernel_node) { void SetPoolingMode(const CNodePtr &kernel_node) {
mode_ = AnfAlgo::GetCNodeName(kernel_node); mode_ = AnfAlgo::GetCNodeName(kernel_node);
if (mode_ == "AvgPoolGradGpu") { if (mode_ == "AvgPoolGradGpu") {
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
pad_value_ = 0.0; pad_value_ = 0.0;
} else { } else {
pooling_mode_ = CUDNN_POOLING_MAX; pooling_mode_ = CUDNN_POOLING_MAX;

View File

@ -37,13 +37,16 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
const EquivPtr &equiv) const { const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto x_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 0); auto x_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 0);
auto x_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(x_after), 0); auto x_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(x_after), 0);
MS_EXCEPTION_IF_NULL(fbn2); MS_EXCEPTION_IF_NULL(fbn2);
MS_EXCEPTION_IF_NULL(x_after); MS_EXCEPTION_IF_NULL(x_after);
MS_EXCEPTION_IF_NULL(x_before); MS_EXCEPTION_IF_NULL(x_before);
// only deal with x_after with fp32: x 16->32->bn->16->32
if (AnfAlgo::GetOutputInferDataType(x_after, 0) == kNumberTypeFloat16) {
return nullptr;
}
std::vector<TypeId> outputs_type; std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape; std::vector<std::vector<size_t>> outputs_shape;
auto manager = graph->manager(); auto manager = graph->manager();

View File

@ -68,8 +68,9 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
auto dy_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(dy_after), 0); auto dy_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(dy_after), 0);
auto x_ = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 1); auto x_ = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 1);
MS_EXCEPTION_IF_NULL(x_); MS_EXCEPTION_IF_NULL(x_);
// if x_type is fp32, the cast is necessary. // if x_type is fp32, the cast is necessary or dy_afer is fp32: dy 16->32->bng->16->32.
if (AnfAlgo::GetOutputInferDataType(x_, 0) == kNumberTypeFloat32) { if (AnfAlgo::GetOutputInferDataType(x_, 0) == kNumberTypeFloat32 ||
AnfAlgo::GetOutputInferDataType(dy_after, 0) == kNumberTypeFloat16) {
return nullptr; return nullptr;
} }
MS_EXCEPTION_IF_NULL(fbn2g); MS_EXCEPTION_IF_NULL(fbn2g);