!58 fix two cast bug in auto parallel

Merge pull request !58 from lichen/fix_two_cast_bug_in_auto_parallel
This commit is contained in:
mindspore-ci-bot 2020-04-01 09:54:10 +08:00 committed by Gitee
commit 87040483ee
3 changed files with 51 additions and 11 deletions

View File

@ -346,6 +346,8 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
} }
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) { OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(prim);
MS_EXCEPTION_IF_NULL(cnode);
auto attrs = prim->attrs(); auto attrs = prim->attrs();
std::vector<Shapes> shape_list = ExtractShape(cnode); std::vector<Shapes> shape_list = ExtractShape(cnode);
if (shape_list.empty()) { if (shape_list.empty()) {
@ -381,8 +383,8 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
operator_info->set_outputs_dtype(cnode->Type()); operator_info->set_outputs_dtype(cnode->Type());
operator_info->set_cnode(cnode); operator_info->set_cnode(cnode);
// If no strategy has been configured for this operator, then candidate strategies are generated for // If no strategy has been configured for this operator, then candidate strategies are generated for
// auto-strategy searching // auto-strategy searchingm if this primitive is Cast, we ignore the user-specified strategy
if (!StrategyFound(attrs)) { if (!StrategyFound(attrs) || prim->name() == CAST) {
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
// BatchParallelInfo operator // BatchParallelInfo operator
operator_info->ComputeBatchSplitFlagList(); operator_info->ComputeBatchSplitFlagList();

View File

@ -371,7 +371,6 @@ bool IsParallelCareNode(const CNodePtr& cnode) {
if (prim == nullptr) { if (prim == nullptr) {
return false; return false;
} }
auto attrs = prim->attrs();
if (IsInBlackList(prim)) { if (IsInBlackList(prim)) {
MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); MS_LOG(INFO) << "Parallel don't care node: " << prim->name();
return false; return false;
@ -380,10 +379,8 @@ bool IsParallelCareNode(const CNodePtr& cnode) {
if (prim->name() == GET_NEXT) { if (prim->name() == GET_NEXT) {
return true; return true;
} }
if ((prim->name() == CAST)) { if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) {
if ((!attrs.count(STRATEGY)) && (cnode->operator_info() == nullptr)) { return false;
return false;
}
} }
return cnode->in_forward_flag(); return cnode->in_forward_flag();
@ -654,6 +651,14 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr& loss_node) {
LossNodeInfo node_info; LossNodeInfo node_info;
// return -> cast
auto pre_cnode = pre_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
pre_node = pre_cnode->input(1);
}
// return -> loss // return -> loss
if (pre_node == loss_node) { if (pre_node == loss_node) {
node_info.has_tuple_getitem = false; node_info.has_tuple_getitem = false;
@ -1948,6 +1953,14 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) {
MS_EXCEPTION_IF_NULL(current_value); MS_EXCEPTION_IF_NULL(current_value);
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>(); PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(current_prim); MS_EXCEPTION_IF_NULL(current_prim);
// return -> cast
if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
}
// notice: the GetNext op has not input // notice: the GetNext op has not input
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
MS_LOG(INFO) << "The loss is: " << current_prim->name(); MS_LOG(INFO) << "The loss is: " << current_prim->name();

View File

@ -192,7 +192,6 @@ def test_cast_before_mirror():
net = GradWrap(NetWithLoss(Net(strategy1))) net = GradWrap(NetWithLoss(Net(strategy1)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([128, 32]), dtype=ms.float32) x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float16) b = Tensor(np.ones([64, 64]), dtype=ms.float16)
@ -217,7 +216,6 @@ def test_cast_before_mirror1():
net = GradWrap(NetWithLoss(Net(strategy1))) net = GradWrap(NetWithLoss(Net(strategy1)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([128, 32]), dtype=ms.float16) x = Tensor(np.ones([128, 32]), dtype=ms.float16)
y = Tensor(np.ones([32, 64]), dtype=ms.float16) y = Tensor(np.ones([32, 64]), dtype=ms.float16)
b = Tensor(np.ones([64, 64]), dtype=ms.float32) b = Tensor(np.ones([64, 64]), dtype=ms.float32)
@ -242,7 +240,6 @@ def test_cast_before_mirror2():
net = GradWrap(NetWithLoss(Net(strategy1))) net = GradWrap(NetWithLoss(Net(strategy1)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([128, 32]), dtype=ms.float16) x = Tensor(np.ones([128, 32]), dtype=ms.float16)
y = Tensor(np.ones([32, 64]), dtype=ms.float16) y = Tensor(np.ones([32, 64]), dtype=ms.float16)
b = Tensor(np.ones([64, 64]), dtype=ms.float32) b = Tensor(np.ones([64, 64]), dtype=ms.float32)
@ -267,8 +264,36 @@ def test_cast_before_mirror3():
net = GradWrap(NetWithLoss(Net(strategy1))) net = GradWrap(NetWithLoss(Net(strategy1)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([128, 32]), dtype=ms.float16) x = Tensor(np.ones([128, 32]), dtype=ms.float16)
y = Tensor(np.ones([32, 64]), dtype=ms.float16) y = Tensor(np.ones([32, 64]), dtype=ms.float16)
b = Tensor(np.ones([64, 64]), dtype=ms.float32) b = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b) _executor.compile(net, x, y, b)
def test_mul_two_cast():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2, strategy3):
super().__init__()
self.mul = P.Mul().set_strategy(strategy1)
self.mul2 = P.Mul().set_strategy(strategy2)
self.cast = P.Cast().set_strategy(strategy3)
self.cast2 = P.Cast().set_strategy(strategy3)
def construct(self, x, y, b):
out = self.mul(x, y)
out = self.mul2(out, b)
out = self.cast(out, ms.int32)
out = self.cast2(out, ms.bool_)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = ((2, 2), (2, 2))
strategy2 = ((8, 1), (8, 1))
strategy3 = ((8, 1), )
net = GradWrap(Net(strategy1, strategy2, strategy3))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
b = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x, y, b)