!5722 fix semi auto parallel parameter of reshape has another user
Merge pull request !5722 from yao_yf/semi_auto_parallel_reshape_parameter_has_another_user
This commit is contained in:
commit
7786adc3aa
|
@ -1645,8 +1645,36 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node) {
|
||||
FuncGraphManagerPtr manager = node->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
AnfNodeIndexSet node_set = manager->node_users()[node];
|
||||
for (auto &node_pair : node_set) {
|
||||
CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
|
||||
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
|
||||
continue;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
||||
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node_prim);
|
||||
if ((node_prim->name() == DEPEND && node_pair.second != 1) || node_prim->name() == RESHAPE) {
|
||||
continue;
|
||||
}
|
||||
if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
|
||||
auto layout = GetInputLayoutFromCNode(node_pair);
|
||||
return std::make_shared<TensorLayout>(layout);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
|
||||
// Create DataParallel tensor layout for parameter(support WideDeep).
|
||||
auto next_layout = FindParameterNextLayout(node);
|
||||
if (next_layout != nullptr) {
|
||||
return next_layout;
|
||||
}
|
||||
CheckGlobalDeviceManager();
|
||||
int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size());
|
||||
TensorLayout input_tensor_layout;
|
||||
|
|
|
@ -156,6 +156,8 @@ using ParameterUsersInfo = std::pair<std::string, std::pair<AnfNodePtr, AnfNodeI
|
|||
|
||||
RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode);
|
||||
|
||||
std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node);
|
||||
|
||||
ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &));
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -292,3 +292,25 @@ def test_reshape_auto_6():
|
|||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
def test_reshape_auto_7():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.Mul().set_strategy(((1, 2, 4), (2, 4)))
|
||||
self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, x):
|
||||
weight = self.reshape(self.mul_weight, (1, 128, 96))
|
||||
out = self.mul(weight, self.mul_weight)
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([128, 28]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x)
|
||||
|
|
Loading…
Reference in New Issue