forked from mindspore-Ecosystem/mindspore
!1261 Incremental subgraph initialization
Merge pull request !1261 from changzherui/add_tensor_init
This commit is contained in:
commit
56c3fed30e
|
@ -374,6 +374,10 @@ TypeId Tensor::set_data_type(const TypeId data_type) {
|
|||
return data_type_;
|
||||
}
|
||||
|
||||
bool Tensor::is_init() { return init_flag_; }
|
||||
|
||||
void Tensor::set_init_flag(bool flag) { init_flag_ = flag; }
|
||||
|
||||
bool Tensor::convert_data(const py::array &in, const TypeId in_data_type, py::array *const out,
|
||||
const TypeId out_data_type) {
|
||||
if (out == nullptr) {
|
||||
|
@ -499,6 +503,24 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
>>> data.size()
|
||||
6
|
||||
)mydelimiter")
|
||||
.def("is_init", &Tensor::is_init, R"mydelimiter(
|
||||
Get tensor init_flag.
|
||||
|
||||
Returns:
|
||||
bool, whether the tensor init.
|
||||
|
||||
Examples:
|
||||
>>> data = mindspore.Tensor(np.ones((2, 3)))
|
||||
>>> data.is_init()
|
||||
False
|
||||
)mydelimiter")
|
||||
.def("set_init_flag", &Tensor::set_init_flag, R"mydelimiter(
|
||||
Set tensor init_flag.
|
||||
|
||||
Examples:
|
||||
>>> data = mindspore.Tensor(np.ones((2, 3)))
|
||||
>>> data.set_init_flag(True)
|
||||
)mydelimiter")
|
||||
.def("dim", &Tensor::DataDim, R"mydelimiter(
|
||||
Get tensor's data dimension.
|
||||
|
||||
|
|
|
@ -389,6 +389,8 @@ class Tensor : public MetaTensor {
|
|||
std::string ToStringRepr() const;
|
||||
py::array data_; // < Tensor's data value
|
||||
const bool parse_info_ = true;
|
||||
bool is_init();
|
||||
void set_init_flag(bool flag);
|
||||
|
||||
private:
|
||||
// brief init tensor
|
||||
|
@ -398,7 +400,7 @@ class Tensor : public MetaTensor {
|
|||
// return true if succeed, false if failed.
|
||||
void init(const py::array &input, const TypeId &data_type);
|
||||
void init(const py::array &input, const TypePtr &type_ptr);
|
||||
|
||||
bool init_flag_{false};
|
||||
// brief init tensor attribute
|
||||
//
|
||||
// param data_type [TypeId] Data type of the tensor.
|
||||
|
|
|
@ -649,7 +649,6 @@ void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) {
|
|||
if (adpt == nullptr) continue;
|
||||
auto param_op = adpt->generate(name + "_data");
|
||||
MS_LOG(INFO) << "Add parameter " << name << " as input, index " << index << ".";
|
||||
(void)std::static_pointer_cast<Data>(param_op)->set_attr_index(index++);
|
||||
|
||||
if (!training_) {
|
||||
auto adpt_const = FindAdapter(kNameConst, training_);
|
||||
|
@ -678,14 +677,17 @@ void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) {
|
|||
|
||||
// we need three variable ops for each graph with same name
|
||||
// build init subgraph
|
||||
auto init_var = std::make_shared<Variable>(name);
|
||||
auto assign_op = std::make_shared<Assign>("assign_" + name);
|
||||
(void)init_var->update_output_desc_y(*desc);
|
||||
(void)assign_op->set_input_ref(*init_var).set_input_value(*param_op);
|
||||
init_input.push_back(*init_var);
|
||||
init_ops_.push_back(param_op);
|
||||
init_ops_.push_back(assign_op);
|
||||
init_ops_.push_back(init_var);
|
||||
if (it.second->is_init() == 0) {
|
||||
(void)std::static_pointer_cast<Data>(param_op)->set_attr_index(index++);
|
||||
auto init_var = std::make_shared<Variable>(name);
|
||||
auto assign_op = std::make_shared<Assign>("assign_" + name);
|
||||
(void)init_var->update_output_desc_y(*desc);
|
||||
(void)assign_op->set_input_ref(*init_var).set_input_value(*param_op);
|
||||
init_input.push_back(*init_var);
|
||||
init_ops_.push_back(param_op);
|
||||
init_ops_.push_back(assign_op);
|
||||
init_ops_.push_back(init_var);
|
||||
}
|
||||
|
||||
auto variable = std::make_shared<Variable>(name);
|
||||
(void)variable->update_output_desc_y(*desc);
|
||||
|
|
|
@ -82,14 +82,15 @@ def _wrap_func(fn):
|
|||
def _exec_init_graph(obj, init_phase):
|
||||
"""Execute the parameter initializer graph."""
|
||||
inst_executor = Executor_.get_instance()
|
||||
exec_init_graph = False
|
||||
for param in obj.get_parameters():
|
||||
param_dict = OrderedDict()
|
||||
for name, param in obj.parameters_dict().items():
|
||||
if not param.is_init:
|
||||
param_dict[name] = param
|
||||
param.is_init = True
|
||||
exec_init_graph = True
|
||||
param.data.init_flag = True
|
||||
|
||||
if exec_init_graph:
|
||||
inst_executor.run_init_graph(obj.parameters_dict(), init_phase)
|
||||
if param_dict:
|
||||
inst_executor.run_init_graph(param_dict, init_phase)
|
||||
|
||||
|
||||
class _MindSporeFunction:
|
||||
|
|
|
@ -188,11 +188,14 @@ class Parameter:
|
|||
if isinstance(data, Tensor):
|
||||
# make a copy of Tensor to init the parameter
|
||||
data = Tensor(data.asnumpy().copy())
|
||||
data.init_flag = False
|
||||
elif isinstance(data, Initializer):
|
||||
self.init_mode = data
|
||||
data = MetaTensor(self.init_mode.dtype, self.init_mode.shape)
|
||||
else:
|
||||
data = Tensor(data)
|
||||
data.init_flag = False
|
||||
|
||||
self.default_input = data
|
||||
|
||||
|
||||
|
|
|
@ -65,6 +65,7 @@ class Tensor(Tensor_):
|
|||
else:
|
||||
super(Tensor, self).__init__(input_data, dtype)
|
||||
self._virtual_flag = False
|
||||
self._init_flag = False
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.__str__())
|
||||
|
@ -153,3 +154,16 @@ class Tensor(Tensor_):
|
|||
if not isinstance(value, bool):
|
||||
raise TypeError("virtual_flag must be bool.")
|
||||
self._virtual_flag = value
|
||||
|
||||
@property
|
||||
def init_flag(self):
|
||||
"""whether the tensor is init."""
|
||||
return self._init_flag
|
||||
|
||||
@init_flag.setter
|
||||
def init_flag(self, value):
|
||||
"""Set the tensor is init_flag."""
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError("init_flag must be bool.")
|
||||
self.set_init_flag(value)
|
||||
self._init_flag = value
|
||||
|
|
Loading…
Reference in New Issue