forked from mindspore-Ecosystem/mindspore
commit
149285b6f2
|
@ -144,7 +144,7 @@ class PoolingGpuFwdKernel : public GpuKernel {
|
|||
void SetPoolingMode(const CNodePtr &kernel_node) {
|
||||
mode_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (mode_ == "AvgPool") {
|
||||
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
||||
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
|
||||
pad_value_ = 0.0;
|
||||
} else {
|
||||
pooling_mode_ = CUDNN_POOLING_MAX;
|
||||
|
|
|
@ -207,7 +207,7 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|||
void SetPoolingMode(const CNodePtr &kernel_node) {
|
||||
mode_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (mode_ == "AvgPoolGradGpu") {
|
||||
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
||||
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
|
||||
pad_value_ = 0.0;
|
||||
} else {
|
||||
pooling_mode_ = CUDNN_POOLING_MAX;
|
||||
|
|
|
@ -37,13 +37,16 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
|
|||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
auto x_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 0);
|
||||
auto x_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(x_after), 0);
|
||||
MS_EXCEPTION_IF_NULL(fbn2);
|
||||
MS_EXCEPTION_IF_NULL(x_after);
|
||||
MS_EXCEPTION_IF_NULL(x_before);
|
||||
// only deal with x_after with fp32: x 16->32->bn->16->32
|
||||
if (AnfAlgo::GetOutputInferDataType(x_after, 0) == kNumberTypeFloat16) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<TypeId> outputs_type;
|
||||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
auto manager = graph->manager();
|
||||
|
|
|
@ -68,8 +68,9 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
|
|||
auto dy_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(dy_after), 0);
|
||||
auto x_ = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 1);
|
||||
MS_EXCEPTION_IF_NULL(x_);
|
||||
// if x_type is fp32, the cast is necessary.
|
||||
if (AnfAlgo::GetOutputInferDataType(x_, 0) == kNumberTypeFloat32) {
|
||||
// if x_type is fp32, the cast is necessary or dy_afer is fp32: dy 16->32->bng->16->32.
|
||||
if (AnfAlgo::GetOutputInferDataType(x_, 0) == kNumberTypeFloat32 ||
|
||||
AnfAlgo::GetOutputInferDataType(dy_after, 0) == kNumberTypeFloat16) {
|
||||
return nullptr;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(fbn2g);
|
||||
|
|
Loading…
Reference in New Issue