forked from mindspore-Ecosystem/mindspore
Support for to while
This commit is contained in:
parent
5e09a8cbb2
commit
56cd94cf60
|
@ -88,6 +88,7 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo
|
|||
FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
|
||||
|
||||
Parser::Parser(const std::shared_ptr<ParseAst> &ast) : ast_(ast) {
|
||||
max_for_loop_count_str_ = common::GetEnv("ENV_FOR_TO_WHILE_LOOP");
|
||||
errcode_ = PARSE_SUCCESS;
|
||||
BuildMethodMap();
|
||||
}
|
||||
|
@ -1170,18 +1171,18 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
|
|||
return body_block;
|
||||
}
|
||||
|
||||
int64_t GetForTransToWhileLoop() {
|
||||
static const auto loop_str = common::GetEnv("ENV_FOR_TO_WHILE_LOOP");
|
||||
int64_t Parser::GetForTransToWhileLoop() {
|
||||
// int64 support 63bits positive num mostly.
|
||||
if (loop_str.size() > 63 || loop_str.empty()) {
|
||||
if (max_for_loop_count_str_.size() > 63 || max_for_loop_count_str_.empty()) {
|
||||
return MAX_FOR_LOOP_COUNT;
|
||||
}
|
||||
if (std::any_of(loop_str.begin(), loop_str.end(), [](char c) { return c < '0' || c > '9'; })) {
|
||||
if (std::any_of(max_for_loop_count_str_.begin(), max_for_loop_count_str_.end(),
|
||||
[](char c) { return c < '0' || c > '9'; })) {
|
||||
return MAX_FOR_LOOP_COUNT;
|
||||
}
|
||||
int64_t loop_count;
|
||||
std::stringstream ss;
|
||||
ss << loop_str;
|
||||
ss << max_for_loop_count_str_;
|
||||
ss >> loop_count;
|
||||
return loop_count;
|
||||
}
|
||||
|
@ -1357,11 +1358,16 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
|
|||
MS_EXCEPTION_IF_NULL(body_block);
|
||||
body_block->AddPrevBlock(header_block);
|
||||
// Create 'x = xs[i]'
|
||||
CNodePtr target_var = body_block->func_graph()->NewCNodeInOrder({op_getitem, iter_node, loop_var});
|
||||
auto body_func_graph = body_block->func_graph();
|
||||
CNodePtr target_var = body_func_graph->NewCNodeInOrder({op_getitem, iter_node, loop_var});
|
||||
WriteAssignVars(body_block, target_node, target_var);
|
||||
// Create 'i = i + 1'
|
||||
CNodePtr loop_var_inc = body_block->func_graph()->NewCNodeInOrder(
|
||||
{NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(static_cast<int64_t>(1))});
|
||||
auto prim_add = prim::GetPythonOps("Add", "mindspore.ops.operations");
|
||||
auto add_node = body_func_graph->NewCNodeInOrder({NewValueNode(prim_add)});
|
||||
auto body_scalar_to_tensor_node = body_func_graph->NewCNodeInOrder({NewValueNode(scalar_to_tensor)});
|
||||
auto add_tensor_node =
|
||||
body_func_graph->NewCNodeInOrder({body_scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(1))});
|
||||
CNodePtr loop_var_inc = body_func_graph->NewCNodeInOrder({add_node, loop_var, add_tensor_node});
|
||||
body_block->WriteVariable(loop_var->name(), loop_var_inc);
|
||||
|
||||
// Link the variable name with the target
|
||||
|
@ -1377,7 +1383,9 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
|
|||
MS_EXCEPTION_IF_NULL(after_block);
|
||||
after_block->AddPrevBlock(header_block);
|
||||
|
||||
block->Jump(header_block, NewValueNode(static_cast<int64_t>(0)));
|
||||
CNodePtr zero_tensor =
|
||||
block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(0))});
|
||||
block->Jump(header_block, zero_tensor);
|
||||
body_block->Mature();
|
||||
|
||||
header_block->ConditionalJump(cond_node, body_block, after_block, false);
|
||||
|
|
|
@ -246,6 +246,7 @@ class Parser {
|
|||
}
|
||||
// return a make tuple for input elements list
|
||||
AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes);
|
||||
int64_t GetForTransToWhileLoop();
|
||||
|
||||
// shared_ptr will be hold by GraphManager, so just hold a weak ref here.
|
||||
static FuncGraphWeakPtr top_func_graph_;
|
||||
|
@ -267,6 +268,7 @@ class Parser {
|
|||
std::map<std::string, pExprFunc> expr_method_map_;
|
||||
// Save current loops to support 'continue', 'break' statement.
|
||||
std::stack<Loop> loops_;
|
||||
string max_for_loop_count_str_;
|
||||
};
|
||||
|
||||
// AST node type define code to ast
|
||||
|
|
|
@ -215,7 +215,10 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
|
|||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list),
|
||||
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (arg->GetValueTrack() != kAnyValue) {
|
||||
if (arg->isa<AbstractScalar>()) {
|
||||
auto config = abstract::AbstractBase::kBroadenScalarParameterOnly;
|
||||
return arg->Broaden(config);
|
||||
} else if (arg->GetValueTrack() != kAnyValue) {
|
||||
return arg->Broaden();
|
||||
}
|
||||
return arg;
|
||||
|
|
|
@ -90,7 +90,7 @@ std::string AbstractBase::ToString() const {
|
|||
}
|
||||
|
||||
AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) || config == kBroadenScalarParameterOnly) {
|
||||
return AbstractBase::Broaden(config);
|
||||
} else {
|
||||
return Clone();
|
||||
|
|
|
@ -75,9 +75,12 @@ class AbstractBase : public Base {
|
|||
// mask for Broaden config
|
||||
inline static const uint8_t kBroadenTensorOnly = 1;
|
||||
inline static const uint8_t kBroadenParameterOnly = 2;
|
||||
// Scalar as Parameter, should boarden
|
||||
inline static const uint8_t kBroadenScalarParameterOnly = 4;
|
||||
// Each bit for on config.
|
||||
// 00000001 -> 1: only boarden tensor
|
||||
// 00000010 -> 2: only boarden parameter
|
||||
// 00000100 -> 4: only boarden scalar parameter
|
||||
virtual AbstractBasePtr Broaden(uint8_t config = 0) const;
|
||||
virtual AbstractBasePtr Join(const AbstractBasePtr &) { return shared_from_base<AbstractBase>(); }
|
||||
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import pytest
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_single_for_01():
|
||||
class SingleForNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.Add()
|
||||
self.mul = P.Mul()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
x = self.add(x, y)
|
||||
for _ in range(0, 3):
|
||||
z = self.add(z, x)
|
||||
y = self.mul(z, y)
|
||||
return y
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor([2], mstype.int32)
|
||||
y = Tensor([5], mstype.int32)
|
||||
z = Tensor([4], mstype.int32)
|
||||
|
||||
os.environ['ENV_FOR_TO_WHILE_LOOP'] = '1'
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
for_net = SingleForNet()
|
||||
net = GradNet(for_net)
|
||||
graph_forward_res = for_net(x, y, z)
|
||||
graph_backward_res = net(x, y, z)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
for_net = SingleForNet()
|
||||
net = GradNet(for_net)
|
||||
pynative_forward_res = for_net(x, y, z)
|
||||
pynative_backward_res = net(x, y, z)
|
||||
os.environ['ENV_FOR_TO_WHILE_LOOP'] = ''
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
Loading…
Reference in New Issue