forked from mindspore-Ecosystem/mindspore
!40107 Support parameter as input for Pynative ms_function
Merge pull request !40107 from JoyLvliang/refactor_pynative_ms_function
This commit is contained in:
commit
98700af68c
|
@ -163,46 +163,52 @@ void MsFunction::UpdateMsFunctionForwardTensors(const FrontendOpRunInfoPtr &op_r
|
|||
}
|
||||
}
|
||||
|
||||
void MsFunction::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args,
|
||||
ValuePtrList *input_values, CNodePtr *ms_function_cnode) const {
|
||||
// Get input node info of ms_function
|
||||
MS_EXCEPTION_IF_NULL(ms_func_graph);
|
||||
std::vector<AnfNodePtr> input_nodes{NewValueNode(ms_func_graph)};
|
||||
void MsFunction::GetInputArgsNode(const py::args &args, AnfNodePtrList *input_nodes, ValuePtrList *input_values) const {
|
||||
MS_EXCEPTION_IF_NULL(input_nodes);
|
||||
MS_EXCEPTION_IF_NULL(input_values);
|
||||
const auto &grad_executor = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
const auto &inp_i_value = PyNativeAlgo::DataConvert::PyObjToValue(args[i]);
|
||||
const auto &input_i_node = grad_executor->GetInput(inp_i_value);
|
||||
const auto &grad_exec = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
|
||||
|
||||
size_t input_args_size = args.size();
|
||||
for (size_t i = 0; i < input_args_size; ++i) {
|
||||
const auto &input_i_value = PyNativeAlgo::DataConvert::PyObjToValue(args[i]);
|
||||
MS_LOG(DEBUG) << "The input " << i << " value of ms_function graph is: " << input_i_value->ToString();
|
||||
(void)input_values->emplace_back(input_i_value);
|
||||
|
||||
const auto &input_i_node = grad_exec->GetInput(input_i_value);
|
||||
MS_EXCEPTION_IF_NULL(input_i_node);
|
||||
MS_LOG(DEBUG) << "The input " << i << " node of ms_function graph is: " << input_i_node->DebugString();
|
||||
(void)input_nodes.emplace_back(input_i_node);
|
||||
MS_LOG(DEBUG) << "The input " << i << " value of ms_function graph is: " << inp_i_value->ToString();
|
||||
(void)(*input_values).emplace_back(inp_i_value);
|
||||
(void)input_nodes->emplace_back(input_i_node);
|
||||
}
|
||||
}
|
||||
|
||||
// Get dfbuilder and graph info map
|
||||
const auto &top_cell = grad_executor->top_cell();
|
||||
void MsFunction::GetWeightsNode(const FuncGraphPtr &ms_func_graph, AnfNodePtrList *input_nodes,
|
||||
ValuePtrList *input_values, const size_t input_args_index) const {
|
||||
const auto &grad_exec = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
|
||||
// Get graph info map.
|
||||
const auto &top_cell = grad_exec->top_cell();
|
||||
auto df_builder = top_cell->df_builder();
|
||||
MS_EXCEPTION_IF_NULL(df_builder);
|
||||
const auto &graph_info = top_cell->graph_info_map().at(df_builder);
|
||||
MS_EXCEPTION_IF_NULL(graph_info);
|
||||
// Get weights info of ms_function
|
||||
std::vector<AnfNodePtr> new_params;
|
||||
auto manage = Manage(ms_func_graph, false);
|
||||
for (const auto &anf_node : ms_func_graph->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto param = anf_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
if (!param->has_default()) {
|
||||
(void)new_params.emplace_back(param);
|
||||
const auto &original_params = ms_func_graph->parameters();
|
||||
size_t params_size = original_params.size();
|
||||
std::vector<AnfNodePtr> new_params;
|
||||
for (size_t i = 0; i < params_size; ++i) {
|
||||
if (i < input_args_index) { // non-weights node.
|
||||
(void)new_params.emplace_back(original_params[i]);
|
||||
continue;
|
||||
}
|
||||
const auto &anf_node = original_params[i];
|
||||
auto param = anf_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
auto param_info = param->param_info();
|
||||
MS_EXCEPTION_IF_NULL(param_info);
|
||||
auto param_name = param_info->name();
|
||||
if (graph_info->params.count(param_name) != 0) {
|
||||
// Share same weight parameter in different ms_function call.
|
||||
auto same_param = graph_info->params.at(param_name);
|
||||
const auto &same_param = graph_info->params.at(param_name);
|
||||
manage->Replace(anf_node, same_param);
|
||||
param = same_param;
|
||||
} else {
|
||||
|
@ -210,19 +216,30 @@ void MsFunction::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const
|
|||
param->debug_info()->set_name(param_name);
|
||||
}
|
||||
(void)new_params.emplace_back(param);
|
||||
(void)input_nodes.emplace_back(param);
|
||||
(void)(*input_values).emplace_back(param->default_param());
|
||||
(void)input_nodes->emplace_back(param);
|
||||
const auto &default_param = param->default_param();
|
||||
MS_EXCEPTION_IF_NULL(default_param);
|
||||
(void)input_values->emplace_back(default_param);
|
||||
top_cell->SetParamNodeMapInGraphInfoMap(df_builder, param_name, param);
|
||||
MS_LOG(DEBUG) << "Top graph set free parameter " << param->DebugString() << ". Its default value is "
|
||||
<< param->default_param()->ToString() << ". Its name is: " << param_name;
|
||||
<< default_param->ToString() << ". Its name is: " << param_name;
|
||||
}
|
||||
ms_func_graph->set_parameters(new_params);
|
||||
manage->Clear();
|
||||
}
|
||||
|
||||
void MsFunction::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args,
|
||||
ValuePtrList *input_values, CNodePtr *ms_function_cnode) const {
|
||||
// Get input node info of ms_function
|
||||
std::vector<AnfNodePtr> input_nodes{NewValueNode(ms_func_graph)};
|
||||
GetInputArgsNode(args, &input_nodes, input_values);
|
||||
// Get weights node info of ms_function.
|
||||
GetWeightsNode(ms_func_graph, &input_nodes, input_values, args.size());
|
||||
// Make a CNode which includes ms_function fprop graph and inputs node
|
||||
MS_EXCEPTION_IF_NULL(ms_function_cnode);
|
||||
*ms_function_cnode = top_cell->fg()->NewCNode(input_nodes);
|
||||
MS_LOG(DEBUG) << "Make ms function forward cnode: " << (*ms_function_cnode)->DebugString();
|
||||
const auto &grad_exec = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
|
||||
*ms_function_cnode = grad_exec->top_cell()->fg()->NewCNode(input_nodes);
|
||||
MS_LOG(DEBUG) << "Make ms function forward CNode: " << (*ms_function_cnode)->DebugString();
|
||||
}
|
||||
|
||||
// Make adjoint for ms_function fprop graph and connect it with previous op
|
||||
|
@ -241,7 +258,6 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph,
|
|||
// Connect grad graph of ms_function to context.
|
||||
auto k_pynative_cell_ptr = top_cell->k_pynative_cell_ptr();
|
||||
MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
|
||||
MS_EXCEPTION_IF_NULL(grad_graph);
|
||||
if (!k_pynative_cell_ptr->KPynativeWithFProp(ms_function_cnode, input_values, actual_out_v, grad_graph)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: "
|
||||
<< ms_function_cnode->DebugString();
|
||||
|
|
|
@ -41,22 +41,26 @@ class MsFunction {
|
|||
|
||||
private:
|
||||
const std::string &graph_phase() const { return graph_phase_; }
|
||||
void GradMsFunctionInner(const std::string &phase, const py::object &out, const py::args &args,
|
||||
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) const;
|
||||
// Update device address of value node in grad graph by forward tensors.
|
||||
void RunReplace(const CNodePtr &added_make_tuple, const std::vector<tensor::TensorPtr> &total_output_tensors,
|
||||
const FuncGraphPtr &grad_graph) const;
|
||||
void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const FrontendOpRunInfoPtr &op_run_info,
|
||||
const ValuePtr &added_out, const FuncGraphPtr &ms_func_graph,
|
||||
const FuncGraphPtr &grad_graph) const;
|
||||
void UpdateMsFunctionForwardTensors(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &new_forward_value) const;
|
||||
// Make CNode for ms_function forward graph.
|
||||
void GetInputArgsNode(const py::args &args, AnfNodePtrList *input_nodes, ValuePtrList *input_values) const;
|
||||
void GetWeightsNode(const FuncGraphPtr &ms_func_graph, AnfNodePtrList *input_nodes, ValuePtrList *input_values,
|
||||
const size_t input_args_index) const;
|
||||
void MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args, ValuePtrList *input_values,
|
||||
CNodePtr *ms_function_cnode) const;
|
||||
// Make adjoint for ms_function fprop graph and connect it with previous op
|
||||
CNodePtr MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
|
||||
const py::object &actual_out, const py::args &args,
|
||||
const ValuePtr &actual_out_v) const;
|
||||
void GradMsFunctionInner(const std::string &phase, const py::object &out, const py::args &args,
|
||||
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) const;
|
||||
|
||||
private:
|
||||
// The graph phase is used to obtain backend graph that is complied by ms_function
|
||||
std::string graph_phase_;
|
||||
// Stores parameter in ms_function
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import platform
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore import Tensor, Parameter, ParameterTuple
|
||||
from mindspore import ms_function, context
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class PyNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(PyNet, self).__init__()
|
||||
self.w1 = Parameter(Tensor(np.ones((2, 2), np.float32)), name="w1")
|
||||
|
||||
@ms_function
|
||||
def construct(self, param_a, list_a, tuple_a, tensor_a, dict_a, param_b, tensor_b):
|
||||
output = param_a + list_a[0] + tuple_a[1] - tensor_a - dict_a["x"] - param_b + tensor_b
|
||||
output = output * self.w1
|
||||
return output
|
||||
|
||||
|
||||
class GraphNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GraphNet, self).__init__()
|
||||
self.w2 = Parameter(Tensor(np.ones((2, 2), np.float32)), name="w2")
|
||||
|
||||
def construct(self, param_x, list_x, tuple_x, tensor_x, dict_x, param_y, tensor_y):
|
||||
output = param_x + list_x[0] + tuple_x[1] - tensor_x - dict_x["x"] - param_y + tensor_y
|
||||
output = output * self.w2
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_ms_function_support_parameter_as_input():
|
||||
"""
|
||||
Feature: PyNative ms_function support parameter as input.
|
||||
Description: PyNative ms_function support parameter as input.
|
||||
Expectation: The calculation result is correct.
|
||||
"""
|
||||
if platform.system() == 'Windows':
|
||||
return
|
||||
|
||||
tensor_a = Tensor(np.ones((2, 2), np.float32))
|
||||
tensor_b = Tensor(np.ones((2, 2), np.float32) * 2)
|
||||
tuple_a = (Tensor(np.ones((2, 2), np.float32) * 3), Tensor(np.ones((2, 2), np.float32) * 4))
|
||||
list_a = [Tensor(np.ones((2, 2), np.float32) * 5), Tensor(np.ones((2, 2), np.float32) * 6)]
|
||||
dict_a = {"x": Tensor(np.ones((2, 2), np.float32) * 7), "y": Tensor(np.ones((2, 2), np.float32) * 8)}
|
||||
param_a = Parameter(Tensor(np.ones((2, 2), np.float32)), name="param1")
|
||||
param_b = Parameter(Tensor(np.ones((2, 2), np.float32) * 2), name="param2")
|
||||
|
||||
grad_op = ops.GradOperation(get_all=True, get_by_list=True)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net1 = PyNet()
|
||||
output1 = grad_op(net1, ParameterTuple(net1.trainable_params()))(param_a, list_a, tuple_a, tensor_a, dict_a,
|
||||
param_b, tensor_b)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net2 = GraphNet()
|
||||
output2 = grad_op(net2, ParameterTuple(net2.trainable_params()))(param_a, list_a, tuple_a, tensor_a, dict_a,
|
||||
param_b, tensor_b)
|
||||
assert np.allclose(output1[0][0].asnumpy(), output2[0][0].asnumpy(), 0.000001, 0.000001)
|
||||
assert np.allclose(output1[0][1].asnumpy(), output2[0][1].asnumpy(), 0.000001, 0.000001)
|
||||
assert np.allclose(output1[0][2].asnumpy(), output2[0][2].asnumpy(), 0.000001, 0.000001)
|
||||
assert np.allclose(output1[0][3].asnumpy(), output2[0][3].asnumpy(), 0.000001, 0.000001)
|
||||
assert np.allclose(output1[1][0].asnumpy(), output2[1][0].asnumpy(), 0.000001, 0.000001)
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import context, Tensor, Parameter, ParameterTuple, ms_function
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class NetWithParamInput(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetWithParamInput, self).__init__()
|
||||
self.w = Parameter(Tensor([6], mstype.float32))
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, y):
|
||||
return (x + y) * self.w
|
||||
|
||||
|
||||
def test_ms_func_parameter_input():
|
||||
"""
|
||||
Feature: ms_function support parameter as input in PyNative Mode.
|
||||
Description: Using parameter as input for ms_function.
|
||||
Expectation: Calculation result is correct.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
input_x = Tensor([1], mstype.float32)
|
||||
input_param = Parameter(Tensor([2], mstype.float32), name="param")
|
||||
net = NetWithParamInput()
|
||||
# check forward run
|
||||
out = net(input_x, input_param)
|
||||
assert np.allclose(out.asnumpy(), 18, 0.000001, 0.000001)
|
||||
# check grad
|
||||
grad_op = ops.GradOperation(get_all=True, get_by_list=True)
|
||||
gradient = grad_op(net, ParameterTuple(net.trainable_params()))(input_x, input_param)
|
||||
assert np.allclose(gradient[0][0].asnumpy(), 6, 0.000001, 0.000001)
|
||||
assert np.allclose(gradient[0][1].asnumpy(), 6, 0.000001, 0.000001)
|
||||
assert np.allclose(gradient[1][0].asnumpy(), 3, 0.000001, 0.000001)
|
Loading…
Reference in New Issue