Fix pijit one stage slice problem
This commit is contained in:
parent
c56614d008
commit
01897887e3
|
@ -37,7 +37,6 @@ bool ShouldFallBackInRuntime(const PrimitivePtr &prim) {
|
|||
kListInplaceClearOpName,
|
||||
kDictInplaceSetItemOpName,
|
||||
kRaiseOpName,
|
||||
kMakeSliceOpName,
|
||||
kJoinedStrOpName,
|
||||
kFormatOpName};
|
||||
return prims_should_fallback_in_runtime.find(prim->name()) != prims_should_fallback_in_runtime.end();
|
||||
|
@ -336,21 +335,6 @@ bool FuncGraphBuilder::AddOutput(const py::object &output_obj) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void FuncGraphBuilder::UpdatePyObject(const py::object &new_obj, const py::object &old_obj) {
|
||||
if (new_obj.ptr() == old_obj.ptr()) {
|
||||
return;
|
||||
}
|
||||
auto iter = py_obj_to_node_.find(old_obj.ptr());
|
||||
if (iter == py_obj_to_node_.end()) {
|
||||
return;
|
||||
}
|
||||
auto node = iter->second;
|
||||
py_obj_to_node_.erase(iter);
|
||||
(void)py_obj_to_node_.emplace(new_obj.ptr(), node);
|
||||
MS_LOG(DEBUG) << "Update python object " << old_obj.ptr() << " to " << new_obj.ptr() << ". Corresponding node is "
|
||||
<< node->DebugString();
|
||||
}
|
||||
|
||||
FuncGraphPtr FuncGraphBuilder::graph() {
|
||||
if (has_set_output_) {
|
||||
return graph_;
|
||||
|
|
|
@ -68,12 +68,6 @@ class FuncGraphBuilder {
|
|||
/// \return Return true if the output object can be used as the output of the graph.
|
||||
bool AddOutput(const py::object &output_obj);
|
||||
|
||||
/// \brief Update key value for converted_py_obj_ map.
|
||||
///
|
||||
/// \param[in] new_obj The new python object as key.
|
||||
/// \param[in] old_obj The old python object as key.
|
||||
void UpdatePyObject(const py::object &new_obj, const py::object &old_obj);
|
||||
|
||||
/// \brief Remove an output node of the graph.
|
||||
///
|
||||
/// \param[in] output_obj The output python object.
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "pipeline/jit/pi/graph_compiler/utils.h"
|
||||
#include "ops/sequence_ops.h"
|
||||
#include "ops/framework_ops.h"
|
||||
#include "ops/structure_ops.h"
|
||||
|
||||
#ifndef PY_MINOR_VERSION
|
||||
#define PY_MINOR_VERSION 3.7
|
||||
|
@ -2657,7 +2658,7 @@ AObject *MindGraphBuilder::HandleMultiOp(const Instr &instr, const std::vector<V
|
|||
return AObject::MakeAObject(AObject::kTypeAnyValue);
|
||||
}
|
||||
auto node = fg_builder_->AddMultiNode(op_name, input_obj);
|
||||
return AObject::Convert(node);
|
||||
return AbstractFuncGraphOut::MakeAObject(node);
|
||||
}
|
||||
|
||||
AObject *MindGraphBuilder::HandleBuildOp(const Instr &instr, const std::vector<ValueNode *> &p) {
|
||||
|
@ -2686,11 +2687,15 @@ AObject *MindGraphBuilder::HandleBuildOp(const Instr &instr, const std::vector<V
|
|||
return AObject::MakeAObject(AObject::kTypeAnyValue);
|
||||
}
|
||||
}
|
||||
if (primitive == prim::kPrimMakeSlice) {
|
||||
constexpr size_t slice_without_step_len = 2;
|
||||
if (input_obj.size() == slice_without_step_len) {
|
||||
// Handle slice without step input scene, such as 0:2. MakeSlice can only handle slice with full inputs.
|
||||
(void)input_obj.emplace_back(py::int_(1));
|
||||
}
|
||||
}
|
||||
auto node = fg_builder_->AddNode(primitive, input_obj);
|
||||
auto ret = AObject::Convert(node);
|
||||
// Container object, such as list/tuple/dict will copy after Convert.
|
||||
fg_builder_->UpdatePyObject(ret->GetPyObject(), node);
|
||||
return ret;
|
||||
return AbstractFuncGraphOut::MakeAObject(node);
|
||||
}
|
||||
|
||||
bool MindGraphBuilder::DoGetItem(const Instr &instr) {
|
||||
|
|
|
@ -0,0 +1,203 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Test basic operation with one stage"""
|
||||
import pytest
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import jit
|
||||
|
||||
cfg = {
|
||||
"replace_nncell_by_construct": True,
|
||||
"print_after_all": False,
|
||||
"trace_flag": True,
|
||||
"print_bb": False,
|
||||
"MAX_INLINE_DEPTH": 10,
|
||||
"allowed_inline_modules": ["mindspore"], # buildsubgraph
|
||||
}
|
||||
mindspore.JitConfig(trace_flag=True)
|
||||
context.set_context(device_target="CPU")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_make_tuple():
|
||||
"""
|
||||
Feature: One stage basic operation.
|
||||
Description: Test one stage basic operation.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return (x, x+1, x+2)
|
||||
|
||||
net = Net()
|
||||
a = Tensor([1])
|
||||
jit(net.construct, mode="PIJit", jit_config=cfg)
|
||||
ret = net(a)
|
||||
assert isinstance(ret, tuple)
|
||||
assert len(ret) == 3
|
||||
assert ret[0] == Tensor([1])
|
||||
assert ret[1] == Tensor([2])
|
||||
assert ret[2] == Tensor([3])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_make_list():
|
||||
"""
|
||||
Feature: One stage basic operation.
|
||||
Description: Test one stage basic operation.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return [x, x+1, x+2]
|
||||
|
||||
net = Net()
|
||||
a = Tensor([1])
|
||||
jit(net.construct, mode="PIJit", jit_config=cfg)
|
||||
ret = net(a)
|
||||
assert isinstance(ret, list)
|
||||
assert len(ret) == 3
|
||||
assert ret[0] == Tensor([1])
|
||||
assert ret[1] == Tensor([2])
|
||||
assert ret[2] == Tensor([3])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="DDE eliminate tuple input")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tuple_slice():
|
||||
"""
|
||||
Feature: One stage basic operation.
|
||||
Description: Test one stage basic operation.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
m = (x, x+1, x+2)
|
||||
return m[0:2:1]
|
||||
|
||||
net = Net()
|
||||
a = Tensor([1])
|
||||
jit(net.construct, mode="PIJit", jit_config=cfg)
|
||||
ret = net(a)
|
||||
assert isinstance(ret, tuple)
|
||||
assert len(ret) == 2
|
||||
assert ret[0] == Tensor([1])
|
||||
assert ret[1] == Tensor([2])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_slice():
|
||||
"""
|
||||
Feature: One stage basic operation.
|
||||
Description: Test one stage basic operation.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
m = [x, x+1, x+2]
|
||||
return m[0:2:1]
|
||||
|
||||
net = Net()
|
||||
a = Tensor([1])
|
||||
jit(net.construct, mode="PIJit", jit_config=cfg)
|
||||
ret = net(a)
|
||||
assert isinstance(ret, list)
|
||||
assert len(ret) == 2
|
||||
assert ret[0] == Tensor([1])
|
||||
assert ret[1] == Tensor([2])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_slice_with_default_parameter():
|
||||
"""
|
||||
Feature: One stage basic operation.
|
||||
Description: Test one stage basic operation.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
m = [x, x+1, x+2]
|
||||
return m[0:2]
|
||||
|
||||
net = Net()
|
||||
a = Tensor([1])
|
||||
jit(net.construct, mode="PIJit", jit_config=cfg)
|
||||
ret = net(a)
|
||||
assert isinstance(ret, list)
|
||||
assert len(ret) == 2
|
||||
assert ret[0] == Tensor([1])
|
||||
assert ret[1] == Tensor([2])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_slice_with_default_parameter_2():
|
||||
"""
|
||||
Feature: One stage basic operation.
|
||||
Description: Test one stage basic operation.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
m = [x, x+1, x+2]
|
||||
return m[::]
|
||||
|
||||
net = Net()
|
||||
a = Tensor([1])
|
||||
jit(net.construct, mode="PIJit", jit_config=cfg)
|
||||
ret = net(a)
|
||||
assert isinstance(ret, list)
|
||||
assert len(ret) == 3
|
||||
assert ret[0] == Tensor([1])
|
||||
assert ret[1] == Tensor([2])
|
||||
assert ret[2] == Tensor([3])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_slice_with_default_parameter_3():
|
||||
"""
|
||||
Feature: One stage basic operation.
|
||||
Description: Test one stage basic operation.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
m = [x, x+1, x+2]
|
||||
return m[:]
|
||||
|
||||
net = Net()
|
||||
a = Tensor([1])
|
||||
jit(net.construct, mode="PIJit", jit_config=cfg)
|
||||
ret = net(a)
|
||||
assert isinstance(ret, list)
|
||||
assert len(ret) == 3
|
||||
assert ret[0] == Tensor([1])
|
||||
assert ret[1] == Tensor([2])
|
||||
assert ret[2] == Tensor([3])
|
Loading…
Reference in New Issue