forked from mindspore-Ecosystem/mindspore
fix two cast bug in auto parallel
This commit is contained in:
parent
b27129c9da
commit
07449cd1cc
|
@ -350,6 +350,8 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
|
||||||
}
|
}
|
||||||
|
|
||||||
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) {
|
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
auto attrs = prim->attrs();
|
auto attrs = prim->attrs();
|
||||||
std::vector<Shapes> shape_list = ExtractShape(cnode);
|
std::vector<Shapes> shape_list = ExtractShape(cnode);
|
||||||
if (shape_list.empty()) {
|
if (shape_list.empty()) {
|
||||||
|
|
|
@ -374,7 +374,6 @@ bool IsParallelCareNode(const CNodePtr& cnode) {
|
||||||
if (prim == nullptr) {
|
if (prim == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto attrs = prim->attrs();
|
|
||||||
if (IsInBlackList(prim)) {
|
if (IsInBlackList(prim)) {
|
||||||
MS_LOG(INFO) << "Parallel don't care node: " << prim->name();
|
MS_LOG(INFO) << "Parallel don't care node: " << prim->name();
|
||||||
return false;
|
return false;
|
||||||
|
@ -1971,11 +1970,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(current_value);
|
MS_EXCEPTION_IF_NULL(current_value);
|
||||||
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
|
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(current_prim);
|
MS_EXCEPTION_IF_NULL(current_prim);
|
||||||
<<<<<<< HEAD
|
|
||||||
<<<<<<< HEAD
|
|
||||||
|
|
||||||
=======
|
|
||||||
>>>>>>> fix_cast_bug
|
|
||||||
// return -> cast
|
// return -> cast
|
||||||
if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
|
if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
|
||||||
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
|
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
|
||||||
|
@ -1983,8 +1978,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) {
|
||||||
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
=======
|
|
||||||
>>>>>>> 回退 'Pull Request !17 : [AutoParallel]Fix bug in the case of two cast'
|
|
||||||
// notice: the GetNext op has not input
|
// notice: the GetNext op has not input
|
||||||
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
|
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
|
||||||
MS_LOG(INFO) << "The loss is: " << current_prim->name();
|
MS_LOG(INFO) << "The loss is: " << current_prim->name();
|
||||||
|
|
|
@ -268,3 +268,32 @@ def test_cast_before_mirror3():
|
||||||
y = Tensor(np.ones([32, 64]), dtype=ms.float16)
|
y = Tensor(np.ones([32, 64]), dtype=ms.float16)
|
||||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||||
_executor.compile(net, x, y, b)
|
_executor.compile(net, x, y, b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mul_two_cast():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, strategy1, strategy2, strategy3):
|
||||||
|
super().__init__()
|
||||||
|
self.mul = P.Mul().set_strategy(strategy1)
|
||||||
|
self.mul2 = P.Mul().set_strategy(strategy2)
|
||||||
|
self.cast = P.Cast().set_strategy(strategy3)
|
||||||
|
self.cast2 = P.Cast().set_strategy(strategy3)
|
||||||
|
|
||||||
|
def construct(self, x, y, b):
|
||||||
|
out = self.mul(x, y)
|
||||||
|
out = self.mul2(out, b)
|
||||||
|
out = self.cast(out, ms.int32)
|
||||||
|
out = self.cast2(out, ms.bool_)
|
||||||
|
return out
|
||||||
|
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((2, 2), (2, 2))
|
||||||
|
strategy2 = ((8, 1), (8, 1))
|
||||||
|
strategy3 = ((8, 1), )
|
||||||
|
net = GradWrap(Net(strategy1, strategy2, strategy3))
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||||
|
|
||||||
|
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||||
|
b = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y, b)
|
||||||
|
|
Loading…
Reference in New Issue