modify interface name for shard

This commit is contained in:
wangjun 2022-03-24 21:12:41 +08:00
parent aba48e688b
commit 789539cbaa
6 changed files with 112 additions and 107 deletions

View File

@ -366,26 +366,26 @@ bool CheckDeviceNum(const std::vector<std::vector<int64_t>> &strategies, const i
return true;
}
void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_axes, const int64_t &device_num) {
auto out_axes_tuple = out_axes->cast<ValueNodePtr>();
void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_strategy, const int64_t &device_num) {
auto out_strategy_tuple = out_strategy->cast<ValueNodePtr>();
bool need_default_strategy = false;
size_t out_axes_size = 0;
if (!IsValueNode<ValueTuple>(out_axes_tuple) ||
!CheckLayout(out_axes_tuple, &need_default_strategy, &out_axes_size)) {
MS_LOG(EXCEPTION) << "out_axes should be a two-dimension tuple";
size_t out_strategy_size = 0;
if (!IsValueNode<ValueTuple>(out_strategy_tuple) ||
!CheckLayout(out_strategy_tuple, &need_default_strategy, &out_strategy_size)) {
MS_LOG(EXCEPTION) << "out_strategy should be a two-dimension tuple";
}
std::vector<AnfNodePtr> output_nodes;
GetOutputNodes(func_graph, &output_nodes);
if (output_nodes.size() != out_axes_size) {
if (output_nodes.size() != out_strategy_size) {
MS_LOG(EXCEPTION) << "Output number: " << output_nodes.size()
<< " is not equal to out_axes number: " << out_axes_size;
<< " is not equal to out_strategy number: " << out_strategy_size;
}
std::vector<std::vector<int64_t>> output_strategy;
if (need_default_strategy) {
GenerateDefaultStrategy(out_axes_tuple, output_nodes, device_num, &output_strategy);
GenerateDefaultStrategy(out_strategy_tuple, output_nodes, device_num, &output_strategy);
} else {
output_strategy = GetValue<std::vector<std::vector<int64_t>>>(out_axes_tuple->value());
output_strategy = GetValue<std::vector<std::vector<int64_t>>>(out_strategy_tuple->value());
}
MS_LOG(WARNING) << "The output strategy will be overwritten as data-parallel";
@ -394,7 +394,8 @@ void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_axes,
auto output_shape = common::AnfAlgo::GetOutputInferShape(node, 0);
if (output_shape.size() != output_strategy[i].size()) {
MS_LOG(EXCEPTION) << "Output dimension: " << output_shape.size()
<< " is not equal to out_axes dimension: " << output_strategy[i].size() << " at index " << i;
<< " is not equal to out_strategy dimension: " << output_strategy[i].size() << " at index "
<< i;
}
std::vector<ValuePtr> elements;
elements.push_back(MakeValue(output_strategy[i]));
@ -430,24 +431,25 @@ std::vector<ValuePtr> GetStrategyElements(const CNodePtr &cnode, const std::vect
return elements;
}
void SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_axes, const int64_t &device_num) {
auto in_axes_tuple = in_axes->cast<ValueNodePtr>();
void SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_strategy, const int64_t &device_num) {
auto in_strategy_tuple = in_strategy->cast<ValueNodePtr>();
bool need_default_strategy = false;
size_t in_axes_size = 0;
if (!IsValueNode<ValueTuple>(in_axes_tuple) || !CheckLayout(in_axes_tuple, &need_default_strategy, &in_axes_size)) {
MS_LOG(EXCEPTION) << "in_axes should be a two-dimension tuple";
size_t in_strategy_size = 0;
if (!IsValueNode<ValueTuple>(in_strategy_tuple) ||
!CheckLayout(in_strategy_tuple, &need_default_strategy, &in_strategy_size)) {
MS_LOG(EXCEPTION) << "in_strategy should be a two-dimension tuple";
}
std::vector<AnfNodePtr> input_nodes;
GetInputNodes(func_graph, &input_nodes);
if (input_nodes.size() != in_axes_size) {
if (input_nodes.size() != in_strategy_size) {
MS_LOG(EXCEPTION) << "Input numbers: " << input_nodes.size()
<< " is not equal to in_axes numbers: " << in_axes_size;
<< " is not equal to in_strategy numbers: " << in_strategy_size;
}
std::vector<std::vector<int64_t>> input_strategy;
if (need_default_strategy) {
GenerateDefaultStrategy(in_axes_tuple, input_nodes, device_num, &input_strategy);
GenerateDefaultStrategy(in_strategy_tuple, input_nodes, device_num, &input_strategy);
} else {
input_strategy = GetValue<std::vector<std::vector<int64_t>>>(in_axes_tuple->value());
input_strategy = GetValue<std::vector<std::vector<int64_t>>>(in_strategy_tuple->value());
}
if (!CheckDeviceNum(input_strategy, device_num)) {
MS_LOG(EXCEPTION) << "check device number failed";
@ -463,7 +465,7 @@ void SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_axes, c
auto output_shape = common::AnfAlgo::GetOutputInferShape(parameter, 0);
if (output_shape.size() != input_strategy[i].size()) {
MS_LOG(EXCEPTION) << "Input dimension: " << output_shape.size()
<< " is not equal to in_axes dimension: " << input_strategy[i].size() << " at index " << i;
<< " is not equal to in_strategy dimension: " << input_strategy[i].size() << " at index " << i;
}
AnfNodeIndexSet param_sub_set = manager->node_users()[parameter];
for (auto &param_pair : param_sub_set) {
@ -492,13 +494,13 @@ void SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfNodePtr>
root->set_flag("auto_parallel", true);
auto cnode = node->cast<CNodePtr>();
auto vnode = cnode->input(1)->cast<ValueNodePtr>();
auto in_axes = cnode->input(2);
auto out_axes = cnode->input(3);
auto in_strategy = cnode->input(2);
auto out_strategy = cnode->input(3);
ScopeGuard scope_guard(vnode->scope());
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
MS_EXCEPTION_IF_NULL(func_graph);
SetInputLayout(func_graph, in_axes, device_num);
SetOutputLayout(func_graph, out_axes, device_num);
SetInputLayout(func_graph, in_strategy, device_num);
SetOutputLayout(func_graph, out_strategy, device_num);
}
}
}

View File

@ -81,6 +81,19 @@ class Cell(Cell_):
[Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
"""
class CellGuard:
"""Detecting whether the cell is a top-level cell with the 'with statement'."""
def __enter__(self):
"""Enter cell and increase recursion depth count."""
_pynative_executor.set_lazy_build(True)
_pynative_executor.enter_cell()
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit cell and decrease recursion depth count."""
_pynative_executor.exit_cell()
if _pynative_executor.is_top_cell():
_pynative_executor.set_lazy_build(False)
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
'_parameter_layout_dict', '_params_list', '_tensor_list', '_phase', '_auto_parallel_mode',
@ -482,11 +495,11 @@ class Cell(Cell_):
for prim in all_prims:
prim.add_prim_attr("strategy_gen_mode", "data_parallel")
def shard(self, in_axes, out_axes, device="Ascend", level=0):
def shard(self, in_strategy, out_strategy, device="Ascend", level=0):
"""
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
generated by sharding propagation. In_axes and out_axes define the input and output layout respectively.
In_axes/Out_axes should be a tuple each element of which corresponds to the desired layout of
generated by sharding propagation. in_strategy and out_strategy define the input and output layout respectively.
in_strategy/out_strategy should be a tuple each element of which corresponds to the desired layout of
this input/output and None represents data_parallel.
Note:
@ -494,9 +507,9 @@ class Cell(Cell_):
search_mode in auto_parallel_context set as sharding_propagation.
Args:
in_axes (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple
defines the layout of the corresponding input and None represents a data parallel strategy.
out_axes (tuple): Define the layout of outputs similar with in_axes.
out_strategy (tuple): Define the layout of outputs similar with in_strategy.
device (string): Select a certain device target. It is not in use right now.
Support ["CPU", "GPU", "Ascend"]. Default: "Ascend".
level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
@ -522,30 +535,17 @@ class Cell(Cell_):
... def __init__(self):
... self.block1 = Block()
... self.block2 = Block()
... self.block2.shard(in_axes=((2, 1),), out_axes=(None,))
... def construct(self, x):
... x = self.block1(x)
... x = self.block2(x)
... return x
... self.block2.shard(in_strategy=((2, 1),), out_strategy=(None,))
... def construct(self, x):
... x = self.block1(x)
... x = self.block2(x)
... return x
"""
shard_fn = Shard()
fn = shard_fn(self, in_axes, out_axes, device, level)
fn = shard_fn(self, in_strategy, out_strategy, device, level)
object.__setattr__(self, "_shard_fn", fn)
return self
class CellGuard:
"""Detecting whether the cell is a top-level cell with the 'with statement'."""
def __enter__(self):
"""Enter cell and increase recursion depth count."""
_pynative_executor.set_lazy_build(True)
_pynative_executor.enter_cell()
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit cell and decrease recursion depth count."""
_pynative_executor.exit_cell()
if _pynative_executor.is_top_cell():
_pynative_executor.set_lazy_build(False)
def auto_cast_inputs(self, inputs):
"""Auto cast inputs in mixed precision scenarios."""
cast_inputs = inputs

View File

@ -800,12 +800,12 @@ class Shard(Shard_):
Shard_.__init__(self, 'Shard')
self.shard_fn = None
self.fn = None
self.in_axes = None
self.out_axes = None
self.in_strategy = None
self.out_strategy = None
self.device = None
self.level = None
def __call__(self, fn, in_axes, out_axes, device="Ascend", level=0):
def __call__(self, fn, in_strategy, out_strategy, device="Ascend", level=0):
if context.get_context("mode") != context.PYNATIVE_MODE or \
context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel"]:
raise AssertionError(f"'Shard' only supports auto parallel under PyNative mode")
@ -815,30 +815,30 @@ class Shard(Shard_):
raise AssertionError(f"'Shard' doesn't support 'full_batch'. Please set 'full_batch' as False")
if context.get_auto_parallel_context("search_mode") != "sharding_propagation":
raise AssertionError(f"'search_mode' must be 'sharding_propagation' for 'Shard'")
if not isinstance(in_axes, tuple):
raise TypeError(f"For 'Shard', the 'in_axes' should be a tuple, but got {type(in_axes).__name__}")
if not isinstance(out_axes, tuple):
raise TypeError(f"For 'Shard', the 'out_axes' should be a tuple, "
f"but got {type(out_axes).__name__}")
if not isinstance(in_strategy, tuple):
raise TypeError(f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}")
if not isinstance(out_strategy, tuple):
raise TypeError(f"For 'Shard', the 'out_strategy' should be a tuple, "
f"but got {type(out_strategy).__name__}")
if not isinstance(device, str):
raise TypeError(f"For 'Shard', the 'device' should be a string, "
f"but got {type(device).__name__}")
if not isinstance(level, int):
raise TypeError(f"For 'Shard', the 'level' should be an integer, "
f"but got {type(level).__name__}")
if self.shard_fn is not None and self.fn == fn and self.in_axes == in_axes and self.out_axes == out_axes and \
self.device == device and self.level == level:
if self.shard_fn is not None and self.fn == fn and self.in_strategy == in_strategy and \
self.out_strategy == out_strategy and self.device == device and self.level == level:
return self.shard_fn
shard_ = Shard()
@ms_function(obj=fn)
def after_shard(*args):
return shard_(fn, in_axes, out_axes, device, level)(*args)
return shard_(fn, in_strategy, out_strategy, device, level)(*args)
self.shard_fn = after_shard
self.fn = fn
self.in_axes = in_axes
self.out_axes = out_axes
self.in_strategy = in_strategy
self.out_strategy = out_strategy
self.device = device
self.level = level
return self.shard_fn

View File

@ -582,8 +582,8 @@ def vjp(fn, inputs, v):
shard_fn = Shard()
def shard(fn, in_axes, out_axes, device="Ascend", level=0):
return shard_fn(fn, in_axes, out_axes, device, level)
def shard(fn, in_strategy, out_strategy, device="Ascend", level=0):
return shard_fn(fn, in_strategy, out_strategy, device, level)
def arange(start=0, stop=None, step=1, rtype=None):

View File

@ -183,7 +183,7 @@ class ResNet(nn.Cell):
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0])
self.layer1.shard(in_axes=(None,), out_axes=(None,))
self.layer1.shard(in_strategy=(None,), out_strategy=(None,))
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
@ -194,7 +194,7 @@ class ResNet(nn.Cell):
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2])
self.layer3.shard(in_axes=((8, 1, 1, 1),), out_axes=(None,))
self.layer3.shard(in_strategy=((8, 1, 1, 1),), out_strategy=(None,))
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
@ -205,7 +205,7 @@ class ResNet(nn.Cell):
self.end_point = nn.Dense(2048, num_classes, has_bias=True,
weight_init=weight_variable(),
bias_init=weight_variable()).add_flags_recursive(fp16=True)
self.head = F.shard(self.end_point, in_axes=((1, 8),), out_axes=(None,))
self.head = F.shard(self.end_point, in_strategy=((1, 8),), out_strategy=(None,))
self.squeeze = P.Squeeze()
self.cast = P.Cast()
@ -376,7 +376,7 @@ def test_train_feed(num_classes=65536):
dataset = ds.GeneratorDataset(dataset, column_names=["image", "label"])
net = resnet50(num_classes)
loss = SoftmaxCrossEntropyExpand(sparse=True)
loss.shard(in_axes=(None, None), out_axes=(None,))
loss.shard(in_strategy=(None, None), out_strategy=(None,))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
model = Model(net, loss_fn=loss, optimizer=opt)
model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback)

View File

@ -42,21 +42,21 @@ class NetMatMul(nn.Cell):
return self.matmul(x, y)
class Net(nn.Cell):
def __init__(self, in_axes, out_axes):
def __init__(self, in_strategy, out_strategy):
super().__init__()
self.mul_net = NetMul()
self.matmul_net = NetMatMul()
self.mul_net.shard(in_axes=in_axes, out_axes=out_axes)
self.mul_net.shard(in_strategy=in_strategy, out_strategy=out_strategy)
def construct(self, x, y):
out1 = self.matmul_net(x, y)
out2 = self.matmul_net(x, y)
return self.mul_net(out1, out2)
def cell_shard_execution(in_axes, out_axes, error_log):
net = Net(in_axes, out_axes)
def cell_shard_execution(in_strategy, out_strategy, error_log):
net = Net(in_strategy, out_strategy)
x = Tensor(np.ones([128, 128]), dtype=ms.float32)
y = Tensor(np.ones([128, 128]), dtype=ms.float32)
@ -65,63 +65,66 @@ def cell_shard_execution(in_axes, out_axes, error_log):
assert error_log in str(err.value)
def test_in_axes_numbers_check():
def test_in_strategy_numbers_check():
"""
Feature: shard function for cell
Description: inconsistent input number and in_axes number
Expectation: throw an exception indicating inconsistent input number and in_axes number
Description: inconsistent input number and in_strategy number
Expectation: throw an exception indicating inconsistent input number and in_strategy number
"""
set_context()
in_axes = ((8, 1), None, (1, 8))
out_axes = (None,)
error_log = "Input numbers: 2 is not equal to in_axes numbers: 3"
cell_shard_execution(in_axes, out_axes, error_log)
in_strategy = ((8, 1), None, (1, 8))
out_strategy = (None,)
error_log = "Input numbers: 2 is not equal to in_strategy numbers: 3"
cell_shard_execution(in_strategy, out_strategy, error_log)
def test_out_axes_numbers_check():
def test_out_strategy_numbers_check():
"""
Feature: shard function for cell
Description: inconsistent output number and out_axes number
Expectation: throw an exception indicating inconsistent output number and out_axes number
Description: inconsistent output number and out_strategy number
Expectation: throw an exception indicating inconsistent output number and out_strategy number
"""
set_context()
in_axes = ((8, 1), None)
out_axes = (None, (8, 1))
error_log = "Output number: 1 is not equal to out_axes number: 2"
cell_shard_execution(in_axes, out_axes, error_log)
in_strategy = ((8, 1), None)
out_strategy = (None, (8, 1))
error_log = "Output number: 1 is not equal to out_strategy number: 2"
cell_shard_execution(in_strategy, out_strategy, error_log)
def test_in_axes_dimension_check():
def test_in_strategy_dimension_check():
"""
Feature: shard function for cell
Description: inconsistent input dimension and in_axes dimension
Expectation: throw an exception indicating inconsistent input_dimension and in_axes dimension
Description: inconsistent input dimension and in_strategy dimension
Expectation: throw an exception indicating inconsistent input_dimension and in_strategy dimension
"""
set_context()
in_axes = ((8, 1, 1), None)
out_axes = (None, (8, 1))
error_log = "Input dimension: 2 is not equal to in_axes dimension: 3 at index 0"
cell_shard_execution(in_axes, out_axes, error_log)
in_strategy = ((8, 1, 1), None)
out_strategy = (None, (8, 1))
error_log = "Input dimension: 2 is not equal to in_strategy dimension: 3 at index 0"
cell_shard_execution(in_strategy, out_strategy, error_log)
def test_out_axes_dimension_check():
def test_out_strategy_dimension_check():
"""
Feature: shard function for cell
Description: inconsistent output dimension and out_axes dimension
Expectation: throw an exception indicating inconsistent output_dimension and out_axes dimension
Description: inconsistent output dimension and out_strategy dimension
Expectation: throw an exception indicating inconsistent output_dimension and out_strategy dimension
"""
set_context()
in_axes = ((8, 1), None)
out_axes = ((8,),)
error_log = "Output dimension: 2 is not equal to out_axes dimension: 1 at index 0"
cell_shard_execution(in_axes, out_axes, error_log)
in_strategy = ((8, 1), None)
out_strategy = ((8,),)
error_log = "Output dimension: 2 is not equal to out_strategy dimension: 1 at index 0"
cell_shard_execution(in_strategy, out_strategy, error_log)
def test_in_axes_format_check():
def test_in_strategy_format_check():
"""
Feature: shard function for cell
Description: unsupported in_axes format
Expectation: throw an exception indicating an supported in_axes format
Description: unsupported in_strategy format
Expectation: throw an exception indicating an supported in_strategy format
"""
set_context()
in_axes = ([8, 1], None)
out_axes = (None,)
error_log = "in_axes should be a two-dimension tuple"
cell_shard_execution(in_axes, out_axes, error_log)
in_strategy = ([8, 1], None)
out_strategy = (None,)
error_log = "in_strategy should be a two-dimension tuple"
cell_shard_execution(in_strategy, out_strategy, error_log)