!96 fix refkey bug for auto parallel

Merge pull request !96 from lichen/fix_ref_key_bug_for_auto_parallel
This commit is contained in:
mindspore-ci-bot 2020-04-02 17:26:58 +08:00 committed by Gitee
commit d8b460c780
2 changed files with 50 additions and 7 deletions

View File

@ -49,6 +49,9 @@ namespace mindspore {
namespace parallel {
const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
// it will be one item in map with key: C, and value: (B, i)
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap;
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
if (new_node_input.empty()) {
@ -1085,11 +1088,19 @@ std::vector<Shapes> ExtractShape(const CNodePtr& node) {
std::vector<AnfNodePtr> all_inputs = node->inputs();
std::vector<AnfNodePtr> node_inputs{all_inputs.begin() + 1, all_inputs.end()};
for (auto& input : node_inputs) {
size_t inputs_size = all_inputs.size();
for (size_t i = 1; i < inputs_size; ++i) {
Shapes input_shapes;
AnfNodePtr input = all_inputs[i];
if (IsValueNode<RefKey>(input)) {
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
if (parameters.size() != 1) {
MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
}
std::pair<AnfNodePtr, int> node_pair = std::make_pair(node, SizeToInt(i));
g_RefMap[parameters[0]] = node_pair;
input_shapes = GetRefKeyNodeShape(input, func_graph);
} else if (IsValueNode<Tensor>(input) || input->isa<CNode>() || input->isa<Parameter>()) {
input_shapes = GetNodeShape(input);
@ -1205,14 +1216,20 @@ void CoverSliceShape(const FuncGraphPtr& root) {
auto parameters = root->parameters();
for (auto& parameter : parameters) {
MS_EXCEPTION_IF_NULL(parameter->Shape());
auto iter = g_RefMap.find(parameter);
if (iter != g_RefMap.end()) {
SetParallelShape(parameter, g_RefMap[parameter]);
continue;
}
std::pair<AnfNodePtr, int> res = FindSubGraph(root, parameter);
if (res.first == nullptr) {
MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape";
} else {
SetParallelShape(parameter, res);
MS_LOG(DEBUG) << "parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
}
}
g_RefMap.clear();
}
bool ParameterIsCloned(const FuncGraphPtr& root, const AnfNodePtr& parameter_node) {

View File

@ -13,14 +13,13 @@
# limitations under the License.
import numpy as np
from mindspore import context
import mindspore as ms
from mindspore import Parameter, Tensor, context
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
from tests.ut.python.ops.test_math_ops import VirtualLoss
import mindspore as ms
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.common.api import _executor
from tests.ut.python.ops.test_math_ops import VirtualLoss
class NetWithLoss(nn.Cell):
@ -470,3 +469,30 @@ def test_matmul_floordiv_broadcast2():
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
def test_assign_sub():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.assign_sub = P.AssignSub()
self.mul = P.Mul()
self.mul_weight = Parameter(Tensor(np.full([128, 32],
0.5, dtype=np.float32)),
name="mul_weight")
self.assignsub_weight = Parameter(Tensor(np.full([128, 32],
1.1, dtype=np.float32)),
name="assignsub_weight")
def construct(self, x, y, z):
out = self.mul(x, self.mul_weight)
out = self.assign_sub(self.assignsub_weight, out)
return out
context.set_auto_parallel_context(device_num=64, global_rank=15)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net = GradWrap(NetWithLoss(Net()))
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
z = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x, y, z)