!49507 Fix compile cache abstract changed by other module

Merge pull request !49507 from chenfei_mindspore/master
This commit is contained in:
i-robot 2023-03-02 13:09:10 +00:00 committed by Gitee
commit ab8eefa123
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 79 additions and 2 deletions

View File

@ -165,7 +165,7 @@ void CheckAndConvertToVariableLenSequence(const py::object &obj, AbstractBasePtr
return;
}
if (!abs->isa<abstract::AbstractSequence>()) {
MS_EXCEPTION(TypeError) << "For mutable, when the variable_len the True, the first input should be"
MS_EXCEPTION(TypeError) << "For mutable, when the dynamic_len the True, the first input should be"
<< " list or tuple, but got: " << abs->ToString();
}
auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();

View File

@ -810,7 +810,9 @@ void DeleteDynamicLen(AnfNode *node) {
const auto &tuple_abs = node->abstract()->cast<abstract::AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(tuple_abs);
if (tuple_abs->dynamic_len()) {
tuple_abs->set_dynamic_len(false);
auto cloned_abstract = tuple_abs->Clone()->cast<abstract::AbstractSequencePtr>();
cloned_abstract->set_dynamic_len(false);
node->set_abstract(cloned_abstract);
}
}
}

View File

@ -477,6 +477,7 @@ std::string AbstractSequence::ToString() const {
}
ss << "}";
}
ss << ", dynamic_len:" << dynamic_len_;
ss << "}";
return ss.str();
}

View File

@ -0,0 +1,23 @@
import mindspore as ms
from mindspore import mutable
from mindspore import context
@ms.jit
def func(input_x, input_y, t):
output = input_x
for _ in range(2):
output = input_x + input_x * input_y + output
return output, t
context.set_context(precompile_only=True)
x = ms.Tensor([1], ms.dtype.float32)
y = ms.Tensor([2], ms.dtype.float32)
t1 = mutable((1,), dynamic_len=True)
t2 = mutable((1, 2,), dynamic_len=True)
out1 = func(x, y, t1)
out2 = func(x, y, t2)
print("out1:", out1)
print("out2:", out2)
context.set_context(precompile_only=False)

View File

@ -0,0 +1,51 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import subprocess
import pytest
def run_same_network_twice_in_one_process(file_name, log_file_name):
# Clear compile cache folder and log files
if os.path.exists(log_file_name):
os.remove(log_file_name)
assert not os.path.exists(log_file_name)
# First run without compile cache
cmd_first = f"GLOG_v=1 python " + file_name + " > " + log_file_name + " 2>&1"
subprocess.check_output(cmd_first, shell=True)
assert os.path.exists(log_file_name)
with open(log_file_name, "r") as f_first:
data_first = f_first.read()
assert "Generate a new compile key for new args, key: 0" in data_first
assert "Generate a new compile key for new args, key: 1" not in data_first
# Clean files
os.remove(log_file_name)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_mutable_compile_repeat():
"""
Feature: Repeating compile .
Description: If arg set as mutable(dynamic_len=True) , the different length list args should not cause repeating
compile.
Expectation: Network only compile once.
"""
run_same_network_twice_in_one_process("repeat_compile_mutable_script.py", "repeat_compile_mutable.log")