Fix pijit one stage slice problem

This commit is contained in:
liangzhibo 2024-02-01 10:31:48 +08:00 committed by r1chardf1d0
parent c56614d008
commit 01897887e3
4 changed files with 213 additions and 27 deletions

View File

@ -37,7 +37,6 @@ bool ShouldFallBackInRuntime(const PrimitivePtr &prim) {
kListInplaceClearOpName,
kDictInplaceSetItemOpName,
kRaiseOpName,
kMakeSliceOpName,
kJoinedStrOpName,
kFormatOpName};
return prims_should_fallback_in_runtime.find(prim->name()) != prims_should_fallback_in_runtime.end();
@ -336,21 +335,6 @@ bool FuncGraphBuilder::AddOutput(const py::object &output_obj) {
return true;
}
void FuncGraphBuilder::UpdatePyObject(const py::object &new_obj, const py::object &old_obj) {
if (new_obj.ptr() == old_obj.ptr()) {
return;
}
auto iter = py_obj_to_node_.find(old_obj.ptr());
if (iter == py_obj_to_node_.end()) {
return;
}
auto node = iter->second;
py_obj_to_node_.erase(iter);
(void)py_obj_to_node_.emplace(new_obj.ptr(), node);
MS_LOG(DEBUG) << "Update python object " << old_obj.ptr() << " to " << new_obj.ptr() << ". Corresponding node is "
<< node->DebugString();
}
FuncGraphPtr FuncGraphBuilder::graph() {
if (has_set_output_) {
return graph_;

View File

@ -68,12 +68,6 @@ class FuncGraphBuilder {
/// \return Return true if the output object can be used as the output of the graph.
bool AddOutput(const py::object &output_obj);
/// \brief Update key value for converted_py_obj_ map.
///
/// \param[in] new_obj The new python object as key.
/// \param[in] old_obj The old python object as key.
void UpdatePyObject(const py::object &new_obj, const py::object &old_obj);
/// \brief Remove an output node of the graph.
///
/// \param[in] output_obj The output python object.

View File

@ -33,6 +33,7 @@
#include "pipeline/jit/pi/graph_compiler/utils.h"
#include "ops/sequence_ops.h"
#include "ops/framework_ops.h"
#include "ops/structure_ops.h"
#ifndef PY_MINOR_VERSION
#define PY_MINOR_VERSION 3.7
@ -2657,7 +2658,7 @@ AObject *MindGraphBuilder::HandleMultiOp(const Instr &instr, const std::vector<V
return AObject::MakeAObject(AObject::kTypeAnyValue);
}
auto node = fg_builder_->AddMultiNode(op_name, input_obj);
return AObject::Convert(node);
return AbstractFuncGraphOut::MakeAObject(node);
}
AObject *MindGraphBuilder::HandleBuildOp(const Instr &instr, const std::vector<ValueNode *> &p) {
@ -2686,11 +2687,15 @@ AObject *MindGraphBuilder::HandleBuildOp(const Instr &instr, const std::vector<V
return AObject::MakeAObject(AObject::kTypeAnyValue);
}
}
if (primitive == prim::kPrimMakeSlice) {
constexpr size_t slice_without_step_len = 2;
if (input_obj.size() == slice_without_step_len) {
// Handle slice without step input scene, such as 0:2. MakeSlice can only handle slice with full inputs.
(void)input_obj.emplace_back(py::int_(1));
}
}
auto node = fg_builder_->AddNode(primitive, input_obj);
auto ret = AObject::Convert(node);
// Container object, such as list/tuple/dict will copy after Convert.
fg_builder_->UpdatePyObject(ret->GetPyObject(), node);
return ret;
return AbstractFuncGraphOut::MakeAObject(node);
}
bool MindGraphBuilder::DoGetItem(const Instr &instr) {

View File

@ -0,0 +1,203 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test basic operation with one stage"""
import pytest
import mindspore
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.common.api import jit
cfg = {
"replace_nncell_by_construct": True,
"print_after_all": False,
"trace_flag": True,
"print_bb": False,
"MAX_INLINE_DEPTH": 10,
"allowed_inline_modules": ["mindspore"], # buildsubgraph
}
mindspore.JitConfig(trace_flag=True)
context.set_context(device_target="CPU")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_make_tuple():
"""
Feature: One stage basic operation.
Description: Test one stage basic operation.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, x):
return (x, x+1, x+2)
net = Net()
a = Tensor([1])
jit(net.construct, mode="PIJit", jit_config=cfg)
ret = net(a)
assert isinstance(ret, tuple)
assert len(ret) == 3
assert ret[0] == Tensor([1])
assert ret[1] == Tensor([2])
assert ret[2] == Tensor([3])
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_make_list():
"""
Feature: One stage basic operation.
Description: Test one stage basic operation.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, x):
return [x, x+1, x+2]
net = Net()
a = Tensor([1])
jit(net.construct, mode="PIJit", jit_config=cfg)
ret = net(a)
assert isinstance(ret, list)
assert len(ret) == 3
assert ret[0] == Tensor([1])
assert ret[1] == Tensor([2])
assert ret[2] == Tensor([3])
@pytest.mark.skip(reason="DDE eliminate tuple input")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_tuple_slice():
"""
Feature: One stage basic operation.
Description: Test one stage basic operation.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, x):
m = (x, x+1, x+2)
return m[0:2:1]
net = Net()
a = Tensor([1])
jit(net.construct, mode="PIJit", jit_config=cfg)
ret = net(a)
assert isinstance(ret, tuple)
assert len(ret) == 2
assert ret[0] == Tensor([1])
assert ret[1] == Tensor([2])
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_list_slice():
"""
Feature: One stage basic operation.
Description: Test one stage basic operation.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, x):
m = [x, x+1, x+2]
return m[0:2:1]
net = Net()
a = Tensor([1])
jit(net.construct, mode="PIJit", jit_config=cfg)
ret = net(a)
assert isinstance(ret, list)
assert len(ret) == 2
assert ret[0] == Tensor([1])
assert ret[1] == Tensor([2])
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_list_slice_with_default_parameter():
"""
Feature: One stage basic operation.
Description: Test one stage basic operation.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, x):
m = [x, x+1, x+2]
return m[0:2]
net = Net()
a = Tensor([1])
jit(net.construct, mode="PIJit", jit_config=cfg)
ret = net(a)
assert isinstance(ret, list)
assert len(ret) == 2
assert ret[0] == Tensor([1])
assert ret[1] == Tensor([2])
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_list_slice_with_default_parameter_2():
"""
Feature: One stage basic operation.
Description: Test one stage basic operation.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, x):
m = [x, x+1, x+2]
return m[::]
net = Net()
a = Tensor([1])
jit(net.construct, mode="PIJit", jit_config=cfg)
ret = net(a)
assert isinstance(ret, list)
assert len(ret) == 3
assert ret[0] == Tensor([1])
assert ret[1] == Tensor([2])
assert ret[2] == Tensor([3])
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_list_slice_with_default_parameter_3():
"""
Feature: One stage basic operation.
Description: Test one stage basic operation.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, x):
m = [x, x+1, x+2]
return m[:]
net = Net()
a = Tensor([1])
jit(net.construct, mode="PIJit", jit_config=cfg)
ret = net(a)
assert isinstance(ret, list)
assert len(ret) == 3
assert ret[0] == Tensor([1])
assert ret[1] == Tensor([2])
assert ret[2] == Tensor([3])