forked from mindspore-Ecosystem/mindspore
fix refkey bug for auto parallel
This commit is contained in:
parent
a44b5293de
commit
5240b1f603
|
@ -49,6 +49,9 @@ namespace mindspore {
|
|||
namespace parallel {
|
||||
const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
|
||||
const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
|
||||
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
|
||||
// it will be one item in map with key: C, and value: (B, i)
|
||||
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap;
|
||||
|
||||
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
|
||||
if (new_node_input.empty()) {
|
||||
|
@ -1085,11 +1088,19 @@ std::vector<Shapes> ExtractShape(const CNodePtr& node) {
|
|||
std::vector<AnfNodePtr> all_inputs = node->inputs();
|
||||
std::vector<AnfNodePtr> node_inputs{all_inputs.begin() + 1, all_inputs.end()};
|
||||
|
||||
for (auto& input : node_inputs) {
|
||||
size_t inputs_size = all_inputs.size();
|
||||
for (size_t i = 1; i < inputs_size; ++i) {
|
||||
Shapes input_shapes;
|
||||
AnfNodePtr input = all_inputs[i];
|
||||
if (IsValueNode<RefKey>(input)) {
|
||||
auto func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
|
||||
if (parameters.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
|
||||
}
|
||||
std::pair<AnfNodePtr, int> node_pair = std::make_pair(node, SizeToInt(i));
|
||||
g_RefMap[parameters[0]] = node_pair;
|
||||
input_shapes = GetRefKeyNodeShape(input, func_graph);
|
||||
} else if (IsValueNode<Tensor>(input) || input->isa<CNode>() || input->isa<Parameter>()) {
|
||||
input_shapes = GetNodeShape(input);
|
||||
|
@ -1205,14 +1216,20 @@ void CoverSliceShape(const FuncGraphPtr& root) {
|
|||
auto parameters = root->parameters();
|
||||
for (auto& parameter : parameters) {
|
||||
MS_EXCEPTION_IF_NULL(parameter->Shape());
|
||||
auto iter = g_RefMap.find(parameter);
|
||||
if (iter != g_RefMap.end()) {
|
||||
SetParallelShape(parameter, g_RefMap[parameter]);
|
||||
continue;
|
||||
}
|
||||
std::pair<AnfNodePtr, int> res = FindSubGraph(root, parameter);
|
||||
if (res.first == nullptr) {
|
||||
MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape";
|
||||
} else {
|
||||
SetParallelShape(parameter, res);
|
||||
MS_LOG(DEBUG) << "parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
|
||||
MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
|
||||
}
|
||||
}
|
||||
g_RefMap.clear();
|
||||
}
|
||||
|
||||
bool ParameterIsCloned(const FuncGraphPtr& root, const AnfNodePtr& parameter_node) {
|
||||
|
|
|
@ -13,14 +13,13 @@
|
|||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
import mindspore as ms
|
||||
from mindspore import Parameter, Tensor, context
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
import mindspore as ms
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.api import _executor
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
|
@ -470,3 +469,30 @@ def test_matmul_floordiv_broadcast2():
|
|||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_assign_sub():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.assign_sub = P.AssignSub()
|
||||
self.mul = P.Mul()
|
||||
self.mul_weight = Parameter(Tensor(np.full([128, 32],
|
||||
0.5, dtype=np.float32)),
|
||||
name="mul_weight")
|
||||
self.assignsub_weight = Parameter(Tensor(np.full([128, 32],
|
||||
1.1, dtype=np.float32)),
|
||||
name="assignsub_weight")
|
||||
|
||||
def construct(self, x, y, z):
|
||||
out = self.mul(x, self.mul_weight)
|
||||
out = self.assign_sub(self.assignsub_weight, out)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=64, global_rank=15)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
z = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, z)
|
||||
|
|
Loading…
Reference in New Issue