!8578 dynamic shape check

From: @wilfchen
Reviewed-by: @limingqi107,@cristoval
Signed-off-by: @cristoval
This commit is contained in:
mindspore-ci-bot 2020-11-16 09:38:27 +08:00 committed by Gitee
commit 24d04b1cb1
3 changed files with 34 additions and 22 deletions

View File

@ -251,7 +251,6 @@ class GpuKernel : public KernelMod {
device::DynamicKernelPtr dynamic_kernel_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -110,6 +110,37 @@ void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const A
manager->Replace(relu_grad, bn_add_relu_grad_output[kBNAddReluGradOutputNum - 1]);
return;
}
bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format");
MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr);
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
return false;
}
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(relu_grad);
auto relu_users = GetRealNodeUsedList(graph, relu_grad);
if (relu_users->size() != 2) {
return false;
}
// process pattern as Relu(TensorAdd(BN#0, BN#1))
auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
MS_EXCEPTION_IF_NULL(tuple_getitem);
if (!utils::isa<CNodePtr>(tuple_getitem) || AnfAlgo::GetCNodeName(tuple_getitem) != prim::kPrimTupleGetItem->name()) {
return false;
}
auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0);
if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) {
return false;
}
return true;
}
} // namespace
const BaseRef BatchNormAddReluGradFusion::DefinePattern() const {
@ -123,31 +154,13 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format");
MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr);
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
if (!PatternCheck(graph, node)) {
return nullptr;
}
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(relu_grad);
auto relu_users = GetRealNodeUsedList(graph, relu_grad);
if (relu_users->size() != 2) {
return nullptr;
}
// process pattern as Relu(TensorAdd(BN#0, BN#1))
auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
MS_EXCEPTION_IF_NULL(tuple_getitem);
if (!utils::isa<CNodePtr>(tuple_getitem) || AnfAlgo::GetCNodeName(tuple_getitem) != prim::kPrimTupleGetItem->name()) {
return nullptr;
}
auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0);
if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) {
return nullptr;
}
auto dy = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 0);
MS_EXCEPTION_IF_NULL(dy);
auto y = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 1);

View File

@ -1432,7 +1432,7 @@ void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tens
}
bool IsDynamicShape(const NotNull<abstract::ShapePtr> &shape) {
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s < 0; });
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; });
}
bool IsNodeOutputDynamicShape(const CNodePtr &anf_node_ptr) {