!40107 Support parameter as input for Pynative ms_function

Merge pull request !40107 from JoyLvliang/refactor_pynative_ms_function
This commit is contained in:
i-robot 2022-08-11 12:15:57 +00:00 committed by Gitee
commit 98700af68c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 187 additions and 31 deletions

View File

@ -163,46 +163,52 @@ void MsFunction::UpdateMsFunctionForwardTensors(const FrontendOpRunInfoPtr &op_r
}
}
void MsFunction::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args,
ValuePtrList *input_values, CNodePtr *ms_function_cnode) const {
// Get input node info of ms_function
MS_EXCEPTION_IF_NULL(ms_func_graph);
std::vector<AnfNodePtr> input_nodes{NewValueNode(ms_func_graph)};
void MsFunction::GetInputArgsNode(const py::args &args, AnfNodePtrList *input_nodes, ValuePtrList *input_values) const {
MS_EXCEPTION_IF_NULL(input_nodes);
MS_EXCEPTION_IF_NULL(input_values);
const auto &grad_executor = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
for (size_t i = 0; i < args.size(); ++i) {
const auto &inp_i_value = PyNativeAlgo::DataConvert::PyObjToValue(args[i]);
const auto &input_i_node = grad_executor->GetInput(inp_i_value);
const auto &grad_exec = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
size_t input_args_size = args.size();
for (size_t i = 0; i < input_args_size; ++i) {
const auto &input_i_value = PyNativeAlgo::DataConvert::PyObjToValue(args[i]);
MS_LOG(DEBUG) << "The input " << i << " value of ms_function graph is: " << input_i_value->ToString();
(void)input_values->emplace_back(input_i_value);
const auto &input_i_node = grad_exec->GetInput(input_i_value);
MS_EXCEPTION_IF_NULL(input_i_node);
MS_LOG(DEBUG) << "The input " << i << " node of ms_function graph is: " << input_i_node->DebugString();
(void)input_nodes.emplace_back(input_i_node);
MS_LOG(DEBUG) << "The input " << i << " value of ms_function graph is: " << inp_i_value->ToString();
(void)(*input_values).emplace_back(inp_i_value);
(void)input_nodes->emplace_back(input_i_node);
}
}
// Get dfbuilder and graph info map
const auto &top_cell = grad_executor->top_cell();
void MsFunction::GetWeightsNode(const FuncGraphPtr &ms_func_graph, AnfNodePtrList *input_nodes,
ValuePtrList *input_values, const size_t input_args_index) const {
const auto &grad_exec = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
// Get graph info map.
const auto &top_cell = grad_exec->top_cell();
auto df_builder = top_cell->df_builder();
MS_EXCEPTION_IF_NULL(df_builder);
const auto &graph_info = top_cell->graph_info_map().at(df_builder);
MS_EXCEPTION_IF_NULL(graph_info);
// Get weights info of ms_function
std::vector<AnfNodePtr> new_params;
auto manage = Manage(ms_func_graph, false);
for (const auto &anf_node : ms_func_graph->parameters()) {
MS_EXCEPTION_IF_NULL(anf_node);
auto param = anf_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
if (!param->has_default()) {
(void)new_params.emplace_back(param);
const auto &original_params = ms_func_graph->parameters();
size_t params_size = original_params.size();
std::vector<AnfNodePtr> new_params;
for (size_t i = 0; i < params_size; ++i) {
if (i < input_args_index) { // non-weights node.
(void)new_params.emplace_back(original_params[i]);
continue;
}
const auto &anf_node = original_params[i];
auto param = anf_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
auto param_info = param->param_info();
MS_EXCEPTION_IF_NULL(param_info);
auto param_name = param_info->name();
if (graph_info->params.count(param_name) != 0) {
// Share same weight parameter in different ms_function call.
auto same_param = graph_info->params.at(param_name);
const auto &same_param = graph_info->params.at(param_name);
manage->Replace(anf_node, same_param);
param = same_param;
} else {
@ -210,19 +216,30 @@ void MsFunction::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const
param->debug_info()->set_name(param_name);
}
(void)new_params.emplace_back(param);
(void)input_nodes.emplace_back(param);
(void)(*input_values).emplace_back(param->default_param());
(void)input_nodes->emplace_back(param);
const auto &default_param = param->default_param();
MS_EXCEPTION_IF_NULL(default_param);
(void)input_values->emplace_back(default_param);
top_cell->SetParamNodeMapInGraphInfoMap(df_builder, param_name, param);
MS_LOG(DEBUG) << "Top graph set free parameter " << param->DebugString() << ". Its default value is "
<< param->default_param()->ToString() << ". Its name is: " << param_name;
<< default_param->ToString() << ". Its name is: " << param_name;
}
ms_func_graph->set_parameters(new_params);
manage->Clear();
}
void MsFunction::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args,
ValuePtrList *input_values, CNodePtr *ms_function_cnode) const {
// Get input node info of ms_function
std::vector<AnfNodePtr> input_nodes{NewValueNode(ms_func_graph)};
GetInputArgsNode(args, &input_nodes, input_values);
// Get weights node info of ms_function.
GetWeightsNode(ms_func_graph, &input_nodes, input_values, args.size());
// Make a CNode which includes ms_function fprop graph and inputs node
MS_EXCEPTION_IF_NULL(ms_function_cnode);
*ms_function_cnode = top_cell->fg()->NewCNode(input_nodes);
MS_LOG(DEBUG) << "Make ms function forward cnode: " << (*ms_function_cnode)->DebugString();
const auto &grad_exec = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
*ms_function_cnode = grad_exec->top_cell()->fg()->NewCNode(input_nodes);
MS_LOG(DEBUG) << "Make ms function forward CNode: " << (*ms_function_cnode)->DebugString();
}
// Make adjoint for ms_function fprop graph and connect it with previous op
@ -241,7 +258,6 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph,
// Connect grad graph of ms_function to context.
auto k_pynative_cell_ptr = top_cell->k_pynative_cell_ptr();
MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
MS_EXCEPTION_IF_NULL(grad_graph);
if (!k_pynative_cell_ptr->KPynativeWithFProp(ms_function_cnode, input_values, actual_out_v, grad_graph)) {
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: "
<< ms_function_cnode->DebugString();

View File

@ -41,22 +41,26 @@ class MsFunction {
private:
const std::string &graph_phase() const { return graph_phase_; }
void GradMsFunctionInner(const std::string &phase, const py::object &out, const py::args &args,
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) const;
// Update device address of value node in grad graph by forward tensors.
void RunReplace(const CNodePtr &added_make_tuple, const std::vector<tensor::TensorPtr> &total_output_tensors,
const FuncGraphPtr &grad_graph) const;
void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const FrontendOpRunInfoPtr &op_run_info,
const ValuePtr &added_out, const FuncGraphPtr &ms_func_graph,
const FuncGraphPtr &grad_graph) const;
void UpdateMsFunctionForwardTensors(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &new_forward_value) const;
// Make CNode for ms_function forward graph.
void GetInputArgsNode(const py::args &args, AnfNodePtrList *input_nodes, ValuePtrList *input_values) const;
void GetWeightsNode(const FuncGraphPtr &ms_func_graph, AnfNodePtrList *input_nodes, ValuePtrList *input_values,
const size_t input_args_index) const;
void MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args, ValuePtrList *input_values,
CNodePtr *ms_function_cnode) const;
// Make adjoint for ms_function fprop graph and connect it with previous op
CNodePtr MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
const py::object &actual_out, const py::args &args,
const ValuePtr &actual_out_v) const;
void GradMsFunctionInner(const std::string &phase, const py::object &out, const py::args &args,
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) const;
private:
// The graph phase is used to obtain backend graph that is complied by ms_function
std::string graph_phase_;
// Stores parameter in ms_function

View File

@ -0,0 +1,85 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import platform
import numpy as np
import pytest
from mindspore import nn
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore import ms_function, context
import mindspore.ops as ops
class PyNet(nn.Cell):
def __init__(self):
super(PyNet, self).__init__()
self.w1 = Parameter(Tensor(np.ones((2, 2), np.float32)), name="w1")
@ms_function
def construct(self, param_a, list_a, tuple_a, tensor_a, dict_a, param_b, tensor_b):
output = param_a + list_a[0] + tuple_a[1] - tensor_a - dict_a["x"] - param_b + tensor_b
output = output * self.w1
return output
class GraphNet(nn.Cell):
def __init__(self):
super(GraphNet, self).__init__()
self.w2 = Parameter(Tensor(np.ones((2, 2), np.float32)), name="w2")
def construct(self, param_x, list_x, tuple_x, tensor_x, dict_x, param_y, tensor_y):
output = param_x + list_x[0] + tuple_x[1] - tensor_x - dict_x["x"] - param_y + tensor_y
output = output * self.w2
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pynative_ms_function_support_parameter_as_input():
"""
Feature: PyNative ms_function support parameter as input.
Description: PyNative ms_function support parameter as input.
Expectation: The calculation result is correct.
"""
if platform.system() == 'Windows':
return
tensor_a = Tensor(np.ones((2, 2), np.float32))
tensor_b = Tensor(np.ones((2, 2), np.float32) * 2)
tuple_a = (Tensor(np.ones((2, 2), np.float32) * 3), Tensor(np.ones((2, 2), np.float32) * 4))
list_a = [Tensor(np.ones((2, 2), np.float32) * 5), Tensor(np.ones((2, 2), np.float32) * 6)]
dict_a = {"x": Tensor(np.ones((2, 2), np.float32) * 7), "y": Tensor(np.ones((2, 2), np.float32) * 8)}
param_a = Parameter(Tensor(np.ones((2, 2), np.float32)), name="param1")
param_b = Parameter(Tensor(np.ones((2, 2), np.float32) * 2), name="param2")
grad_op = ops.GradOperation(get_all=True, get_by_list=True)
context.set_context(mode=context.PYNATIVE_MODE)
net1 = PyNet()
output1 = grad_op(net1, ParameterTuple(net1.trainable_params()))(param_a, list_a, tuple_a, tensor_a, dict_a,
param_b, tensor_b)
context.set_context(mode=context.GRAPH_MODE)
net2 = GraphNet()
output2 = grad_op(net2, ParameterTuple(net2.trainable_params()))(param_a, list_a, tuple_a, tensor_a, dict_a,
param_b, tensor_b)
assert np.allclose(output1[0][0].asnumpy(), output2[0][0].asnumpy(), 0.000001, 0.000001)
assert np.allclose(output1[0][1].asnumpy(), output2[0][1].asnumpy(), 0.000001, 0.000001)
assert np.allclose(output1[0][2].asnumpy(), output2[0][2].asnumpy(), 0.000001, 0.000001)
assert np.allclose(output1[0][3].asnumpy(), output2[0][3].asnumpy(), 0.000001, 0.000001)
assert np.allclose(output1[1][0].asnumpy(), output2[1][0].asnumpy(), 0.000001, 0.000001)

View File

@ -0,0 +1,51 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.nn as nn
from mindspore.common import dtype as mstype
from mindspore import context, Tensor, Parameter, ParameterTuple, ms_function
import mindspore.ops as ops
class NetWithParamInput(nn.Cell):
def __init__(self):
super(NetWithParamInput, self).__init__()
self.w = Parameter(Tensor([6], mstype.float32))
@ms_function
def construct(self, x, y):
return (x + y) * self.w
def test_ms_func_parameter_input():
"""
Feature: ms_function support parameter as input in PyNative Mode.
Description: Using parameter as input for ms_function.
Expectation: Calculation result is correct.
"""
context.set_context(mode=context.PYNATIVE_MODE)
input_x = Tensor([1], mstype.float32)
input_param = Parameter(Tensor([2], mstype.float32), name="param")
net = NetWithParamInput()
# check forward run
out = net(input_x, input_param)
assert np.allclose(out.asnumpy(), 18, 0.000001, 0.000001)
# check grad
grad_op = ops.GradOperation(get_all=True, get_by_list=True)
gradient = grad_op(net, ParameterTuple(net.trainable_params()))(input_x, input_param)
assert np.allclose(gradient[0][0].asnumpy(), 6, 0.000001, 0.000001)
assert np.allclose(gradient[0][1].asnumpy(), 6, 0.000001, 0.000001)
assert np.allclose(gradient[1][0].asnumpy(), 3, 0.000001, 0.000001)