fix refkey bug for auto parallel

This commit is contained in:
lichenever 2020-04-02 11:14:45 +08:00
parent a44b5293de
commit 5240b1f603
2 changed files with 50 additions and 7 deletions

View File

@ -49,6 +49,9 @@ namespace mindspore {
namespace parallel {
const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
// it will be one item in map with key: C, and value: (B, i)
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap;
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
if (new_node_input.empty()) {
@ -1085,11 +1088,19 @@ std::vector<Shapes> ExtractShape(const CNodePtr& node) {
std::vector<AnfNodePtr> all_inputs = node->inputs();
std::vector<AnfNodePtr> node_inputs{all_inputs.begin() + 1, all_inputs.end()};
for (auto& input : node_inputs) {
size_t inputs_size = all_inputs.size();
for (size_t i = 1; i < inputs_size; ++i) {
Shapes input_shapes;
AnfNodePtr input = all_inputs[i];
if (IsValueNode<RefKey>(input)) {
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
if (parameters.size() != 1) {
MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
}
std::pair<AnfNodePtr, int> node_pair = std::make_pair(node, SizeToInt(i));
g_RefMap[parameters[0]] = node_pair;
input_shapes = GetRefKeyNodeShape(input, func_graph);
} else if (IsValueNode<Tensor>(input) || input->isa<CNode>() || input->isa<Parameter>()) {
input_shapes = GetNodeShape(input);
@ -1205,14 +1216,20 @@ void CoverSliceShape(const FuncGraphPtr& root) {
auto parameters = root->parameters();
for (auto& parameter : parameters) {
MS_EXCEPTION_IF_NULL(parameter->Shape());
auto iter = g_RefMap.find(parameter);
if (iter != g_RefMap.end()) {
SetParallelShape(parameter, g_RefMap[parameter]);
continue;
}
std::pair<AnfNodePtr, int> res = FindSubGraph(root, parameter);
if (res.first == nullptr) {
MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape";
} else {
SetParallelShape(parameter, res);
MS_LOG(DEBUG) << "parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
}
}
g_RefMap.clear();
}
bool ParameterIsCloned(const FuncGraphPtr& root, const AnfNodePtr& parameter_node) {

View File

@ -13,14 +13,13 @@
# limitations under the License.
import numpy as np
from mindspore import context
import mindspore as ms
from mindspore import Parameter, Tensor, context
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
from tests.ut.python.ops.test_math_ops import VirtualLoss
import mindspore as ms
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.common.api import _executor
from tests.ut.python.ops.test_math_ops import VirtualLoss
class NetWithLoss(nn.Cell):
@ -470,3 +469,30 @@ def test_matmul_floordiv_broadcast2():
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
def test_assign_sub():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.assign_sub = P.AssignSub()
self.mul = P.Mul()
self.mul_weight = Parameter(Tensor(np.full([128, 32],
0.5, dtype=np.float32)),
name="mul_weight")
self.assignsub_weight = Parameter(Tensor(np.full([128, 32],
1.1, dtype=np.float32)),
name="assignsub_weight")
def construct(self, x, y, z):
out = self.mul(x, self.mul_weight)
out = self.assign_sub(self.assignsub_weight, out)
return out
context.set_auto_parallel_context(device_num=64, global_rank=15)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net = GradWrap(NetWithLoss(Net()))
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
z = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x, y, z)