forked from mindspore-Ecosystem/mindspore
!8578 dynamic shape check
From: @wilfchen Reviewed-by: @limingqi107,@cristoval Signed-off-by: @cristoval
This commit is contained in:
commit
24d04b1cb1
|
@ -251,7 +251,6 @@ class GpuKernel : public KernelMod {
|
|||
|
||||
device::DynamicKernelPtr dynamic_kernel_;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -110,6 +110,37 @@ void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const A
|
|||
manager->Replace(relu_grad, bn_add_relu_grad_output[kBNAddReluGradOutputNum - 1]);
|
||||
return;
|
||||
}
|
||||
|
||||
bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format");
|
||||
MS_EXCEPTION_IF_NULL(format_attr);
|
||||
auto format = GetValue<std::string>(format_attr);
|
||||
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
MS_EXCEPTION_IF_NULL(relu_grad);
|
||||
auto relu_users = GetRealNodeUsedList(graph, relu_grad);
|
||||
if (relu_users->size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// process pattern as Relu(TensorAdd(BN#0, BN#1))
|
||||
auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
if (!utils::isa<CNodePtr>(tuple_getitem) || AnfAlgo::GetCNodeName(tuple_getitem) != prim::kPrimTupleGetItem->name()) {
|
||||
return false;
|
||||
}
|
||||
auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0);
|
||||
if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef BatchNormAddReluGradFusion::DefinePattern() const {
|
||||
|
@ -123,31 +154,13 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
|
|||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format");
|
||||
MS_EXCEPTION_IF_NULL(format_attr);
|
||||
auto format = GetValue<std::string>(format_attr);
|
||||
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
|
||||
|
||||
if (!PatternCheck(graph, node)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
MS_EXCEPTION_IF_NULL(relu_grad);
|
||||
auto relu_users = GetRealNodeUsedList(graph, relu_grad);
|
||||
if (relu_users->size() != 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// process pattern as Relu(TensorAdd(BN#0, BN#1))
|
||||
auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
if (!utils::isa<CNodePtr>(tuple_getitem) || AnfAlgo::GetCNodeName(tuple_getitem) != prim::kPrimTupleGetItem->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0);
|
||||
if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto dy = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 0);
|
||||
MS_EXCEPTION_IF_NULL(dy);
|
||||
auto y = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 1);
|
||||
|
|
|
@ -1432,7 +1432,7 @@ void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tens
|
|||
}
|
||||
|
||||
bool IsDynamicShape(const NotNull<abstract::ShapePtr> &shape) {
|
||||
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s < 0; });
|
||||
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; });
|
||||
}
|
||||
|
||||
bool IsNodeOutputDynamicShape(const CNodePtr &anf_node_ptr) {
|
||||
|
|
Loading…
Reference in New Issue