forked from mindspore-Ecosystem/mindspore
commit
149285b6f2
|
@ -144,7 +144,7 @@ class PoolingGpuFwdKernel : public GpuKernel {
|
||||||
void SetPoolingMode(const CNodePtr &kernel_node) {
|
void SetPoolingMode(const CNodePtr &kernel_node) {
|
||||||
mode_ = AnfAlgo::GetCNodeName(kernel_node);
|
mode_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
if (mode_ == "AvgPool") {
|
if (mode_ == "AvgPool") {
|
||||||
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
|
||||||
pad_value_ = 0.0;
|
pad_value_ = 0.0;
|
||||||
} else {
|
} else {
|
||||||
pooling_mode_ = CUDNN_POOLING_MAX;
|
pooling_mode_ = CUDNN_POOLING_MAX;
|
||||||
|
|
|
@ -207,7 +207,7 @@ class PoolingGradGpuKernel : public GpuKernel {
|
||||||
void SetPoolingMode(const CNodePtr &kernel_node) {
|
void SetPoolingMode(const CNodePtr &kernel_node) {
|
||||||
mode_ = AnfAlgo::GetCNodeName(kernel_node);
|
mode_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
if (mode_ == "AvgPoolGradGpu") {
|
if (mode_ == "AvgPoolGradGpu") {
|
||||||
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
|
||||||
pad_value_ = 0.0;
|
pad_value_ = 0.0;
|
||||||
} else {
|
} else {
|
||||||
pooling_mode_ = CUDNN_POOLING_MAX;
|
pooling_mode_ = CUDNN_POOLING_MAX;
|
||||||
|
|
|
@ -37,13 +37,16 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
|
||||||
const EquivPtr &equiv) const {
|
const EquivPtr &equiv) const {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(equiv);
|
|
||||||
auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||||
auto x_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 0);
|
auto x_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 0);
|
||||||
auto x_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(x_after), 0);
|
auto x_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(x_after), 0);
|
||||||
MS_EXCEPTION_IF_NULL(fbn2);
|
MS_EXCEPTION_IF_NULL(fbn2);
|
||||||
MS_EXCEPTION_IF_NULL(x_after);
|
MS_EXCEPTION_IF_NULL(x_after);
|
||||||
MS_EXCEPTION_IF_NULL(x_before);
|
MS_EXCEPTION_IF_NULL(x_before);
|
||||||
|
// only deal with x_after with fp32: x 16->32->bn->16->32
|
||||||
|
if (AnfAlgo::GetOutputInferDataType(x_after, 0) == kNumberTypeFloat16) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
std::vector<TypeId> outputs_type;
|
std::vector<TypeId> outputs_type;
|
||||||
std::vector<std::vector<size_t>> outputs_shape;
|
std::vector<std::vector<size_t>> outputs_shape;
|
||||||
auto manager = graph->manager();
|
auto manager = graph->manager();
|
||||||
|
|
|
@ -68,8 +68,9 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
|
||||||
auto dy_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(dy_after), 0);
|
auto dy_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(dy_after), 0);
|
||||||
auto x_ = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 1);
|
auto x_ = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 1);
|
||||||
MS_EXCEPTION_IF_NULL(x_);
|
MS_EXCEPTION_IF_NULL(x_);
|
||||||
// if x_type is fp32, the cast is necessary.
|
// if x_type is fp32, the cast is necessary or dy_afer is fp32: dy 16->32->bng->16->32.
|
||||||
if (AnfAlgo::GetOutputInferDataType(x_, 0) == kNumberTypeFloat32) {
|
if (AnfAlgo::GetOutputInferDataType(x_, 0) == kNumberTypeFloat32 ||
|
||||||
|
AnfAlgo::GetOutputInferDataType(dy_after, 0) == kNumberTypeFloat16) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(fbn2g);
|
MS_EXCEPTION_IF_NULL(fbn2g);
|
||||||
|
|
Loading…
Reference in New Issue