forked from mindspore-Ecosystem/mindspore
!20 modify _set_dataset_mode_config api param
Merge pull request !20 from jinyaohui/master
This commit is contained in:
commit
40d4a4baa3
|
@ -67,7 +67,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--mode", type=str, default="graph", help="Run graph mode or feed mode, default is graph")
|
||||
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or non-sink mode, default is sink")
|
||||
parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
|
||||
|
@ -150,8 +150,8 @@ if __name__ == '__main__':
|
|||
|
||||
model = Model(net)
|
||||
dataset_sink_mode = False
|
||||
if args_opt.mode == "graph":
|
||||
print("In graph mode, one epoch return a loss.")
|
||||
if args_opt.mode == "sink":
|
||||
print("In sink mode, one epoch return a loss.")
|
||||
dataset_sink_mode = True
|
||||
print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.")
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)
|
||||
|
|
|
@ -116,7 +116,7 @@ bool InitExecDatasetGe(const std::string& queue_name, int64_t size, int64_t batc
|
|||
return transform::TransformUtil::ConvertDataType(i->type_id());
|
||||
});
|
||||
|
||||
ConfigManager::GetInstance().set_dataset_mode(DatasetMode::DS_GRAPH_MODE);
|
||||
ConfigManager::GetInstance().set_dataset_mode(DatasetMode::DS_SINK_MODE);
|
||||
ConfigManager::GetInstance().set_iter_num(size);
|
||||
ConfigManager::GetInstance().set_dataset_phase(phase);
|
||||
|
||||
|
@ -453,8 +453,8 @@ void ProcessGeArg(const std::map<std::string, ExecutorInfoPtr>& info, const py::
|
|||
}
|
||||
|
||||
// process the first args of tensor
|
||||
// only in Dataset Feed Mode, fp_bp graph need input tensors
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_FEED_MODE) {
|
||||
// only in Dataset non-sink Mode, fp_bp graph need input tensors
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_NORMAL_MODE) {
|
||||
for (std::size_t i = 0; i < size; i++) {
|
||||
ValuePtr converted = nullptr;
|
||||
bool succ = parse::ConvertData(args[i], &converted);
|
||||
|
|
|
@ -442,10 +442,10 @@ void DfGraphConvertor::InitLoopVar(std::vector<ge::Operator> *init_input) {
|
|||
|
||||
int64_t value = 0;
|
||||
auto const_iter_num = std::make_shared<Constant>("const/npu_runconfig/iterations_per_loop");
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
|
||||
value = ConfigManager::GetInstance().iter_num();
|
||||
} else {
|
||||
MS_LOG(INFO) << "Run with feed mode, the iterator number will always be 1";
|
||||
MS_LOG(INFO) << "Run with non-sink mode, the iterator number will always be 1";
|
||||
value = 1;
|
||||
ConfigManager::GetInstance().set_iter_num(value);
|
||||
}
|
||||
|
@ -576,7 +576,7 @@ void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std
|
|||
|
||||
void DfGraphConvertor::MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it) {
|
||||
MS_LOG(INFO) << "The " << name << " is the " << input_idx << "(st/nd/th) input";
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
|
||||
auto getnext_idx = static_cast<int64_t>(input_idx);
|
||||
DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
|
||||
if (!param.input_indexes().empty() && input_idx <= param.input_indexes().size()) {
|
||||
|
@ -868,7 +868,7 @@ DfGraphConvertor &DfGraphConvertor::ConvertAllNode() {
|
|||
}
|
||||
|
||||
// Create dataset iterator and iterator_getnext node
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
|
||||
DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
|
||||
MS_LOG(INFO) << "Dataset param is " << param.ToString() << ".";
|
||||
// GetNext
|
||||
|
@ -977,7 +977,7 @@ void DfGraphConvertor::TraceOutputFromParameter(const AnfNodePtr &anf_out) {
|
|||
}
|
||||
|
||||
void SetupDatasetIterGetNextNode(const OperatorPtr &op) {
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
|
||||
DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
|
||||
size_t output_num = param.ge_types().size();
|
||||
MS_LOG(INFO) << "Set iterator_getnext op's output num = " << output_num << ".";
|
||||
|
@ -1036,7 +1036,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
|||
|
||||
// set graph input according to the order from anf graph
|
||||
std::vector<Operator> inputs;
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
|
||||
inputs.push_back(*dataset_iter_getnext_);
|
||||
} else {
|
||||
auto params = anf_graph_->parameters();
|
||||
|
|
|
@ -28,7 +28,7 @@ ConfigManager& ConfigManager::GetInstance() noexcept {
|
|||
}
|
||||
|
||||
void ConfigManager::SetDatasetModeConfig(const std::string& mode) {
|
||||
static const std::map<std::string, DatasetMode> mode_map = {{"feed", DS_FEED_MODE}, {"graph", DS_GRAPH_MODE}};
|
||||
static const std::map<std::string, DatasetMode> mode_map = {{"normal", DS_NORMAL_MODE}, {"sink", DS_SINK_MODE}};
|
||||
if (mode_map.find(mode) == mode_map.end()) {
|
||||
MS_LOG(ERROR) << "Invalid dataset mode:" << mode;
|
||||
return;
|
||||
|
@ -38,7 +38,7 @@ void ConfigManager::SetDatasetModeConfig(const std::string& mode) {
|
|||
|
||||
void ConfigManager::ResetConfig() noexcept {
|
||||
parallel_strategy_ = ONE_DEVICE;
|
||||
dataset_mode_ = DS_FEED_MODE;
|
||||
dataset_mode_ = DS_NORMAL_MODE;
|
||||
dataset_param_ = DatasetGraphParam("", 0, 0, {}, {}, {});
|
||||
iter_num_ = 1;
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ enum ParallelStrategy {
|
|||
DISTRIBUTION,
|
||||
};
|
||||
|
||||
enum DatasetMode { DS_FEED_MODE = 0, DS_GRAPH_MODE };
|
||||
enum DatasetMode { DS_NORMAL_MODE = 0, DS_SINK_MODE };
|
||||
|
||||
class DatasetGraphParam {
|
||||
public:
|
||||
|
@ -106,7 +106,7 @@ class ConfigManager {
|
|||
~ConfigManager() = default;
|
||||
|
||||
ParallelStrategy parallel_strategy_{ONE_DEVICE};
|
||||
DatasetMode dataset_mode_{DS_FEED_MODE};
|
||||
DatasetMode dataset_mode_{DS_NORMAL_MODE};
|
||||
DatasetGraphParam dataset_param_{"", 0, 0, {}, {}, {}};
|
||||
int64_t iter_num_{1};
|
||||
std::string dataset_phase_{""};
|
||||
|
|
|
@ -378,9 +378,9 @@ class _Executor:
|
|||
if enable_ge:
|
||||
# decide whether to sink based on whether the inputs is virtual or not
|
||||
if args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag:
|
||||
_set_dataset_mode_config('graph')
|
||||
_set_dataset_mode_config('sink')
|
||||
else:
|
||||
_set_dataset_mode_config('feed')
|
||||
_set_dataset_mode_config('normal')
|
||||
|
||||
self._build_data_graph(obj, params, phase)
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ class DynamicLossScaleUpdateCell(Cell):
|
|||
In every training step, the loss scaling value will be updated by loss scaling value/`scale_factor`
|
||||
when there is overflow. And it will be increased by loss scaling value * `scale_factor` if there is no
|
||||
overflow for a continuous `scale_window` steps. This cell is used for Graph mode training in which all
|
||||
logic will be executed on device side(Another training mode is feed mode in which some logic will be
|
||||
logic will be executed on device side(Another training mode is non-sink mode in which some logic will be
|
||||
executed on host).
|
||||
|
||||
Args:
|
||||
|
|
|
@ -24,11 +24,12 @@ from mindspore import context
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.train.callback import ModelCheckpoint, _check_file_name_prefix, RunContext,_checkpoint_cb_for_save_op,\
|
||||
LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist,\
|
||||
_build_callbacks, CheckpointConfig, _set_cur_net
|
||||
from mindspore.train.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, _checkpoint_cb_for_save_op, \
|
||||
LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
|
||||
_build_callbacks, CheckpointConfig, _set_cur_net
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""Net definition."""
|
||||
|
||||
|
@ -52,6 +53,7 @@ class Net(nn.Cell):
|
|||
|
||||
class LossNet(nn.Cell):
|
||||
""" LossNet definition """
|
||||
|
||||
def __init__(self):
|
||||
super(LossNet, self).__init__()
|
||||
self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
|
||||
|
@ -110,8 +112,8 @@ def test_save_checkpoint():
|
|||
os.remove('./test_files/test_ckpt-model.pkl')
|
||||
|
||||
|
||||
def test_loss_monitor_graph_model():
|
||||
"""Test lossmonitor Graph model."""
|
||||
def test_loss_monitor_sink_model():
|
||||
"""Test loss monitor sink model."""
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.cur_epoch_num = 4
|
||||
cb_params.cur_step_num = 2
|
||||
|
@ -129,8 +131,8 @@ def test_loss_monitor_graph_model():
|
|||
callbacklist.end(run_context)
|
||||
|
||||
|
||||
def test_Loss_Monitor_feed_feed_model():
|
||||
"""Test Loss Monitor feed feed mode."""
|
||||
def test_loss_monitor_feed_model():
|
||||
"""Test loss monitor non-sink mode."""
|
||||
cb_params = _InternalCallbackParam()
|
||||
run_context = RunContext(cb_params)
|
||||
loss_cb = LossMonitor(1)
|
||||
|
|
Loading…
Reference in New Issue