forked from mindspore-Ecosystem/mindspore
modify interface name for shard
This commit is contained in:
parent
aba48e688b
commit
789539cbaa
|
@ -366,26 +366,26 @@ bool CheckDeviceNum(const std::vector<std::vector<int64_t>> &strategies, const i
|
|||
return true;
|
||||
}
|
||||
|
||||
void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_axes, const int64_t &device_num) {
|
||||
auto out_axes_tuple = out_axes->cast<ValueNodePtr>();
|
||||
void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_strategy, const int64_t &device_num) {
|
||||
auto out_strategy_tuple = out_strategy->cast<ValueNodePtr>();
|
||||
bool need_default_strategy = false;
|
||||
size_t out_axes_size = 0;
|
||||
if (!IsValueNode<ValueTuple>(out_axes_tuple) ||
|
||||
!CheckLayout(out_axes_tuple, &need_default_strategy, &out_axes_size)) {
|
||||
MS_LOG(EXCEPTION) << "out_axes should be a two-dimension tuple";
|
||||
size_t out_strategy_size = 0;
|
||||
if (!IsValueNode<ValueTuple>(out_strategy_tuple) ||
|
||||
!CheckLayout(out_strategy_tuple, &need_default_strategy, &out_strategy_size)) {
|
||||
MS_LOG(EXCEPTION) << "out_strategy should be a two-dimension tuple";
|
||||
}
|
||||
std::vector<AnfNodePtr> output_nodes;
|
||||
GetOutputNodes(func_graph, &output_nodes);
|
||||
if (output_nodes.size() != out_axes_size) {
|
||||
if (output_nodes.size() != out_strategy_size) {
|
||||
MS_LOG(EXCEPTION) << "Output number: " << output_nodes.size()
|
||||
<< " is not equal to out_axes number: " << out_axes_size;
|
||||
<< " is not equal to out_strategy number: " << out_strategy_size;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> output_strategy;
|
||||
if (need_default_strategy) {
|
||||
GenerateDefaultStrategy(out_axes_tuple, output_nodes, device_num, &output_strategy);
|
||||
GenerateDefaultStrategy(out_strategy_tuple, output_nodes, device_num, &output_strategy);
|
||||
} else {
|
||||
output_strategy = GetValue<std::vector<std::vector<int64_t>>>(out_axes_tuple->value());
|
||||
output_strategy = GetValue<std::vector<std::vector<int64_t>>>(out_strategy_tuple->value());
|
||||
}
|
||||
MS_LOG(WARNING) << "The output strategy will be overwritten as data-parallel";
|
||||
|
||||
|
@ -394,7 +394,8 @@ void SetOutputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &out_axes,
|
|||
auto output_shape = common::AnfAlgo::GetOutputInferShape(node, 0);
|
||||
if (output_shape.size() != output_strategy[i].size()) {
|
||||
MS_LOG(EXCEPTION) << "Output dimension: " << output_shape.size()
|
||||
<< " is not equal to out_axes dimension: " << output_strategy[i].size() << " at index " << i;
|
||||
<< " is not equal to out_strategy dimension: " << output_strategy[i].size() << " at index "
|
||||
<< i;
|
||||
}
|
||||
std::vector<ValuePtr> elements;
|
||||
elements.push_back(MakeValue(output_strategy[i]));
|
||||
|
@ -430,24 +431,25 @@ std::vector<ValuePtr> GetStrategyElements(const CNodePtr &cnode, const std::vect
|
|||
return elements;
|
||||
}
|
||||
|
||||
void SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_axes, const int64_t &device_num) {
|
||||
auto in_axes_tuple = in_axes->cast<ValueNodePtr>();
|
||||
void SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_strategy, const int64_t &device_num) {
|
||||
auto in_strategy_tuple = in_strategy->cast<ValueNodePtr>();
|
||||
bool need_default_strategy = false;
|
||||
size_t in_axes_size = 0;
|
||||
if (!IsValueNode<ValueTuple>(in_axes_tuple) || !CheckLayout(in_axes_tuple, &need_default_strategy, &in_axes_size)) {
|
||||
MS_LOG(EXCEPTION) << "in_axes should be a two-dimension tuple";
|
||||
size_t in_strategy_size = 0;
|
||||
if (!IsValueNode<ValueTuple>(in_strategy_tuple) ||
|
||||
!CheckLayout(in_strategy_tuple, &need_default_strategy, &in_strategy_size)) {
|
||||
MS_LOG(EXCEPTION) << "in_strategy should be a two-dimension tuple";
|
||||
}
|
||||
std::vector<AnfNodePtr> input_nodes;
|
||||
GetInputNodes(func_graph, &input_nodes);
|
||||
if (input_nodes.size() != in_axes_size) {
|
||||
if (input_nodes.size() != in_strategy_size) {
|
||||
MS_LOG(EXCEPTION) << "Input numbers: " << input_nodes.size()
|
||||
<< " is not equal to in_axes numbers: " << in_axes_size;
|
||||
<< " is not equal to in_strategy numbers: " << in_strategy_size;
|
||||
}
|
||||
std::vector<std::vector<int64_t>> input_strategy;
|
||||
if (need_default_strategy) {
|
||||
GenerateDefaultStrategy(in_axes_tuple, input_nodes, device_num, &input_strategy);
|
||||
GenerateDefaultStrategy(in_strategy_tuple, input_nodes, device_num, &input_strategy);
|
||||
} else {
|
||||
input_strategy = GetValue<std::vector<std::vector<int64_t>>>(in_axes_tuple->value());
|
||||
input_strategy = GetValue<std::vector<std::vector<int64_t>>>(in_strategy_tuple->value());
|
||||
}
|
||||
if (!CheckDeviceNum(input_strategy, device_num)) {
|
||||
MS_LOG(EXCEPTION) << "check device number failed";
|
||||
|
@ -463,7 +465,7 @@ void SetInputLayout(const FuncGraphPtr &func_graph, const AnfNodePtr &in_axes, c
|
|||
auto output_shape = common::AnfAlgo::GetOutputInferShape(parameter, 0);
|
||||
if (output_shape.size() != input_strategy[i].size()) {
|
||||
MS_LOG(EXCEPTION) << "Input dimension: " << output_shape.size()
|
||||
<< " is not equal to in_axes dimension: " << input_strategy[i].size() << " at index " << i;
|
||||
<< " is not equal to in_strategy dimension: " << input_strategy[i].size() << " at index " << i;
|
||||
}
|
||||
AnfNodeIndexSet param_sub_set = manager->node_users()[parameter];
|
||||
for (auto ¶m_pair : param_sub_set) {
|
||||
|
@ -492,13 +494,13 @@ void SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfNodePtr>
|
|||
root->set_flag("auto_parallel", true);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto vnode = cnode->input(1)->cast<ValueNodePtr>();
|
||||
auto in_axes = cnode->input(2);
|
||||
auto out_axes = cnode->input(3);
|
||||
auto in_strategy = cnode->input(2);
|
||||
auto out_strategy = cnode->input(3);
|
||||
ScopeGuard scope_guard(vnode->scope());
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
SetInputLayout(func_graph, in_axes, device_num);
|
||||
SetOutputLayout(func_graph, out_axes, device_num);
|
||||
SetInputLayout(func_graph, in_strategy, device_num);
|
||||
SetOutputLayout(func_graph, out_strategy, device_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -81,6 +81,19 @@ class Cell(Cell_):
|
|||
[Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
|
||||
"""
|
||||
|
||||
class CellGuard:
|
||||
"""Detecting whether the cell is a top-level cell with the 'with statement'."""
|
||||
def __enter__(self):
|
||||
"""Enter cell and increase recursion depth count."""
|
||||
_pynative_executor.set_lazy_build(True)
|
||||
_pynative_executor.enter_cell()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit cell and decrease recursion depth count."""
|
||||
_pynative_executor.exit_cell()
|
||||
if _pynative_executor.is_top_cell():
|
||||
_pynative_executor.set_lazy_build(False)
|
||||
|
||||
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
|
||||
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
|
||||
'_parameter_layout_dict', '_params_list', '_tensor_list', '_phase', '_auto_parallel_mode',
|
||||
|
@ -482,11 +495,11 @@ class Cell(Cell_):
|
|||
for prim in all_prims:
|
||||
prim.add_prim_attr("strategy_gen_mode", "data_parallel")
|
||||
|
||||
def shard(self, in_axes, out_axes, device="Ascend", level=0):
|
||||
def shard(self, in_strategy, out_strategy, device="Ascend", level=0):
|
||||
"""
|
||||
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
|
||||
generated by sharding propagation. In_axes and out_axes define the input and output layout respectively.
|
||||
In_axes/Out_axes should be a tuple each element of which corresponds to the desired layout of
|
||||
generated by sharding propagation. in_strategy and out_strategy define the input and output layout respectively.
|
||||
in_strategy/out_strategy should be a tuple each element of which corresponds to the desired layout of
|
||||
this input/output and None represents data_parallel.
|
||||
|
||||
Note:
|
||||
|
@ -494,9 +507,9 @@ class Cell(Cell_):
|
|||
search_mode in auto_parallel_context set as sharding_propagation.
|
||||
|
||||
Args:
|
||||
in_axes (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple
|
||||
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple
|
||||
defines the layout of the corresponding input and None represents a data parallel strategy.
|
||||
out_axes (tuple): Define the layout of outputs similar with in_axes.
|
||||
out_strategy (tuple): Define the layout of outputs similar with in_strategy.
|
||||
device (string): Select a certain device target. It is not in use right now.
|
||||
Support ["CPU", "GPU", "Ascend"]. Default: "Ascend".
|
||||
level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
|
||||
|
@ -522,30 +535,17 @@ class Cell(Cell_):
|
|||
... def __init__(self):
|
||||
... self.block1 = Block()
|
||||
... self.block2 = Block()
|
||||
... self.block2.shard(in_axes=((2, 1),), out_axes=(None,))
|
||||
... def construct(self, x):
|
||||
... x = self.block1(x)
|
||||
... x = self.block2(x)
|
||||
... return x
|
||||
... self.block2.shard(in_strategy=((2, 1),), out_strategy=(None,))
|
||||
... def construct(self, x):
|
||||
... x = self.block1(x)
|
||||
... x = self.block2(x)
|
||||
... return x
|
||||
"""
|
||||
shard_fn = Shard()
|
||||
fn = shard_fn(self, in_axes, out_axes, device, level)
|
||||
fn = shard_fn(self, in_strategy, out_strategy, device, level)
|
||||
object.__setattr__(self, "_shard_fn", fn)
|
||||
return self
|
||||
|
||||
class CellGuard:
|
||||
"""Detecting whether the cell is a top-level cell with the 'with statement'."""
|
||||
def __enter__(self):
|
||||
"""Enter cell and increase recursion depth count."""
|
||||
_pynative_executor.set_lazy_build(True)
|
||||
_pynative_executor.enter_cell()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit cell and decrease recursion depth count."""
|
||||
_pynative_executor.exit_cell()
|
||||
if _pynative_executor.is_top_cell():
|
||||
_pynative_executor.set_lazy_build(False)
|
||||
|
||||
def auto_cast_inputs(self, inputs):
|
||||
"""Auto cast inputs in mixed precision scenarios."""
|
||||
cast_inputs = inputs
|
||||
|
|
|
@ -800,12 +800,12 @@ class Shard(Shard_):
|
|||
Shard_.__init__(self, 'Shard')
|
||||
self.shard_fn = None
|
||||
self.fn = None
|
||||
self.in_axes = None
|
||||
self.out_axes = None
|
||||
self.in_strategy = None
|
||||
self.out_strategy = None
|
||||
self.device = None
|
||||
self.level = None
|
||||
|
||||
def __call__(self, fn, in_axes, out_axes, device="Ascend", level=0):
|
||||
def __call__(self, fn, in_strategy, out_strategy, device="Ascend", level=0):
|
||||
if context.get_context("mode") != context.PYNATIVE_MODE or \
|
||||
context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel"]:
|
||||
raise AssertionError(f"'Shard' only supports auto parallel under PyNative mode")
|
||||
|
@ -815,30 +815,30 @@ class Shard(Shard_):
|
|||
raise AssertionError(f"'Shard' doesn't support 'full_batch'. Please set 'full_batch' as False")
|
||||
if context.get_auto_parallel_context("search_mode") != "sharding_propagation":
|
||||
raise AssertionError(f"'search_mode' must be 'sharding_propagation' for 'Shard'")
|
||||
if not isinstance(in_axes, tuple):
|
||||
raise TypeError(f"For 'Shard', the 'in_axes' should be a tuple, but got {type(in_axes).__name__}")
|
||||
if not isinstance(out_axes, tuple):
|
||||
raise TypeError(f"For 'Shard', the 'out_axes' should be a tuple, "
|
||||
f"but got {type(out_axes).__name__}")
|
||||
if not isinstance(in_strategy, tuple):
|
||||
raise TypeError(f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}")
|
||||
if not isinstance(out_strategy, tuple):
|
||||
raise TypeError(f"For 'Shard', the 'out_strategy' should be a tuple, "
|
||||
f"but got {type(out_strategy).__name__}")
|
||||
if not isinstance(device, str):
|
||||
raise TypeError(f"For 'Shard', the 'device' should be a string, "
|
||||
f"but got {type(device).__name__}")
|
||||
if not isinstance(level, int):
|
||||
raise TypeError(f"For 'Shard', the 'level' should be an integer, "
|
||||
f"but got {type(level).__name__}")
|
||||
if self.shard_fn is not None and self.fn == fn and self.in_axes == in_axes and self.out_axes == out_axes and \
|
||||
self.device == device and self.level == level:
|
||||
if self.shard_fn is not None and self.fn == fn and self.in_strategy == in_strategy and \
|
||||
self.out_strategy == out_strategy and self.device == device and self.level == level:
|
||||
return self.shard_fn
|
||||
shard_ = Shard()
|
||||
|
||||
@ms_function(obj=fn)
|
||||
def after_shard(*args):
|
||||
return shard_(fn, in_axes, out_axes, device, level)(*args)
|
||||
return shard_(fn, in_strategy, out_strategy, device, level)(*args)
|
||||
|
||||
self.shard_fn = after_shard
|
||||
self.fn = fn
|
||||
self.in_axes = in_axes
|
||||
self.out_axes = out_axes
|
||||
self.in_strategy = in_strategy
|
||||
self.out_strategy = out_strategy
|
||||
self.device = device
|
||||
self.level = level
|
||||
return self.shard_fn
|
||||
|
|
|
@ -582,8 +582,8 @@ def vjp(fn, inputs, v):
|
|||
shard_fn = Shard()
|
||||
|
||||
|
||||
def shard(fn, in_axes, out_axes, device="Ascend", level=0):
|
||||
return shard_fn(fn, in_axes, out_axes, device, level)
|
||||
def shard(fn, in_strategy, out_strategy, device="Ascend", level=0):
|
||||
return shard_fn(fn, in_strategy, out_strategy, device, level)
|
||||
|
||||
|
||||
def arange(start=0, stop=None, step=1, rtype=None):
|
||||
|
|
|
@ -183,7 +183,7 @@ class ResNet(nn.Cell):
|
|||
in_channel=in_channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=strides[0])
|
||||
self.layer1.shard(in_axes=(None,), out_axes=(None,))
|
||||
self.layer1.shard(in_strategy=(None,), out_strategy=(None,))
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
|
@ -194,7 +194,7 @@ class ResNet(nn.Cell):
|
|||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=strides[2])
|
||||
self.layer3.shard(in_axes=((8, 1, 1, 1),), out_axes=(None,))
|
||||
self.layer3.shard(in_strategy=((8, 1, 1, 1),), out_strategy=(None,))
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
|
@ -205,7 +205,7 @@ class ResNet(nn.Cell):
|
|||
self.end_point = nn.Dense(2048, num_classes, has_bias=True,
|
||||
weight_init=weight_variable(),
|
||||
bias_init=weight_variable()).add_flags_recursive(fp16=True)
|
||||
self.head = F.shard(self.end_point, in_axes=((1, 8),), out_axes=(None,))
|
||||
self.head = F.shard(self.end_point, in_strategy=((1, 8),), out_strategy=(None,))
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
|
||||
|
@ -376,7 +376,7 @@ def test_train_feed(num_classes=65536):
|
|||
dataset = ds.GeneratorDataset(dataset, column_names=["image", "label"])
|
||||
net = resnet50(num_classes)
|
||||
loss = SoftmaxCrossEntropyExpand(sparse=True)
|
||||
loss.shard(in_axes=(None, None), out_axes=(None,))
|
||||
loss.shard(in_strategy=(None, None), out_strategy=(None,))
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt)
|
||||
model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
|
||||
|
|
|
@ -42,21 +42,21 @@ class NetMatMul(nn.Cell):
|
|||
return self.matmul(x, y)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, in_axes, out_axes):
|
||||
def __init__(self, in_strategy, out_strategy):
|
||||
super().__init__()
|
||||
self.mul_net = NetMul()
|
||||
self.matmul_net = NetMatMul()
|
||||
self.mul_net.shard(in_axes=in_axes, out_axes=out_axes)
|
||||
self.mul_net.shard(in_strategy=in_strategy, out_strategy=out_strategy)
|
||||
|
||||
def construct(self, x, y):
|
||||
out1 = self.matmul_net(x, y)
|
||||
out2 = self.matmul_net(x, y)
|
||||
return self.mul_net(out1, out2)
|
||||
|
||||
def cell_shard_execution(in_axes, out_axes, error_log):
|
||||
|
||||
net = Net(in_axes, out_axes)
|
||||
def cell_shard_execution(in_strategy, out_strategy, error_log):
|
||||
|
||||
net = Net(in_strategy, out_strategy)
|
||||
x = Tensor(np.ones([128, 128]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 128]), dtype=ms.float32)
|
||||
|
||||
|
@ -65,63 +65,66 @@ def cell_shard_execution(in_axes, out_axes, error_log):
|
|||
assert error_log in str(err.value)
|
||||
|
||||
|
||||
def test_in_axes_numbers_check():
|
||||
def test_in_strategy_numbers_check():
|
||||
"""
|
||||
Feature: shard function for cell
|
||||
Description: inconsistent input number and in_axes number
|
||||
Expectation: throw an exception indicating inconsistent input number and in_axes number
|
||||
Description: inconsistent input number and in_strategy number
|
||||
Expectation: throw an exception indicating inconsistent input number and in_strategy number
|
||||
"""
|
||||
set_context()
|
||||
in_axes = ((8, 1), None, (1, 8))
|
||||
out_axes = (None,)
|
||||
error_log = "Input numbers: 2 is not equal to in_axes numbers: 3"
|
||||
cell_shard_execution(in_axes, out_axes, error_log)
|
||||
in_strategy = ((8, 1), None, (1, 8))
|
||||
out_strategy = (None,)
|
||||
error_log = "Input numbers: 2 is not equal to in_strategy numbers: 3"
|
||||
cell_shard_execution(in_strategy, out_strategy, error_log)
|
||||
|
||||
|
||||
def test_out_axes_numbers_check():
|
||||
def test_out_strategy_numbers_check():
|
||||
"""
|
||||
Feature: shard function for cell
|
||||
Description: inconsistent output number and out_axes number
|
||||
Expectation: throw an exception indicating inconsistent output number and out_axes number
|
||||
Description: inconsistent output number and out_strategy number
|
||||
Expectation: throw an exception indicating inconsistent output number and out_strategy number
|
||||
"""
|
||||
set_context()
|
||||
in_axes = ((8, 1), None)
|
||||
out_axes = (None, (8, 1))
|
||||
error_log = "Output number: 1 is not equal to out_axes number: 2"
|
||||
cell_shard_execution(in_axes, out_axes, error_log)
|
||||
in_strategy = ((8, 1), None)
|
||||
out_strategy = (None, (8, 1))
|
||||
error_log = "Output number: 1 is not equal to out_strategy number: 2"
|
||||
cell_shard_execution(in_strategy, out_strategy, error_log)
|
||||
|
||||
def test_in_axes_dimension_check():
|
||||
|
||||
def test_in_strategy_dimension_check():
|
||||
"""
|
||||
Feature: shard function for cell
|
||||
Description: inconsistent input dimension and in_axes dimension
|
||||
Expectation: throw an exception indicating inconsistent input_dimension and in_axes dimension
|
||||
Description: inconsistent input dimension and in_strategy dimension
|
||||
Expectation: throw an exception indicating inconsistent input_dimension and in_strategy dimension
|
||||
"""
|
||||
set_context()
|
||||
in_axes = ((8, 1, 1), None)
|
||||
out_axes = (None, (8, 1))
|
||||
error_log = "Input dimension: 2 is not equal to in_axes dimension: 3 at index 0"
|
||||
cell_shard_execution(in_axes, out_axes, error_log)
|
||||
in_strategy = ((8, 1, 1), None)
|
||||
out_strategy = (None, (8, 1))
|
||||
error_log = "Input dimension: 2 is not equal to in_strategy dimension: 3 at index 0"
|
||||
cell_shard_execution(in_strategy, out_strategy, error_log)
|
||||
|
||||
def test_out_axes_dimension_check():
|
||||
|
||||
def test_out_strategy_dimension_check():
|
||||
"""
|
||||
Feature: shard function for cell
|
||||
Description: inconsistent output dimension and out_axes dimension
|
||||
Expectation: throw an exception indicating inconsistent output_dimension and out_axes dimension
|
||||
Description: inconsistent output dimension and out_strategy dimension
|
||||
Expectation: throw an exception indicating inconsistent output_dimension and out_strategy dimension
|
||||
"""
|
||||
set_context()
|
||||
in_axes = ((8, 1), None)
|
||||
out_axes = ((8,),)
|
||||
error_log = "Output dimension: 2 is not equal to out_axes dimension: 1 at index 0"
|
||||
cell_shard_execution(in_axes, out_axes, error_log)
|
||||
in_strategy = ((8, 1), None)
|
||||
out_strategy = ((8,),)
|
||||
error_log = "Output dimension: 2 is not equal to out_strategy dimension: 1 at index 0"
|
||||
cell_shard_execution(in_strategy, out_strategy, error_log)
|
||||
|
||||
def test_in_axes_format_check():
|
||||
|
||||
def test_in_strategy_format_check():
|
||||
"""
|
||||
Feature: shard function for cell
|
||||
Description: unsupported in_axes format
|
||||
Expectation: throw an exception indicating an supported in_axes format
|
||||
Description: unsupported in_strategy format
|
||||
Expectation: throw an exception indicating an supported in_strategy format
|
||||
"""
|
||||
set_context()
|
||||
in_axes = ([8, 1], None)
|
||||
out_axes = (None,)
|
||||
error_log = "in_axes should be a two-dimension tuple"
|
||||
cell_shard_execution(in_axes, out_axes, error_log)
|
||||
in_strategy = ([8, 1], None)
|
||||
out_strategy = (None,)
|
||||
error_log = "in_strategy should be a two-dimension tuple"
|
||||
cell_shard_execution(in_strategy, out_strategy, error_log)
|
||||
|
|
Loading…
Reference in New Issue