Support for to while

This commit is contained in:
Margaret_wangrui 2021-05-13 15:49:19 +08:00
parent 5e09a8cbb2
commit 56cd94cf60
6 changed files with 100 additions and 11 deletions

View File

@ -88,6 +88,7 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo
FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
Parser::Parser(const std::shared_ptr<ParseAst> &ast) : ast_(ast) {
max_for_loop_count_str_ = common::GetEnv("ENV_FOR_TO_WHILE_LOOP");
errcode_ = PARSE_SUCCESS;
BuildMethodMap();
}
@ -1170,18 +1171,18 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
return body_block;
}
int64_t GetForTransToWhileLoop() {
static const auto loop_str = common::GetEnv("ENV_FOR_TO_WHILE_LOOP");
int64_t Parser::GetForTransToWhileLoop() {
// int64 support 63bits positive num mostly.
if (loop_str.size() > 63 || loop_str.empty()) {
if (max_for_loop_count_str_.size() > 63 || max_for_loop_count_str_.empty()) {
return MAX_FOR_LOOP_COUNT;
}
if (std::any_of(loop_str.begin(), loop_str.end(), [](char c) { return c < '0' || c > '9'; })) {
if (std::any_of(max_for_loop_count_str_.begin(), max_for_loop_count_str_.end(),
[](char c) { return c < '0' || c > '9'; })) {
return MAX_FOR_LOOP_COUNT;
}
int64_t loop_count;
std::stringstream ss;
ss << loop_str;
ss << max_for_loop_count_str_;
ss >> loop_count;
return loop_count;
}
@ -1357,11 +1358,16 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
MS_EXCEPTION_IF_NULL(body_block);
body_block->AddPrevBlock(header_block);
// Create 'x = xs[i]'
CNodePtr target_var = body_block->func_graph()->NewCNodeInOrder({op_getitem, iter_node, loop_var});
auto body_func_graph = body_block->func_graph();
CNodePtr target_var = body_func_graph->NewCNodeInOrder({op_getitem, iter_node, loop_var});
WriteAssignVars(body_block, target_node, target_var);
// Create 'i = i + 1'
CNodePtr loop_var_inc = body_block->func_graph()->NewCNodeInOrder(
{NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(static_cast<int64_t>(1))});
auto prim_add = prim::GetPythonOps("Add", "mindspore.ops.operations");
auto add_node = body_func_graph->NewCNodeInOrder({NewValueNode(prim_add)});
auto body_scalar_to_tensor_node = body_func_graph->NewCNodeInOrder({NewValueNode(scalar_to_tensor)});
auto add_tensor_node =
body_func_graph->NewCNodeInOrder({body_scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(1))});
CNodePtr loop_var_inc = body_func_graph->NewCNodeInOrder({add_node, loop_var, add_tensor_node});
body_block->WriteVariable(loop_var->name(), loop_var_inc);
// Link the variable name with the target
@ -1377,7 +1383,9 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
MS_EXCEPTION_IF_NULL(after_block);
after_block->AddPrevBlock(header_block);
block->Jump(header_block, NewValueNode(static_cast<int64_t>(0)));
CNodePtr zero_tensor =
block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(0))});
block->Jump(header_block, zero_tensor);
body_block->Mature();
header_block->ConditionalJump(cond_node, body_block, after_block, false);

View File

@ -246,6 +246,7 @@ class Parser {
}
// return a make tuple for input elements list
AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes);
int64_t GetForTransToWhileLoop();
// shared_ptr will be hold by GraphManager, so just hold a weak ref here.
static FuncGraphWeakPtr top_func_graph_;
@ -267,6 +268,7 @@ class Parser {
std::map<std::string, pExprFunc> expr_method_map_;
// Save current loops to support 'continue', 'break' statement.
std::stack<Loop> loops_;
string max_for_loop_count_str_;
};
// AST node type define code to ast

View File

@ -215,7 +215,10 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list),
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg);
if (arg->GetValueTrack() != kAnyValue) {
if (arg->isa<AbstractScalar>()) {
auto config = abstract::AbstractBase::kBroadenScalarParameterOnly;
return arg->Broaden(config);
} else if (arg->GetValueTrack() != kAnyValue) {
return arg->Broaden();
}
return arg;

View File

@ -90,7 +90,7 @@ std::string AbstractBase::ToString() const {
}
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) || config == kBroadenScalarParameterOnly) {
return AbstractBase::Broaden(config);
} else {
return Clone();

View File

@ -75,9 +75,12 @@ class AbstractBase : public Base {
// mask for Broaden config
inline static const uint8_t kBroadenTensorOnly = 1;
inline static const uint8_t kBroadenParameterOnly = 2;
// Scalar as Parameter, should boarden
inline static const uint8_t kBroadenScalarParameterOnly = 4;
// Each bit for on config.
// 00000001 -> 1: only boarden tensor
// 00000010 -> 2: only boarden parameter
// 00000100 -> 4: only boarden scalar parameter
virtual AbstractBasePtr Broaden(uint8_t config = 0) const;
virtual AbstractBasePtr Join(const AbstractBasePtr &) { return shared_from_base<AbstractBase>(); }

View File

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_single_for_01():
class SingleForNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.Add()
self.mul = P.Mul()
def construct(self, x, y, z):
x = self.add(x, y)
for _ in range(0, 3):
z = self.add(z, x)
y = self.mul(z, y)
return y
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
return grad_all(self.net)(*inputs)
x = Tensor([2], mstype.int32)
y = Tensor([5], mstype.int32)
z = Tensor([4], mstype.int32)
os.environ['ENV_FOR_TO_WHILE_LOOP'] = '1'
# graph mode
context.set_context(mode=context.GRAPH_MODE)
for_net = SingleForNet()
net = GradNet(for_net)
graph_forward_res = for_net(x, y, z)
graph_backward_res = net(x, y, z)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_net = SingleForNet()
net = GradNet(for_net)
pynative_forward_res = for_net(x, y, z)
pynative_backward_res = net(x, y, z)
os.environ['ENV_FOR_TO_WHILE_LOOP'] = ''
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res