forked from mindspore-Ecosystem/mindspore
add parallel mode for cell
This commit is contained in:
parent
7c64048d76
commit
8c9730b3c5
|
@ -35,8 +35,8 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
|
|||
// assume no change to graph
|
||||
bool changes = false;
|
||||
// control whether use model_parallel mode
|
||||
if (((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || (!enable_all_reduce_fusion) ||
|
||||
(root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) {
|
||||
if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
|
||||
(!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) {
|
||||
return changes;
|
||||
}
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
|
|
|
@ -121,7 +121,8 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
|||
// assume no change to graph
|
||||
bool changes = false;
|
||||
// control whether use model_parallel mode
|
||||
if ((parallel_mode != AUTO_PARALLEL) || root->flags()[AUTO_PARALLEL_RUN_ONCE_ONLY]) {
|
||||
if (!root->has_flag(AUTO_PARALLEL) || (parallel_mode != AUTO_PARALLEL) ||
|
||||
root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) {
|
||||
return changes;
|
||||
}
|
||||
// check whether strategy_search_mode is valid
|
||||
|
|
|
@ -2220,7 +2220,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
// assume no change to graph
|
||||
bool changes = false;
|
||||
// control whether use model_parallel mode
|
||||
if (((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
|
||||
if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
|
||||
(root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) {
|
||||
return changes;
|
||||
}
|
||||
|
|
|
@ -281,7 +281,7 @@ void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) {
|
|||
|
||||
MS_LOG(INFO) << "Save compiled func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
|
||||
info_[phase_s]->func_graph = func_graph;
|
||||
if ((func_graph != nullptr) &&
|
||||
if ((func_graph != nullptr) && func_graph->has_flag(parallel::AUTO_PARALLEL) &&
|
||||
((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) {
|
||||
MS_LOG(DEBUG) << "Save model parallel parameter layout graph!";
|
||||
func_graph = info_[phase_s]->resource->results()[kStepParallelGraph].cast<FuncGraphPtr>();
|
||||
|
|
|
@ -20,7 +20,6 @@ from collections import OrderedDict
|
|||
from functools import wraps
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore.parallel._utils import _get_parallel_mode
|
||||
from .._c_expression import generate_key, Executor_, Tensor, MetaTensor
|
||||
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
|
||||
from .tensor import Tensor as MsTensor
|
||||
|
@ -327,7 +326,7 @@ class _Executor:
|
|||
raise TypeError('Parameters need OrderedDict type, but got {}'.
|
||||
format(type(params)))
|
||||
|
||||
def compile(self, obj, *args, phase='predict', params=None, do_convert=True):
|
||||
def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
|
||||
"""
|
||||
Compiles graph.
|
||||
|
||||
|
@ -337,6 +336,7 @@ class _Executor:
|
|||
phase (str): The name of compile phase. Default: 'predict'.
|
||||
params (OrderedDict): The parameters dictionary used for init data graph. Default: None.
|
||||
do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
|
||||
auto_parallel_mode: When set to True, use auto parallel mode to compile graph.
|
||||
|
||||
Return:
|
||||
Str, the full phase of the cell.
|
||||
|
@ -370,8 +370,9 @@ class _Executor:
|
|||
logger.error("%r graph compile failed.", phase)
|
||||
if not do_convert:
|
||||
return phase, True
|
||||
|
||||
if not enable_debug_runtime or enable_ge:
|
||||
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]:
|
||||
if auto_parallel_mode:
|
||||
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
||||
obj.load_parameter_slice(params)
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@ from ..common.parameter import Parameter, ParameterTuple
|
|||
from .._c_expression import init_backend
|
||||
from ..ops.primitive import Primitive
|
||||
from ..parallel._tensor import _load_tensor_by_layout
|
||||
from ..parallel._utils import _get_parallel_mode
|
||||
from ..common.tensor import Tensor
|
||||
|
||||
|
||||
|
@ -71,8 +70,7 @@ class Cell:
|
|||
gc.collect()
|
||||
self._construct_inputs_num = 0
|
||||
self._construct_inputs_names = []
|
||||
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]:
|
||||
self._get_construct_inputs_number_and_name()
|
||||
self._auto_parallel_mode = False
|
||||
self._parallel_inputs_run = None
|
||||
if flags:
|
||||
self.add_flags(**flags)
|
||||
|
@ -298,9 +296,10 @@ class Cell:
|
|||
Returns:
|
||||
Object, the result of executing.
|
||||
"""
|
||||
_, compile_flag = _executor.compile(self, *inputs, phase=self.phase)
|
||||
_, compile_flag = _executor.compile(self, *inputs, phase=self.phase,
|
||||
auto_parallel_mode=self._auto_parallel_mode)
|
||||
|
||||
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]:
|
||||
if self._auto_parallel_mode:
|
||||
if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag and (not compile_flag):
|
||||
parallel_inputs_run = self._parallel_inputs_run
|
||||
else:
|
||||
|
@ -665,3 +664,15 @@ class Cell:
|
|||
"""
|
||||
self.add_flags_recursive(broadcast_flag=mode)
|
||||
return self
|
||||
|
||||
def set_auto_parallel(self):
|
||||
"""
|
||||
Set the cell to auto parallel mode.
|
||||
|
||||
Note:
|
||||
If a cell needs to use auto parallel or semi auto parallel mode for training, evaluation or prediction,
|
||||
this interface needs to be called for the cell.
|
||||
"""
|
||||
self._auto_parallel_mode = True
|
||||
self.add_flags(auto_parallel=True)
|
||||
self._get_construct_inputs_number_and_name()
|
||||
|
|
|
@ -16,8 +16,7 @@
|
|||
|
||||
from mindspore._c_expression import reset_op_id
|
||||
from mindspore.communication.management import get_group_size, get_rank
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context, _set_auto_parallel_context,\
|
||||
_reset_auto_parallel_context
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
|
||||
|
||||
def _get_parallel_mode():
|
||||
|
@ -108,102 +107,6 @@ def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
|
|||
.format(parallel_mode, parameter_broadcast))
|
||||
|
||||
|
||||
_parallel_mode = None
|
||||
_device_num = None
|
||||
_global_rank = None
|
||||
_parameter_broadcast = None
|
||||
_mirror_mean = None
|
||||
_cast_before_mirror = None
|
||||
_loss_repeated_mean = None
|
||||
_communication_backend = None
|
||||
_has_checkpointed = False
|
||||
_enable_all_reduce_fusion = None
|
||||
|
||||
|
||||
def _checkpoint_auto_parallel_context():
|
||||
"""checkpoint auto parallel context"""
|
||||
global _has_checkpointed
|
||||
if _has_checkpointed is True:
|
||||
return
|
||||
|
||||
global _parallel_mode
|
||||
global _device_num
|
||||
global _global_rank
|
||||
global _parameter_broadcast
|
||||
global _mirror_mean
|
||||
global _cast_before_mirror
|
||||
global _loss_repeated_mean
|
||||
global _communication_backend
|
||||
global _enable_all_reduce_fusion
|
||||
_parallel_mode = auto_parallel_context().get_parallel_mode()
|
||||
_device_num = _get_device_num()
|
||||
_global_rank = _get_global_rank()
|
||||
_parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
|
||||
_mirror_mean = auto_parallel_context().get_mirror_mean()
|
||||
_cast_before_mirror = auto_parallel_context().get_cast_before_mirror()
|
||||
_loss_repeated_mean = auto_parallel_context().get_loss_repeated_mean()
|
||||
_communication_backend = auto_parallel_context().get_communication_backend()
|
||||
_enable_all_reduce_fusion = auto_parallel_context().get_enable_all_reduce_fusion()
|
||||
_has_checkpointed = True
|
||||
|
||||
|
||||
def _restore_auto_parallel_context():
|
||||
"""restore auto parallel context"""
|
||||
global _parallel_mode
|
||||
global _device_num
|
||||
global _global_rank
|
||||
global _parameter_broadcast
|
||||
global _mirror_mean
|
||||
global _cast_before_mirror
|
||||
global _loss_repeated_mean
|
||||
global _communication_backend
|
||||
global _enable_all_reduce_fusion
|
||||
_set_auto_parallel_context(parallel_mode=_parallel_mode, device_num=_device_num, global_rank=_global_rank,
|
||||
parameter_broadcast=_parameter_broadcast, mirror_mean=_mirror_mean,
|
||||
cast_before_mirror=_cast_before_mirror, loss_repeated_mean=_loss_repeated_mean)
|
||||
auto_parallel_context().set_communication_backend(_communication_backend)
|
||||
auto_parallel_context().set_enable_all_reduce_fusion(_enable_all_reduce_fusion)
|
||||
|
||||
|
||||
def _reset_checkpoint_auto_parallel_context():
|
||||
"""reset the _has_checkpointed"""
|
||||
global _has_checkpointed
|
||||
_has_checkpointed = False
|
||||
|
||||
|
||||
def _callback_wrapper(list_callback, run_context, callback_type):
|
||||
"""
|
||||
reset the context for callback of model train
|
||||
|
||||
Raises:
|
||||
ValueError: If the type keyword is not recognized
|
||||
"""
|
||||
_callback_func_map = {
|
||||
"begin": list_callback.begin,
|
||||
"epoch_begin": list_callback.epoch_begin,
|
||||
"step_begin": list_callback.step_begin,
|
||||
"step_end": list_callback.step_end,
|
||||
"epoch_end": list_callback.epoch_end,
|
||||
"end": list_callback.end}
|
||||
|
||||
if callback_type not in _callback_func_map:
|
||||
raise ValueError("Get type keyword %s is not recognized!" % callback_type)
|
||||
func = _callback_func_map[callback_type]
|
||||
|
||||
if callback_type == "begin":
|
||||
_reset_checkpoint_auto_parallel_context()
|
||||
|
||||
_checkpoint_auto_parallel_context()
|
||||
global _parallel_mode
|
||||
if _parallel_mode == "stand_alone":
|
||||
func(run_context)
|
||||
return
|
||||
|
||||
_reset_auto_parallel_context()
|
||||
func(run_context)
|
||||
_restore_auto_parallel_context()
|
||||
|
||||
|
||||
PARAMETER_CLONED_INDEX = 0
|
||||
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from .._checkparam import check_input_data, check_output_data, check_int_positiv
|
|||
from .callback import _InternalCallbackParam, RunContext, _build_callbacks
|
||||
from .. import context
|
||||
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
from ..nn.metrics import Loss
|
||||
from .. import nn
|
||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
|
@ -144,6 +144,9 @@ class Model:
|
|||
elif self._loss_fn:
|
||||
network = nn.WithLossCell(network, self._loss_fn)
|
||||
# If need to check if loss_fn is not None, but optimizer is None
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network.set_auto_parallel()
|
||||
return network
|
||||
|
||||
def _build_eval_network(self, metrics, eval_network, eval_indexes):
|
||||
|
@ -165,11 +168,15 @@ class Model:
|
|||
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2")
|
||||
self._eval_indexes = [0, 1, 2]
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
self._eval_network.set_auto_parallel()
|
||||
|
||||
def _build_predict_network(self):
|
||||
"""Build the network for prediction."""
|
||||
self._predict_network = self._network
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
self._predict_network = _VirtualDatasetCell(self._network)
|
||||
self._predict_network.set_auto_parallel()
|
||||
|
||||
def _clear_metrics(self):
|
||||
"""Clear metrics local values."""
|
||||
|
@ -287,28 +294,28 @@ class Model:
|
|||
cb_params.cur_step_num = 0
|
||||
loop_size = dataset_helper.loop_size()
|
||||
run_context = RunContext(cb_params)
|
||||
_callback_wrapper(list_callback, run_context, "begin")
|
||||
list_callback.begin(run_context)
|
||||
|
||||
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
||||
should_stop = False
|
||||
for i in range(epoch):
|
||||
cb_params.cur_epoch_num = i + 1
|
||||
_callback_wrapper(list_callback, run_context, "epoch_begin")
|
||||
list_callback.epoch_begin(run_context)
|
||||
|
||||
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
||||
for inputs in dataset_helper:
|
||||
cb_params.cur_step_num += loop_size
|
||||
_callback_wrapper(list_callback, run_context, "step_begin")
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._train_network(*inputs)
|
||||
cb_params.net_outputs = outputs
|
||||
_callback_wrapper(list_callback, run_context, "step_end")
|
||||
list_callback.step_end(run_context)
|
||||
|
||||
_callback_wrapper(list_callback, run_context, "epoch_end")
|
||||
list_callback.epoch_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
_callback_wrapper(list_callback, run_context, "end")
|
||||
list_callback.end(run_context)
|
||||
|
||||
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
|
@ -327,14 +334,14 @@ class Model:
|
|||
dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False)
|
||||
cb_params.cur_step_num = 0
|
||||
run_context = RunContext(cb_params)
|
||||
_callback_wrapper(list_callback, run_context, "begin")
|
||||
list_callback.begin(run_context)
|
||||
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
||||
should_stop = False
|
||||
|
||||
for i in range(epoch):
|
||||
cb_params.cur_epoch_num = i + 1
|
||||
|
||||
_callback_wrapper(list_callback, run_context, "epoch_begin")
|
||||
list_callback.epoch_begin(run_context)
|
||||
|
||||
for next_element in dataset_helper:
|
||||
len_element = len(next_element)
|
||||
|
@ -342,7 +349,7 @@ class Model:
|
|||
raise ValueError("when loss_fn is not None, train_dataset should"
|
||||
"return two elements, but got {}".format(len_element))
|
||||
cb_params.cur_step_num += 1
|
||||
_callback_wrapper(list_callback, run_context, "step_begin")
|
||||
list_callback.step_begin(run_context)
|
||||
|
||||
overflow = False
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
|
@ -356,19 +363,19 @@ class Model:
|
|||
overflow = np.all(overflow.asnumpy())
|
||||
self._loss_scale_manager.update_loss_scale(overflow)
|
||||
|
||||
_callback_wrapper(list_callback, run_context, "step_end")
|
||||
list_callback.step_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
train_dataset.reset()
|
||||
|
||||
_callback_wrapper(list_callback, run_context, "epoch_end")
|
||||
list_callback.epoch_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
_callback_wrapper(list_callback, run_context, "end")
|
||||
list_callback.end(run_context)
|
||||
|
||||
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
|
||||
"""
|
||||
|
|
|
@ -92,6 +92,7 @@ class AddReluFactory:
|
|||
def forward_mindspore_parallel_impl(self):
|
||||
net = AddRelu(strategy0=self.strategy0, strategy1=self.strategy1)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
x = Tensor(self.input_np1)
|
||||
y = Tensor(self.input_np2, ms.float32)
|
||||
inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1])
|
||||
|
@ -118,6 +119,7 @@ class AddReluFactory:
|
|||
net = AddRelu(strategy0=self.strategy0, strategy1=self.strategy1)
|
||||
grad_net = Grad(net)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
grad_net.set_auto_parallel()
|
||||
grad_net.set_train()
|
||||
inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1])
|
||||
x1 = Tensor(inputs_x[self.x_id])
|
||||
|
|
|
@ -249,6 +249,7 @@ class Conv2dFactory:
|
|||
padding=self.padding, dilation=self.dilation,
|
||||
group=self.group, has_bias=False, weight_init=weight, strategy=(self.strategy0[0], self.strategy0[1], self.strategy0[1]))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
|
||||
return out.asnumpy()
|
||||
|
||||
|
@ -307,7 +308,8 @@ class Conv2dFactory:
|
|||
|
||||
grad_net = Grad(net)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
grad_net.set_train()
|
||||
grad_net.set_train()
|
||||
grad_net.set_auto_parallel()
|
||||
out_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1])
|
||||
return out_grad
|
||||
|
||||
|
|
|
@ -95,6 +95,7 @@ class DropoutFactory:
|
|||
x1 = Tensor(inputs_x[self.x_id])
|
||||
net = Net(0.4, 0, 0, strategy=self.strategy0)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
out = net(x, parallel_inputs_compile=[x], parallel_inputs_run=[x1])
|
||||
return out.asnumpy()
|
||||
|
||||
|
|
|
@ -118,6 +118,7 @@ class L2normalizeFactory:
|
|||
y1 = Tensor(inputs_y[self.y_id])
|
||||
net = L2normalize(self.axis, self.epsilon, strategy0=self.strategy0, strategy1=self.strategy1)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
|
||||
return out.asnumpy()
|
||||
|
||||
|
@ -144,6 +145,7 @@ class L2normalizeFactory:
|
|||
net = L2normalize(self.axis, self.epsilon, strategy0=self.strategy0, strategy1=self.strategy1)
|
||||
grad_net = Grad(net)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
grad_net.set_auto_parallel()
|
||||
grad_net.set_train()
|
||||
input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1])
|
||||
return input_grad
|
||||
|
|
|
@ -140,6 +140,7 @@ class AddReluFactory:
|
|||
net_with_loss = NetWithLoss(net, strategy2=self.strategy2)
|
||||
grad_net = Grad(net_with_loss)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
grad_net.set_auto_parallel()
|
||||
grad_net.set_train()
|
||||
input_grads = []
|
||||
for i in range(0, 3):
|
||||
|
|
|
@ -229,6 +229,7 @@ class BatchmatmulFactory:
|
|||
y1 = Tensor(ys[self.y_id]) #需要从设备矩阵推导
|
||||
z1 = Tensor(zs[self.x_id])
|
||||
matmul.set_train()
|
||||
matmul.set_auto_parallel()
|
||||
out_me = matmul(x, y, z, parallel_inputs_compile=[x, y, z], parallel_inputs_run=[x1, y1, z1])
|
||||
return out_me.asnumpy()
|
||||
|
||||
|
@ -267,6 +268,7 @@ class BatchmatmulFactory:
|
|||
out_grad1 = Tensor(out_grads[self.out_id])
|
||||
net_me = Grad(matmul)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net_me.set_auto_parallel()
|
||||
net_me.set_train()
|
||||
|
||||
out_grad = net_me(x, y, z, out_grad_me, parallel_inputs_compile = [x, y, z, out_grad1], parallel_inputs_run = [x1, y1, z1, out_grad1])
|
||||
|
|
|
@ -119,6 +119,7 @@ class MaxFactory:
|
|||
y1 = Tensor(ys[self.y_id])
|
||||
net = Max(axis=self.axis, keep_dims=self.keep_dims, strategy0=self.strategy0, strategy1=self.strategy1)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
|
||||
return out.asnumpy()
|
||||
|
||||
|
@ -144,6 +145,7 @@ class MaxFactory:
|
|||
net = Max(axis=self.axis, keep_dims=self.keep_dims, strategy0=self.strategy0, strategy1=self.strategy1)
|
||||
grad_net = Grad(net)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
grad_net.set_auto_parallel()
|
||||
grad_net.set_train()
|
||||
input_grad = grad_net(x, y, out_grad, parallel_inputs_compile=[x, y, out_grad], parallel_inputs_run=[x1, y1, out_grad])
|
||||
return input_grad
|
||||
|
|
|
@ -93,6 +93,7 @@ class MulSoftmaxFactory:
|
|||
def forward_mindspore_parallel_impl(self):
|
||||
net = MulSoftmax(strategy0=self.strategy0, strategy1=self.strategy1)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
x = Tensor(self.input_np1)
|
||||
y = Tensor(self.input_np2, ms.float32)
|
||||
inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1])
|
||||
|
@ -120,6 +121,7 @@ class MulSoftmaxFactory:
|
|||
grad_net = Grad(net)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
grad_net.set_train()
|
||||
grad_net.set_auto_parallel()
|
||||
inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1])
|
||||
x1 = Tensor(inputs_x[self.x_id])
|
||||
y1 = Tensor(self.input_np2, ms.float32)
|
||||
|
|
|
@ -113,6 +113,7 @@ class OneHotFactory:
|
|||
on_value=self.on_value,
|
||||
off_value=self.off_value, strategy=self.strategy0)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
out = net(x, parallel_inputs_compile=[x], parallel_inputs_run=[x1])
|
||||
return out.asnumpy()
|
||||
|
||||
|
|
|
@ -86,6 +86,7 @@ class PReLUFactory:
|
|||
def forward_mindspore_parallel_impl(self):
|
||||
net = PReLU(channel=self.channel, w=self.weight, strategy_=self.strategy, strategy1_=(self.strategy[0], self.strategy[1], self.strategy[1]))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
x = Tensor(self.input_np)
|
||||
z = Tensor(np.zeros(self.input_np.shape), ms.float32)
|
||||
w = Tensor(self.weight)
|
||||
|
@ -122,6 +123,7 @@ class PReLUFactory:
|
|||
net = PReLU(channel=self.channel, w=self.weight, strategy_=self.strategy, strategy1_=(self.strategy[0], self.strategy[1], self.strategy[1]))
|
||||
grad_net = Grad(net)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
grad_net.set_auto_parallel()
|
||||
|
||||
grad_net.set_train()
|
||||
inputs = self.get_parallel_blocks(self.input_np, self.strategy[1])
|
||||
|
|
|
@ -176,6 +176,7 @@ class ReduceMeanFactory:
|
|||
y1 = Tensor(inputs_y[self.y_id])
|
||||
net = ReduceMean(keep_dims=self.keep_dims, axis=self.axis, strategy0=self.strategy0, strategy1=self.strategy1)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
|
||||
return out.asnumpy()
|
||||
|
||||
|
@ -202,6 +203,7 @@ class ReduceMeanFactory:
|
|||
net = ReduceMean(keep_dims=self.keep_dims, axis=self.axis, strategy0=self.strategy0, strategy1=self.strategy1)
|
||||
grad_net = Grad(net)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
grad_net.set_auto_parallel()
|
||||
grad_net.set_train()
|
||||
input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1],
|
||||
parallel_inputs_run=[x1, y1, output_grad1])
|
||||
|
|
|
@ -121,6 +121,7 @@ class ReshapeFactory:
|
|||
y1 = Tensor(inputs_y[self.y_id])
|
||||
net = Reshape(self.target_shape, strategy0=self.strategy0, strategy1=self.strategy1)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
|
||||
return out.asnumpy()
|
||||
|
||||
|
@ -147,6 +148,7 @@ class ReshapeFactory:
|
|||
net = Reshape(self.target_shape, strategy0=self.strategy0, strategy1=self.strategy1)
|
||||
grad_net = Grad(net)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
grad_net.set_auto_parallel()
|
||||
grad_net.set_train()
|
||||
input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1])
|
||||
return input_grad
|
||||
|
|
|
@ -148,6 +148,7 @@ class TransposeFactory:
|
|||
y1 = Tensor(inputs_y[self.y_id])
|
||||
net = Net(self.perm_in, strategy0=self.strategy0, strategy1=self.strategy1)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
|
||||
return out.asnumpy()
|
||||
|
||||
|
@ -174,6 +175,7 @@ class TransposeFactory:
|
|||
net = Net(self.perm_in, strategy0=self.strategy0, strategy1=self.strategy1)
|
||||
grad_net = Grad(net)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
grad_net.set_auto_parallel()
|
||||
grad_net.set_train()
|
||||
input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1])
|
||||
return input_grad
|
||||
|
|
|
@ -49,6 +49,12 @@ class Grad(nn.Cell):
|
|||
def construct(self, x, y):
|
||||
return C.grad_all(self.network)(x, y)
|
||||
|
||||
|
||||
def compile(net, x, y):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_add_relu_stride_slice():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=7)
|
||||
|
||||
|
@ -59,7 +65,7 @@ def test_add_relu_stride_slice():
|
|||
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
compile(net, x, y)
|
||||
|
||||
def test_add_relu_all_gather():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=7)
|
||||
|
@ -71,4 +77,4 @@ def test_add_relu_all_gather():
|
|||
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
compile(net, x, y)
|
|
@ -42,6 +42,11 @@ class GradWrap(nn.Cell):
|
|||
return C.grad_all(self.network)(x, y, b)
|
||||
|
||||
|
||||
def compile(net, x, y, b):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_sub():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
|
@ -64,7 +69,7 @@ def test_matmul_sub():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_add():
|
||||
|
@ -88,7 +93,7 @@ def test_matmul_add():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_mul():
|
||||
|
@ -112,7 +117,7 @@ def test_matmul_mul():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_div():
|
||||
|
@ -136,7 +141,7 @@ def test_matmul_div():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
def test_matmul_greater():
|
||||
class Net(nn.Cell):
|
||||
|
@ -159,7 +164,7 @@ def test_matmul_greater():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
def test_matmul_add_broadcast():
|
||||
class Net(nn.Cell):
|
||||
|
@ -182,7 +187,7 @@ def test_matmul_add_broadcast():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_add_broadcast2():
|
||||
|
@ -206,7 +211,7 @@ def test_matmul_add_broadcast2():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_sub_broadcast():
|
||||
|
@ -230,7 +235,7 @@ def test_matmul_sub_broadcast():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_sub_broadcast2():
|
||||
|
@ -254,7 +259,7 @@ def test_matmul_sub_broadcast2():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_mul_broadcast():
|
||||
|
@ -278,7 +283,7 @@ def test_matmul_mul_broadcast():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_mul_broadcast2():
|
||||
|
@ -302,7 +307,7 @@ def test_matmul_mul_broadcast2():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_div_broadcast():
|
||||
|
@ -326,7 +331,7 @@ def test_matmul_div_broadcast():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_div_broadcast2():
|
||||
|
@ -350,7 +355,7 @@ def test_matmul_div_broadcast2():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
def test_matmul_greater_broadcast():
|
||||
class Net(nn.Cell):
|
||||
|
@ -373,7 +378,7 @@ def test_matmul_greater_broadcast():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_greater_broadcast2():
|
||||
|
@ -397,7 +402,7 @@ def test_matmul_greater_broadcast2():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
def test_matmul_floordiv():
|
||||
class Net(nn.Cell):
|
||||
|
@ -420,7 +425,7 @@ def test_matmul_floordiv():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_floordiv_broadcast():
|
||||
|
@ -444,7 +449,7 @@ def test_matmul_floordiv_broadcast():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_floordiv_broadcast2():
|
||||
|
@ -468,7 +473,7 @@ def test_matmul_floordiv_broadcast2():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_assign_sub():
|
||||
|
@ -495,4 +500,4 @@ def test_assign_sub():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
z = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, z)
|
||||
compile(net, x, y, z)
|
||||
|
|
|
@ -66,4 +66,5 @@ def test_auto_parallel_bn_with_prelu():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x)
|
||||
|
|
|
@ -43,6 +43,12 @@ class GradWrap(nn.Cell):
|
|||
def construct(self, x, y, b):
|
||||
return C.grad_all(self.network)(x, y, b)
|
||||
|
||||
|
||||
def compile(net, x, y, b, phase):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b, phase=phase)
|
||||
|
||||
|
||||
def test_auto_parallel_arithmetic():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -63,7 +69,7 @@ def test_auto_parallel_arithmetic():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 128]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b, phase='train')
|
||||
compile(net, x, y, b, phase='train')
|
||||
strategies = _executor._get_strategy(net)
|
||||
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[2, 4], [2, 4]],
|
||||
'Default/network-Net/MatMul-op1': [[2, 1], [1, 4]]}
|
||||
|
@ -89,7 +95,7 @@ def test_auto_parallel_arithmetic_broadcast_both():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b, phase='train')
|
||||
compile(net, x, y, b, phase='train')
|
||||
strategies = _executor._get_strategy(net)
|
||||
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[8, 1], [1, 1]],
|
||||
'Default/network-Net/MatMul-op1': [[8, 1], [1, 1]]}
|
||||
|
@ -116,7 +122,7 @@ def test_auto_parallel_arithmetic_broadcast_right():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([32]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b, phase='train')
|
||||
compile(net, x, y, b, phase='train')
|
||||
strategies = _executor._get_strategy(net)
|
||||
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [2]],
|
||||
'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]}
|
||||
|
@ -143,7 +149,7 @@ def test_auto_parallel_arithmetic_broadcast_left():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b, phase="train")
|
||||
compile(net, x, y, b, phase="train")
|
||||
strategies = _executor._get_strategy(net)
|
||||
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [1, 4, 2]],
|
||||
'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]}
|
||||
|
|
1
tests/ut/python/parallel/test_auto_parallel_assign_sub_with_ref_key.py
Executable file → Normal file
1
tests/ut/python/parallel/test_auto_parallel_assign_sub_with_ref_key.py
Executable file → Normal file
|
@ -52,6 +52,7 @@ def test_auto_parallel_assign_sub_with_ref_key():
|
|||
|
||||
net = NetWithLoss(nn.PReLU(4))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
reset_op_id()
|
||||
|
||||
_executor.compile(net, x, phase="train")
|
||||
|
|
|
@ -71,6 +71,7 @@ def test_double_star_graph():
|
|||
|
||||
net = NetWithLoss(Net())
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
reset_op_id()
|
||||
|
||||
_executor.compile(net, x, y, z, w, phase='train')
|
||||
|
|
|
@ -63,4 +63,5 @@ def test_common_parameter():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, z)
|
|
@ -74,4 +74,5 @@ def test_double_star_graph():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, z, w, a, b, c)
|
||||
|
|
|
@ -88,6 +88,7 @@ def test_double_subgraphs():
|
|||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
net = TrainStepWarp(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32)
|
||||
reset_op_id()
|
||||
|
|
|
@ -61,4 +61,5 @@ def test_two_matmul():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
|
|
|
@ -40,6 +40,12 @@ class GradWrap(nn.Cell):
|
|||
def construct(self, x, y, z, w, b):
|
||||
return C.grad_all(self.network)(x, y, z, w, b)
|
||||
|
||||
|
||||
def compile(net, x, y, z, w, b):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, z, w, b)
|
||||
|
||||
|
||||
# model_parallel test
|
||||
def test_four_matmul_linear():
|
||||
class Net(nn.Cell):
|
||||
|
@ -67,7 +73,7 @@ def test_four_matmul_linear():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
_executor.compile(net, x, y, z, w, b)
|
||||
compile(net, x, y, z, w, b)
|
||||
|
||||
|
||||
def test_four_matmul1():
|
||||
|
@ -93,7 +99,7 @@ def test_four_matmul1():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
_executor.compile(net, x, y, z, w, b)
|
||||
compile(net, x, y, z, w, b)
|
||||
|
||||
|
||||
def test_four_matmul2():
|
||||
|
@ -120,4 +126,4 @@ def test_four_matmul2():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
_executor.compile(net, x, y, z, w, b)
|
||||
compile(net, x, y, z, w, b)
|
||||
|
|
|
@ -63,6 +63,7 @@ def test_auto_parallel_l2normalize():
|
|||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
net = NetWithLoss(Net())
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
reset_op_id()
|
||||
|
||||
x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
|
||||
|
|
|
@ -61,6 +61,7 @@ def test_two_matmul_dropout():
|
|||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
|
|
|
@ -63,6 +63,7 @@ def test_matmul_prelu():
|
|||
|
||||
net = NetWithLoss(Net())
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
reset_op_id()
|
||||
|
||||
_executor.compile(net, x, y, b, phase='train')
|
||||
|
|
|
@ -89,6 +89,7 @@ def test_auto_parallel_arithmetic():
|
|||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
|
|
|
@ -76,6 +76,7 @@ def test_common_parameter():
|
|||
|
||||
net = NetWithLoss(Net())
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
reset_op_id()
|
||||
|
||||
_executor.compile(net, x, y, z, w, phase='train')
|
||||
|
|
|
@ -68,5 +68,5 @@ def test_four_matmul_linear():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net(strategy1)))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, z, w, b)
|
||||
|
|
|
@ -42,6 +42,12 @@ class GradWrap(nn.Cell):
|
|||
def construct(self, x, y, b):
|
||||
return C.grad_all(self.network)(x, y, b)
|
||||
|
||||
|
||||
def compile(net, x, y, b):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
|
||||
|
||||
# model_parallel test
|
||||
def test_sum_mul():
|
||||
class Net(nn.Cell):
|
||||
|
@ -64,7 +70,7 @@ def test_sum_mul():
|
|||
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
def test_sum_mul2():
|
||||
class Net(nn.Cell):
|
||||
|
@ -87,7 +93,7 @@ def test_sum_mul2():
|
|||
x = Tensor(np.ones([128, 128, 64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 128, 64, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
def test_sum_mul3():
|
||||
class Net(nn.Cell):
|
||||
|
@ -110,4 +116,4 @@ def test_sum_mul3():
|
|||
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
|
|
@ -62,6 +62,7 @@ def test_reshape_matmul():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x)
|
||||
|
||||
|
||||
|
|
|
@ -40,6 +40,12 @@ class GradWrap(nn.Cell):
|
|||
def construct(self, x, y, b):
|
||||
return C.grad_all(self.network)(x, y, b)
|
||||
|
||||
|
||||
def compile(net, x, y, b):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_rhombus1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -63,7 +69,7 @@ def test_rhombus1():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
def test_rhombus2():
|
||||
class Net(nn.Cell):
|
||||
|
@ -93,7 +99,7 @@ def test_rhombus2():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
def test_rhombus3():
|
||||
class Net(nn.Cell):
|
||||
|
@ -123,4 +129,4 @@ def test_rhombus3():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
_executor.compile(net, x, y, z)
|
||||
compile(net, x, y, z)
|
||||
|
|
|
@ -57,6 +57,7 @@ def test_softmax_cross_entropy_loss_auto_parallel():
|
|||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
|
|
|
@ -102,4 +102,5 @@ def test_dmnet_train_step():
|
|||
input = Tensor(np.ones([4096, 4096]).astype(np.float32) * 0.01)
|
||||
net = GradWrap(NetWithLoss(MultiTransformer()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, input)
|
||||
|
|
|
@ -67,6 +67,7 @@ def test_two_matmul_transpose():
|
|||
|
||||
net = NetWithLoss(Net())
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
reset_op_id()
|
||||
|
||||
_executor.compile(net, x, y, b, phase='train')
|
||||
|
|
|
@ -69,6 +69,7 @@ def test_virtual_dataset_3_input():
|
|||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
net.set_auto_parallel()
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 2048]), dtype=ms.float32)
|
||||
|
|
|
@ -54,6 +54,7 @@ def test_two_bn():
|
|||
context.set_context(save_graphs=True)
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
reset_op_id()
|
||||
|
||||
|
|
|
@ -124,6 +124,7 @@ def test_two_matmul():
|
|||
|
||||
net = NetWithLoss(Net())
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
reset_op_id()
|
||||
|
||||
_executor.compile(net, x, y, b, phase='train')
|
||||
|
|
|
@ -62,4 +62,5 @@ def test_four_matmul_linear():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net(strategy1)))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y)
|
|
@ -68,4 +68,5 @@ def test_zig_zag_graph():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, z, w, a)
|
||||
|
|
|
@ -85,4 +85,5 @@ def test_marin_loss():
|
|||
|
||||
net = GradWrap(NetWithLoss(MarginCE()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y)
|
|
@ -43,6 +43,7 @@ _b = Tensor(np.ones([128, 64, 16]), dtype=ms.float32)
|
|||
def compile(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
|
|
@ -100,6 +100,7 @@ def test_batch():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([128, 16, 34, 34]), dtype=ms.float32)
|
||||
w1 = Tensor(np.ones([128, 8, 32, 32]), dtype=ms.float32)
|
||||
|
|
|
@ -61,6 +61,7 @@ def test_batch_parallel_dropout():
|
|||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
|
|
|
@ -58,6 +58,7 @@ def test_matmul_add():
|
|||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
|
|
|
@ -42,6 +42,12 @@ class GradWrap(nn.Cell):
|
|||
def construct(self, x, y, b):
|
||||
return C.grad_all(self.network)(x, y, b)
|
||||
|
||||
|
||||
def compile(net, x, y, b):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_equal():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
|
@ -62,7 +68,7 @@ def test_matmul_equal():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_not_equal():
|
||||
|
@ -85,7 +91,7 @@ def test_matmul_not_equal():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_not_equal_repeated_calculation():
|
||||
|
@ -108,7 +114,7 @@ def test_matmul_not_equal_repeated_calculation():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_maximum():
|
||||
|
@ -131,7 +137,7 @@ def test_matmul_maximum():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_maximum_broadcast():
|
||||
|
@ -154,7 +160,7 @@ def test_matmul_maximum_broadcast():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_maximum_broadcast2():
|
||||
|
@ -177,7 +183,7 @@ def test_matmul_maximum_broadcast2():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_minimum():
|
||||
|
@ -200,7 +206,7 @@ def test_matmul_minimum():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_minimum_broadcast():
|
||||
|
@ -223,7 +229,7 @@ def test_matmul_minimum_broadcast():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_minimum_broadcast2():
|
||||
|
@ -246,7 +252,7 @@ def test_matmul_minimum_broadcast2():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_minimum_auto_parallel():
|
||||
|
@ -267,4 +273,4 @@ def test_matmul_minimum_auto_parallel():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
|
|
@ -31,6 +31,11 @@ class GradWrap(nn.Cell):
|
|||
return C.grad_all(self.network)(x, y, bias)
|
||||
|
||||
|
||||
def compile(net, x, y, bias):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, bias)
|
||||
|
||||
|
||||
def test_sum_as_loss_float16():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy0, strategy1):
|
||||
|
@ -52,7 +57,7 @@ def test_sum_as_loss_float16():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float16)
|
||||
y = Tensor(np.ones([64, 32]), dtype=ms.float16)
|
||||
bias = Tensor(np.ones([64]), dtype=ms.float16)
|
||||
_executor.compile(net, x, y, bias)
|
||||
compile(net, x, y, bias)
|
||||
|
||||
|
||||
def test_sum_as_loss_float32():
|
||||
|
@ -76,7 +81,7 @@ def test_sum_as_loss_float32():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
bias = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, bias)
|
||||
compile(net, x, y, bias)
|
||||
|
||||
|
||||
def test_sum_as_loss_int32():
|
||||
|
@ -100,4 +105,4 @@ def test_sum_as_loss_int32():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.int32)
|
||||
y = Tensor(np.ones([64, 32]), dtype=ms.int32)
|
||||
bias = Tensor(np.ones([64]), dtype=ms.int32)
|
||||
_executor.compile(net, x, y, bias)
|
||||
compile(net, x, y, bias)
|
||||
|
|
|
@ -52,6 +52,7 @@ _b = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
|||
def compile(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
|
|
@ -43,6 +43,11 @@ class GradWrap(nn.Cell):
|
|||
return C.grad_all(self.network)(x, y, b)
|
||||
|
||||
|
||||
def compile(net, x, y, b):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_pow():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
|
@ -66,7 +71,7 @@ def test_matmul_pow():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_exp():
|
||||
|
@ -92,7 +97,7 @@ def test_matmul_exp():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_log():
|
||||
|
@ -118,7 +123,7 @@ def test_matmul_log():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_logical_not():
|
||||
|
@ -145,7 +150,7 @@ def test_matmul_logical_not():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
def test_matmul_cast():
|
||||
class Net(nn.Cell):
|
||||
|
@ -171,7 +176,7 @@ def test_matmul_cast():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.int32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_cast_before_mirror():
|
||||
|
@ -195,7 +200,7 @@ def test_cast_before_mirror():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float16)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_cast_before_mirror1():
|
||||
|
@ -219,7 +224,7 @@ def test_cast_before_mirror1():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float16)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float16)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_cast_before_mirror2():
|
||||
|
@ -243,7 +248,7 @@ def test_cast_before_mirror2():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float16)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float16)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_cast_before_mirror3():
|
||||
|
@ -267,7 +272,7 @@ def test_cast_before_mirror3():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float16)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float16)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_mul_two_cast():
|
||||
|
@ -296,4 +301,4 @@ def test_mul_two_cast():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
|
|
@ -56,6 +56,7 @@ _b = Tensor(np.ones([128, 64, 32, 1]), dtype=ms.float32)
|
|||
def compile(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ _b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
|||
|
||||
|
||||
def compile(net):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
|
|
@ -52,6 +52,12 @@ class GradWrap(nn.Cell):
|
|||
def construct(self):
|
||||
return C.grad_by_list(self.network, self.weights)()
|
||||
|
||||
|
||||
def compile(net):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net)
|
||||
|
||||
|
||||
def test_get_next_single():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, channel=1, w=0.25):
|
||||
|
@ -87,7 +93,7 @@ def test_get_next_semi_auto_parallel():
|
|||
net_with_loss = NetWithLoss(network, [ms.float32, ms.int32],[[32,64], [32]], 2, strategy3=strategy3, strategy4=strategy4)
|
||||
net = GradWrap(net_with_loss)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
_executor.compile(net)
|
||||
compile(net)
|
||||
|
||||
def test_get_next_semi_auto_parallel1():
|
||||
class Net(nn.Cell):
|
||||
|
@ -109,7 +115,7 @@ def test_get_next_semi_auto_parallel1():
|
|||
net_with_loss = NetWithLoss(network, [ms.float32, ms.int32],[[32,64], [32]], 2, strategy3=strategy3, strategy4=strategy4)
|
||||
net = GradWrap(net_with_loss)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
_executor.compile(net)
|
||||
compile(net)
|
||||
|
||||
def test_get_next_auto_parallel():
|
||||
class Net(nn.Cell):
|
||||
|
@ -129,7 +135,7 @@ def test_get_next_auto_parallel():
|
|||
net_with_loss = NetWithLoss(network, [ms.float32, ms.int32],[[32,64], [32]], 2)
|
||||
net = GradWrap(net_with_loss)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
_executor.compile(net)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_only_one_get_next():
|
||||
|
@ -145,4 +151,4 @@ def test_only_one_get_next():
|
|||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
net = Net()
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
_executor.compile(net)
|
||||
compile(net)
|
||||
|
|
|
@ -45,13 +45,14 @@ def test_get_parameter_layout():
|
|||
weight = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
|
||||
net = Net(strategy1, strategy2, weight)
|
||||
net.set_auto_parallel()
|
||||
exe = me._executor
|
||||
exe.compile(net, x)
|
||||
exe.compile(net, x, auto_parallel_mode=True)
|
||||
x_layout = ([2, 4], [1, -1]) # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
weight_layout = ([2, 4], [0, -1]) # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
||||
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
|
||||
assert (net._parameter_layout_dict == expect_dict)
|
||||
assert (net.parameter_layout_dict == expect_dict)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -43,6 +43,10 @@ class GradWrap(nn.Cell):
|
|||
return C.grad_all(self.network)(x, y, b)
|
||||
|
||||
|
||||
def compile(net, x, y, b):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
|
||||
def test_matmul_tanh():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, strategy3):
|
||||
|
@ -66,7 +70,7 @@ def test_matmul_tanh():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_activation():
|
||||
|
@ -92,7 +96,7 @@ def test_matmul_activation():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_softmax():
|
||||
|
@ -118,7 +122,7 @@ def test_matmul_softmax():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_logsoftmax():
|
||||
|
@ -144,7 +148,7 @@ def test_matmul_logsoftmax():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_activations():
|
||||
|
@ -173,7 +177,7 @@ def test_activations():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
def test_activations_repeated_calculation():
|
||||
class Net(nn.Cell):
|
||||
|
@ -204,7 +208,7 @@ def test_activations_repeated_calculation():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_activations_axis_tuple():
|
||||
|
@ -236,4 +240,4 @@ def test_activations_axis_tuple():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
|
|
@ -65,6 +65,7 @@ def test_l2normalize_matmul():
|
|||
strategy3 = ((1, 1, 8), (1, 1, 8))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
|
|
|
@ -50,6 +50,7 @@ _b = Tensor(np.ones([128, 64, 32, 16]), dtype=ms.float32)
|
|||
def compile(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
|
|
@ -62,6 +62,7 @@ def test_linear():
|
|||
strategy3 = ((16, 1), (16, 1))
|
||||
net = GradWrap(NetWithLoss(Net(strategy0, strategy1, strategy2), strategy3))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
|
|
|
@ -90,6 +90,7 @@ def test_two_matmul():
|
|||
print(strategy1, strategy2)
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
count = count + 1
|
||||
|
||||
|
|
|
@ -35,6 +35,11 @@ class NetWithLoss(nn.Cell):
|
|||
return self.loss(predict, b)[0]
|
||||
|
||||
|
||||
def compile(net, x, b):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, b)
|
||||
|
||||
|
||||
def test_momentum():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, weight):
|
||||
|
@ -66,7 +71,7 @@ def test_momentum():
|
|||
train_net = TrainOneStepCell(net_with_loss, optimizer)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
_executor.compile(train_net, x, b)
|
||||
compile(train_net, x, b)
|
||||
|
||||
|
||||
def test_momentum_with_loss_scale():
|
||||
|
@ -100,7 +105,7 @@ def test_momentum_with_loss_scale():
|
|||
train_net = TrainOneStepCell(net_with_loss, optimizer)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
_executor.compile(train_net, x, b)
|
||||
compile(train_net, x, b)
|
||||
|
||||
|
||||
def test_momentum_with_dynamic_lr():
|
||||
|
@ -135,7 +140,7 @@ def test_momentum_with_dynamic_lr():
|
|||
train_net = TrainOneStepCell(net_with_loss, optimizer)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
_executor.compile(train_net, x, b)
|
||||
compile(train_net, x, b)
|
||||
|
||||
|
||||
def test_momentum_with_loss_scale_and_dynamic_lr():
|
||||
|
@ -171,7 +176,7 @@ def test_momentum_with_loss_scale_and_dynamic_lr():
|
|||
train_net = TrainOneStepCell(net_with_loss, optimizer)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
_executor.compile(train_net, x, b)
|
||||
compile(train_net, x, b)
|
||||
|
||||
def test_lars():
|
||||
class Net(nn.Cell):
|
||||
|
@ -205,4 +210,4 @@ def test_lars():
|
|||
train_net = TrainOneStepCell(net_with_loss, optimizer)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
_executor.compile(train_net, x, b)
|
||||
compile(train_net, x, b)
|
||||
|
|
|
@ -66,7 +66,7 @@ def test_two_matmul_dropout():
|
|||
strategy3 = ((1, 8), (8, 1))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
|
|
|
@ -45,6 +45,11 @@ class GradWrap(nn.Cell):
|
|||
return C.grad_all(self.network)(x, y)
|
||||
|
||||
|
||||
def compile(net, x, y):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
# model_parallel test
|
||||
def test_two_matmul():
|
||||
class Net(nn.Cell):
|
||||
|
@ -73,7 +78,7 @@ def test_two_matmul():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
|
||||
_executor.compile(net, x, y)
|
||||
compile(net, x, y)
|
||||
|
||||
|
||||
def test_matmul_mul_broadcast2():
|
||||
|
@ -97,8 +102,8 @@ def test_matmul_mul_broadcast2():
|
|||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
|
||||
_executor.compile(net, x, y)
|
||||
compile(net, x, y)
|
||||
|
||||
|
||||
def test_two_matmul1():
|
||||
class Net(nn.Cell):
|
||||
|
@ -127,7 +132,8 @@ def test_two_matmul1():
|
|||
x = Tensor(np.ones([128, 128]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 128]), dtype=ms.float32)
|
||||
|
||||
_executor.compile(net, x, y)
|
||||
compile(net, x, y)
|
||||
|
||||
|
||||
def test_matmul_add_tensor():
|
||||
class Net(nn.Cell):
|
||||
|
@ -151,4 +157,4 @@ def test_matmul_add_tensor():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
|
||||
_executor.compile(net, x, y)
|
||||
compile(net, x, y)
|
||||
|
|
|
@ -76,6 +76,7 @@ def test_two_matmul():
|
|||
strategy4 = ((2, 4), (4, 1))
|
||||
net = GradWrap(NetWithLoss(Net2(strategy1, strategy2, strategy3, strategy4).add_flags_recursive(fp16=True)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
|
|
|
@ -1,135 +0,0 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mindspore.train import Model, ParallelMode
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore import Tensor
|
||||
import mindspore as ms
|
||||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.parameter import Parameter
|
||||
from tests.dataset_mock import MindData
|
||||
from mindspore import context
|
||||
from mindspore.parallel._utils import _reset_op_id
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class Dataset(MindData):
|
||||
def __init__(self, predict, label, length=3):
|
||||
super(Dataset, self).__init__(size=length)
|
||||
self.predict = predict
|
||||
self.label = label
|
||||
self.index = 0
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.length:
|
||||
raise StopIteration
|
||||
self.index += 1
|
||||
return self.predict, self.label
|
||||
|
||||
def reset(self):
|
||||
self.index = 0
|
||||
|
||||
|
||||
class AllToAllNet(nn.Cell):
|
||||
def __init__(self, strategy1):
|
||||
super(AllToAllNet, self).__init__()
|
||||
self.matmul = P.MatMul().set_strategy(((1, 1), (1, 8)))
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight")
|
||||
self.transpose1 = P.Transpose().set_strategy(strategy1)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.matmul(x, self.matmul_weight)
|
||||
x = self.transpose1(x, (1, 0))
|
||||
return x
|
||||
|
||||
|
||||
def all_to_all_net(strategy1):
|
||||
return AllToAllNet(strategy1=strategy1)
|
||||
|
||||
|
||||
class ContextCallback(Callback):
|
||||
def begin(self, run_context):
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
assert parallel_mode == ParallelMode.STAND_ALONE
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
assert parallel_mode == ParallelMode.STAND_ALONE
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
assert parallel_mode == ParallelMode.STAND_ALONE
|
||||
|
||||
def step_begin(self, run_context):
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
assert parallel_mode == ParallelMode.STAND_ALONE
|
||||
|
||||
def step_end(self, run_context):
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
assert parallel_mode == ParallelMode.STAND_ALONE
|
||||
|
||||
def end(self, run_context):
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
assert parallel_mode == ParallelMode.STAND_ALONE
|
||||
|
||||
|
||||
def all_to_all_common(strategy1):
|
||||
learning_rate = 0.1
|
||||
momentum = 0.9
|
||||
epoch_size = 2
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8)
|
||||
predict = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([32]), dtype=ms.int32)
|
||||
dataset = Dataset(predict, label, 2)
|
||||
net = all_to_all_net(strategy1)
|
||||
|
||||
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
loss.softmax_cross_entropy.set_strategy(((8, 1), (8, 1)))
|
||||
opt = Momentum(net.trainable_params(), learning_rate, momentum)
|
||||
model = Model(net, loss, opt)
|
||||
|
||||
context_callback = ContextCallback()
|
||||
|
||||
model.train(epoch_size, dataset, dataset_sink_mode=False, callbacks=[context_callback])
|
||||
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
assert parallel_mode == ParallelMode.SEMI_AUTO_PARALLEL
|
||||
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=8)
|
||||
model.train(epoch_size, dataset, dataset_sink_mode=False, callbacks=[context_callback])
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
assert parallel_mode == ParallelMode.AUTO_PARALLEL
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_model_callback():
|
||||
strategy1 = ((8, 1), )
|
||||
_reset_op_id()
|
||||
all_to_all_common(strategy1)
|
||||
|
||||
|
||||
|
|
@ -41,6 +41,7 @@ _b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
|||
def compile(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
|
|
@ -271,6 +271,7 @@ def test_bn_reshape_dense_bn_train_loss():
|
|||
|
||||
net = GradWrap(NetWithLoss(BNReshapeDenseBNNet()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
|
||||
_executor.compile(net, input, label)
|
||||
|
||||
|
@ -284,6 +285,7 @@ def test_semi_one_hot_net_batch():
|
|||
net = SemiAutoOneHotNet(args=Args(), strategy=StrategyBatch())
|
||||
net = GradWrap(NetWithLoss(net))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
|
||||
_executor.compile(net, input, label)
|
||||
|
||||
|
|
|
@ -63,11 +63,11 @@ def test_one_weight_parameter():
|
|||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
|
||||
net = Net(strategy1, weight)
|
||||
print ("======================================dict", net.__dict__)
|
||||
|
||||
net_with_loss = NetWithLoss(net, strategy3)
|
||||
|
||||
train_net = OneStepCell(net_with_loss)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
train_net.set_auto_parallel()
|
||||
|
||||
_executor.compile(train_net, x, b)
|
||||
|
|
|
@ -64,6 +64,7 @@ class Net(nn.Cell):
|
|||
|
||||
def compile_graph(strategy1, strategy2, strategy3, strategy4, auto=False, onthot_axis=-1):
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2), strategy3, strategy4, axis=onthot_axis))
|
||||
net.set_auto_parallel()
|
||||
if auto:
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
else:
|
||||
|
|
|
@ -58,6 +58,6 @@ def test_dense_gen_graph():
|
|||
|
||||
predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.zeros([64, 32]).astype(np.float32))
|
||||
|
||||
network.set_auto_parallel()
|
||||
_executor.compile(network, predict, label)
|
||||
|
||||
|
||||
|
|
|
@ -34,6 +34,11 @@ class NetWithLoss(nn.Cell):
|
|||
return self.loss(predict, b)[0]
|
||||
|
||||
|
||||
def compile(net, x, b):
|
||||
net.set_auto_parallel()
|
||||
_Executor().compile(net, x, b)
|
||||
|
||||
|
||||
def test_optimizer_clone_weight():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, weight):
|
||||
|
@ -66,7 +71,7 @@ def test_optimizer_clone_weight():
|
|||
train_net = TrainOneStepCell(net_with_loss, optimizer)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
_Executor().compile(train_net, x, b)
|
||||
compile(train_net, x, b)
|
||||
|
||||
|
||||
def test_optimizer_clone_weight2():
|
||||
|
@ -101,4 +106,4 @@ def test_optimizer_clone_weight2():
|
|||
train_net = TrainOneStepCell(net_with_loss, optimizer)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
_Executor().compile(train_net, x, b)
|
||||
compile(train_net, x, b)
|
||||
|
|
|
@ -43,6 +43,11 @@ class GradWrap(nn.Cell):
|
|||
return C.grad_all(self.network)(x, y)
|
||||
|
||||
|
||||
def compile(net, x, y):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_prelu_single_success1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -57,7 +62,8 @@ def test_prelu_single_success1():
|
|||
net = GradWrap(NetWithLoss(Net()))
|
||||
x = Tensor(np.random.rand(1, 33, 4, 4), ms.float32)
|
||||
w = Tensor(np.random.rand(33), ms.float32)
|
||||
_executor.compile(net, x, w)
|
||||
compile(net, x, w)
|
||||
|
||||
|
||||
def test_prelu_single_success2():
|
||||
class Net(nn.Cell):
|
||||
|
@ -73,7 +79,8 @@ def test_prelu_single_success2():
|
|||
net = GradWrap(NetWithLoss(Net()))
|
||||
x = Tensor(np.random.rand(1, 33, 4, 4), ms.float32)
|
||||
w = Tensor([0.1], ms.float32)
|
||||
_executor.compile(net, x, w)
|
||||
compile(net, x, w)
|
||||
|
||||
|
||||
def test_prelu_parallel_success1():
|
||||
class Net(nn.Cell):
|
||||
|
@ -90,7 +97,8 @@ def test_prelu_parallel_success1():
|
|||
x = Tensor(np.random.rand(4, 4, 32, 64),dtype=ms.float32)
|
||||
w = Tensor(np.random.rand(4),dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net(strategy)))
|
||||
_executor.compile(net, x, w)
|
||||
compile(net, x, w)
|
||||
|
||||
|
||||
def test_prelu_parallel_success2():
|
||||
class Net(nn.Cell):
|
||||
|
@ -107,7 +115,8 @@ def test_prelu_parallel_success2():
|
|||
x = Tensor(np.random.rand(4, 4, 32, 64),dtype=ms.float32)
|
||||
w = Tensor(np.random.rand(4),dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net(strategy)))
|
||||
_executor.compile(net, x, w)
|
||||
compile(net, x, w)
|
||||
|
||||
|
||||
def test_prelu_parallel_success3():
|
||||
class NetWithLoss(nn.Cell):
|
||||
|
@ -148,8 +157,10 @@ def test_prelu_parallel_success3():
|
|||
y = Tensor(np.random.rand(64, 16),dtype=ms.float32)
|
||||
w = Tensor(np.random.rand(16),dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, w)
|
||||
|
||||
|
||||
def test_prelu_parallel_success4():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy):
|
||||
|
@ -165,7 +176,8 @@ def test_prelu_parallel_success4():
|
|||
x = Tensor(np.random.rand(4, 16, 32, 64),dtype=ms.float32)
|
||||
w = Tensor(np.random.rand(16),dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net(strategy)))
|
||||
_executor.compile(net, x, w)
|
||||
compile(net, x, w)
|
||||
|
||||
|
||||
def test_prelu_parallel_success5():
|
||||
class Net(nn.Cell):
|
||||
|
@ -182,5 +194,4 @@ def test_prelu_parallel_success5():
|
|||
x = Tensor(np.random.rand(4, 16, 32, 64),dtype=ms.float32)
|
||||
w = Tensor(np.random.rand(1),dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net(strategy)))
|
||||
_executor.compile(net, x, w)
|
||||
|
||||
compile(net, x, w)
|
||||
|
|
|
@ -42,6 +42,12 @@ class GradWrap(nn.Cell):
|
|||
def construct(self, x, y, b):
|
||||
return C.grad_all(self.network)(x, y, b)
|
||||
|
||||
|
||||
def compile(net, x, y, b):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
|
||||
|
||||
# model_parallel test
|
||||
def test_sum_mul():
|
||||
class Net(nn.Cell):
|
||||
|
@ -67,7 +73,8 @@ def test_sum_mul():
|
|||
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_sum_mul2():
|
||||
class Net(nn.Cell):
|
||||
|
@ -93,7 +100,8 @@ def test_sum_mul2():
|
|||
x = Tensor(np.ones([128, 128, 64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 128, 64, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_sum_mul3():
|
||||
class Net(nn.Cell):
|
||||
|
@ -119,7 +127,8 @@ def test_sum_mul3():
|
|||
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_sum_mul4():
|
||||
class Net(nn.Cell):
|
||||
|
@ -145,7 +154,7 @@ def test_sum_mul4():
|
|||
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 32, 1]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_sum_mul5():
|
||||
|
@ -169,7 +178,7 @@ def test_sum_mul5():
|
|||
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 32, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_sum_mul6():
|
||||
|
@ -193,7 +202,7 @@ def test_sum_mul6():
|
|||
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 1, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_sum_mul7():
|
||||
|
@ -217,7 +226,7 @@ def test_sum_mul7():
|
|||
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_max_mul():
|
||||
|
@ -244,7 +253,7 @@ def test_max_mul():
|
|||
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_min_mul():
|
||||
|
@ -271,7 +280,7 @@ def test_min_mul():
|
|||
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_reduce_mean_mul_float32():
|
||||
|
@ -299,7 +308,7 @@ def test_reduce_mean_mul_float32():
|
|||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
class ArgMaxWithValueNet(nn.Cell):
|
||||
|
@ -334,7 +343,7 @@ def gen_inputs_and_compile(net):
|
|||
x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def tobefixed_test_arg_max_with_value_mul_semi_axis_parallel():
|
||||
|
@ -467,7 +476,7 @@ def test_cross_batch():
|
|||
x = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_cross_batch2():
|
||||
|
@ -495,7 +504,7 @@ def test_cross_batch2():
|
|||
x = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_cross_batch_auto():
|
||||
|
@ -515,12 +524,11 @@ def test_cross_batch_auto():
|
|||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
|
||||
|
||||
x = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_max_empty_tuple():
|
||||
|
@ -548,4 +556,4 @@ def test_max_empty_tuple():
|
|||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
|
|
@ -303,6 +303,11 @@ class ReshapeNet6(nn.Cell):
|
|||
return matmul2_o
|
||||
|
||||
|
||||
def compile(net, input):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, input)
|
||||
|
||||
|
||||
def reshape_net2(backbone):
|
||||
batch_size = 16
|
||||
device_num = 16
|
||||
|
@ -312,7 +317,7 @@ def reshape_net2(backbone):
|
|||
net = GradWrap(NetWithLoss(backbone))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
_executor.compile(net, input)
|
||||
compile(net, input)
|
||||
|
||||
|
||||
def test_reshape_net1_1():
|
||||
|
@ -475,7 +480,7 @@ def test_batchnorm_reshape_train():
|
|||
|
||||
net = GradWrap(NetWithLoss(BatchNormReshapeNet()))
|
||||
|
||||
_executor.compile(net, input)
|
||||
compile(net, input)
|
||||
|
||||
|
||||
def bn_with_initialize(out_channels):
|
||||
|
@ -513,7 +518,7 @@ def test_bn_reshape_dense_bn_train():
|
|||
net = GradWrap(NetWithLoss(BNReshapeDenseBNNet()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
_executor.compile(net, input)
|
||||
compile(net, input)
|
||||
|
||||
|
||||
class ParallelReduceMeanNet(nn.Cell):
|
||||
|
|
|
@ -57,13 +57,18 @@ class Net(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
def compile(net, x, y):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_reshape_parameter_data_parallel():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy = ((8, 1, 1), (8, 1, 1))
|
||||
net = GradWrap(NetWithLoss(Net(strategy)))
|
||||
x = Tensor(np.ones([10000, 36]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([10000, 36, 1]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
compile(net, x, y)
|
||||
|
||||
|
||||
def test_reshape_parameter_model_parallel():
|
||||
|
@ -72,4 +77,4 @@ def test_reshape_parameter_model_parallel():
|
|||
net = GradWrap(NetWithLoss(Net(strategy)))
|
||||
x = Tensor(np.ones([10000, 36]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([10000, 36, 1]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
compile(net, x, y)
|
||||
|
|
|
@ -51,6 +51,7 @@ def test_sum_as_loss():
|
|||
strategy1 = ((4, 1), )
|
||||
net = GradWrap(Net(strategy0, strategy1))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
|
|
|
@ -105,4 +105,5 @@ def test_two_subgraphs():
|
|||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
net = TrainStepWrap(NetWithLoss(Net()))
|
||||
input_x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32)
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, input_x)
|
||||
|
|
|
@ -41,6 +41,7 @@ _b = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
|||
def compile(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
|
|
@ -25,5 +25,5 @@ def test_SoftmaxCrossEntropy():
|
|||
logit = Tensor(np.ones([64, 512]), dtype=mstype.float32)
|
||||
label = Tensor(np.ones([64]), dtype=mstype.int32)
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, logit, label)
|
||||
|
|
|
@ -42,6 +42,11 @@ class GradWrap(nn.Cell):
|
|||
return C.grad_all(self.network)(x, y, b)
|
||||
|
||||
|
||||
def compile(net, x, y, b):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_softmax_cross_entropy_loss():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
|
@ -64,7 +69,7 @@ def test_softmax_cross_entropy_loss():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_softmax_cross_entropy_loss_repeated_calculation():
|
||||
|
@ -89,7 +94,7 @@ def test_softmax_cross_entropy_loss_repeated_calculation():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_softmax_cross_entropy_loss_auto_batch_parallel():
|
||||
|
@ -111,4 +116,4 @@ def test_softmax_cross_entropy_loss_auto_batch_parallel():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
|
|
@ -53,6 +53,11 @@ class GradWrap3(nn.Cell):
|
|||
return C.grad_all(self.network)(x, y, bias)
|
||||
|
||||
|
||||
def compile(net, x, y, b):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_no_grad():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
|
@ -75,7 +80,7 @@ def test_no_grad():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_grad_sens_parameter_type():
|
||||
|
@ -103,6 +108,7 @@ def test_grad_sens_parameter_type():
|
|||
|
||||
sens = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
# net(x, y, b, sens)
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b, sens)
|
||||
|
||||
|
||||
|
@ -128,7 +134,7 @@ def test_grad_sens_tensor_type():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_grad_sens_scalar_broadcast():
|
||||
|
@ -152,4 +158,4 @@ def test_grad_sens_scalar_broadcast():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
bias = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, bias)
|
||||
compile(net, x, y, bias)
|
||||
|
|
|
@ -43,6 +43,7 @@ _b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
|||
def compile_net(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ _b = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
|||
|
||||
|
||||
def compile(net):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
|
|
@ -71,5 +71,5 @@ def test_two_matmul():
|
|||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 128]), dtype=ms.float32)
|
||||
a = Tensor(np.ones([128, 128]), dtype=ms.float32)
|
||||
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b, a)
|
||||
|
|
|
@ -81,7 +81,7 @@ def test_six_matmul_save():
|
|||
strategy6 = ((4, 1), (1, 2))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
net.set_auto_parallel()
|
||||
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x1, x6)
|
||||
|
@ -142,7 +142,7 @@ def test_six_matmul_load():
|
|||
strategy7 = ((8, 1), (1, 1))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
net.set_auto_parallel()
|
||||
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
|
@ -199,7 +199,7 @@ def test_six_matmul_save_auto():
|
|||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.ckpt")
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
|
||||
net.set_auto_parallel()
|
||||
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
_executor.compile(net, x1, x6)
|
||||
|
@ -258,7 +258,7 @@ def test_six_matmul_load_auto():
|
|||
strategy5 = ((2, 2), (2, 2))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5)))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
|
||||
net.set_auto_parallel()
|
||||
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
|
|
|
@ -31,6 +31,12 @@ class GradWrap(nn.Cell):
|
|||
def construct(self, x, y, bias):
|
||||
return C.grad_all(self.network)(x, y, bias)
|
||||
|
||||
|
||||
def compile(net, x, y, bias):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, bias)
|
||||
|
||||
|
||||
def test_sum_as_loss():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy0, strategy1):
|
||||
|
@ -53,7 +59,7 @@ def test_sum_as_loss():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
bias = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, bias)
|
||||
compile(net, x, y, bias)
|
||||
|
||||
|
||||
def test_sum_as_loss2():
|
||||
|
@ -78,4 +84,4 @@ def test_sum_as_loss2():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
bias = Tensor(np.ones([64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, bias)
|
||||
compile(net, x, y, bias)
|
||||
|
|
|
@ -43,6 +43,11 @@ class GradWrap(nn.Cell):
|
|||
return C.grad_all(self.network)(x, y, b)
|
||||
|
||||
|
||||
def compile(net, x, y, b):
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
|
||||
|
||||
# model_parallel test
|
||||
def test_two_matmul():
|
||||
class Net(nn.Cell):
|
||||
|
@ -66,7 +71,8 @@ def test_two_matmul():
|
|||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_two_matmul_repeated_calculation1():
|
||||
class Net(nn.Cell):
|
||||
|
@ -89,7 +95,7 @@ def test_two_matmul_repeated_calculation1():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
||||
|
||||
def test_two_matmul_repeated_calculation2():
|
||||
|
@ -113,4 +119,4 @@ def test_two_matmul_repeated_calculation2():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y, b)
|
||||
compile(net, x, y, b)
|
||||
|
|
|
@ -74,5 +74,5 @@ def test_two_weights_parameter():
|
|||
|
||||
train_net = OneStepCell(net_with_loss)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
train_net.set_auto_parallel()
|
||||
_executor.compile(train_net, x, b)
|
||||
|
|
|
@ -70,6 +70,7 @@ def test_virtual_dataset_3_input():
|
|||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 2048]), dtype=ms.float32)
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue