!31603 support str.format

Merge pull request !31603 from lianliguang/supported-str-format
This commit is contained in:
i-robot 2022-03-28 01:33:58 +00:00 committed by Gitee
commit a15f682803
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 296 additions and 8 deletions

View File

@ -31,9 +31,8 @@ namespace pipeline {
BuiltInTypeMap &GetMethodMap() { BuiltInTypeMap &GetMethodMap() {
static BuiltInTypeMap method_map = {{kObjectTypeString, static BuiltInTypeMap method_map = {{kObjectTypeString,
{ {{"__bool__", std::string("str_bool")}, // C.str_bool
{"__bool__", std::string("str_bool")} // C.str_bool {"format", std::string("_format")}}},
}},
{kMetaTypeNone, {kMetaTypeNone,
{ {
{"__bool__", std::string("none_bool")} // C.none_bool {"__bool__", std::string("none_bool")} // C.none_bool

View File

@ -1187,17 +1187,14 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) { REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) {
MS_EXCEPTION_IF_NULL(old_conf); MS_EXCEPTION_IF_NULL(old_conf);
AbstractBasePtr abstract = ToAbstract(value, AnalysisContext::DummyContext(), old_conf); AbstractBasePtr abstract = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
AbstractFunctionPtr abs_func = dyn_cast<abstract::AbstractFunction>(abstract);
MS_EXCEPTION_IF_NULL(abs_func);
// Create new cnode // Create new cnode
std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)}; std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
auto func_graph_func = dyn_cast<abstract::FuncGraphAbstractClosure>(abs_func); auto func_graph_func = dyn_cast<abstract::FuncGraphAbstractClosure>(abstract);
if (func_graph_func != nullptr) { if (func_graph_func != nullptr) {
FuncGraphPtr fg = func_graph_func->func_graph(); FuncGraphPtr fg = func_graph_func->func_graph();
input.push_back(NewValueNode(fg)); input.push_back(NewValueNode(fg));
} else { } else {
auto prim_func = dyn_cast<abstract::PrimitiveAbstractClosure>(abs_func); auto prim_func = dyn_cast<abstract::PrimitiveAbstractClosure>(abstract);
MS_EXCEPTION_IF_NULL(prim_func); MS_EXCEPTION_IF_NULL(prim_func);
PrimitivePtr prim = prim_func->prim(); PrimitivePtr prim = prim_func->prim();
input.push_back(NewValueNode(prim)); input.push_back(NewValueNode(prim));

View File

@ -30,6 +30,7 @@ from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add,
from ...ops.composite.base import _append, _insert from ...ops.composite.base import _append, _insert
from ...ops.composite.multitype_ops import _constexpr_utils as const_utils from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
from ...ops.composite.multitype_ops import _compile_utils as compile_utils from ...ops.composite.multitype_ops import _compile_utils as compile_utils
from ...ops.operations._inner_ops import Format
from ...ops.primitive import constexpr from ...ops.primitive import constexpr
@ -41,6 +42,7 @@ abs_ = P.Abs()
ndim_ = P.Rank() ndim_ = P.Rank()
cumsum_ = P.CumSum() cumsum_ = P.CumSum()
size_op_ = P.Size() size_op_ = P.Size()
_format = Format()
_reduce_sum_default = P.ReduceSum() _reduce_sum_default = P.ReduceSum()
_reduce_sum_keepdims = P.ReduceSum(True) _reduce_sum_keepdims = P.ReduceSum(True)
_mean_keepdims = P.ReduceMean(True) _mean_keepdims = P.ReduceMean(True)

View File

@ -1872,3 +1872,42 @@ class CellBackwardHook(PrimitiveWithInfer):
None. None.
""" """
self.remove_backward_hook_fn(key) self.remove_backward_hook_fn(key)
class Format(PrimitiveWithInfer):
r"""
This operator is used to format a string.
Note:
Current not supported to using by customer.
Only support convert str.format() in user code and it will be converted to be Format
operation by ME-Compiler automatically.
Inputs:
- **input** -
string : the string to be formatted.
args : the format args.
Outputs:
- **output** - Returns formatted string.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['string', 'args'], outputs=['string'])
def __infer__(self, str_, *var):
str_value = str_["value"]
var_value = list()
if str_value is None and str_["dtype"] is not None:
raise ValueError("str.format not support to input a variable.")
for item in var:
if item["value"] is None and item["dtype"] is not None:
raise ValueError("str.format not support to input a variable.")
var_value.append(item["value"])
value = str_value.format(*var_value)
return {'dtype': mstype.string, 'shape': [], 'value': value}

View File

@ -0,0 +1,251 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test graph fallback """
import pytest
from mindspore import ms_function, Tensor
def test_str_format_single_input():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo():
ms_str = "string is {}".format("1")
return ms_str
assert foo() == "string is 1"
def test_str_format_mutiple_input():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo():
ms_str = "{} is {}".format("string", "1")
return ms_str
assert foo() == "string is 1"
def test_str_format_constant_tensor_input():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo():
a = Tensor([1])
ms_str = "{} is {}".format("string", a)
return ms_str
assert foo() == "string is Tensor(shape=[1], dtype=Int64, value=[1])"
def test_str_format_variable_input():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo(b, a):
a = a + b
ms_str = "{} is {}".format("String", a)
return ms_str
with pytest.raises(ValueError) as ex:
foo(Tensor([1]), Tensor([1]))
assert "str.format not support to input a variable." in str(ex.value)
def test_fallback_str_format_input():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo():
a = Tensor([1])
ms_str = format(a)
ms_str2 = format(ms_str, "4")
return ms_str, ms_str2
ms_str, ms_str2 = foo()
assert ms_str == "[1]"
assert ms_str2 == "[1] "
def test_format_with_number_placeholder_input():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo():
ms_str = "{1} {0} {1}"
ms_format_str = ms_str.format("hello", "world")
return ms_format_str
ms_str = foo()
assert ms_str == "world hello world"
def test_format_with_key_input():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo():
ms_str = "hello {name2},It's me,{name1}"
ms_format_str = ms_str.format(name2="Mind", name1="Spore")
return ms_format_str
with pytest.raises(TypeError) as ex:
foo()
assert "Unsupported parameter type for python primitive," \
" the parameter value is KeywordArg[key : name2, value : Mind]" in str(ex.value)
def test_format_with_list_index():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo():
ms_str = "hello {0[1]},It's me {0[0]}"
names = ["Mind", "Spore"]
ms_format_str = ms_str.format(names)
return ms_format_str
result_st = foo()
assert result_st == "hello Spore,It's me Mind"
@pytest.mark.skip("Need to support kwargs input of primitive "
"operations same as test_format_with_key_input")
def test_format_with_map():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo():
ms_str = "hello {0[name2]},It's me {0[name1]}"
names = {"name1": "Mind", "name2": "Spore"}
ms_format_str = ms_str.format(names)
return ms_format_str
result_st = foo()
assert result_st == "hello Spore,It's me Mind"
def test_format_as_function():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.git
"""
@ms_function
def foo():
func = "hello {0[1]},It's me {0[0]}".format
names = ["Mind", "Spore"]
ms_format_str = func(names)
return ms_format_str
result_st = foo()
assert result_st == "hello Spore,It's me Mind"
def test_format_number():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.git
"""
@ms_function
def foo():
num1 = 3.1415926
str1_format = "{:.2f}".format(num1)
str2_format = "{:.0f}".format(num1)
num2 = 1000000
str3_format = "{:,}".format(num2)
num3 = 0.25
str4_format = "{:.2%}".format(num3)
num4 = 1000000000
str5_format = "{:.2e}".format(num4)
num5 = 25
str6_format = "{0:b}".format(num5)
str7_format = "{0:d}".format(num5)
str8_format = "{0:o}".format(num5)
str9_format = "{0:x}".format(num5)
result = (str1_format, str2_format, str3_format, str4_format, str5_format,
str6_format, str7_format, str8_format, str9_format)
return result
correct_str = ("3.14", "3", "1,000,000", "25.00%", "1.00e+09", "11001", "25", "31", "19")
result_str = foo()
assert result_str == correct_str
def test_format_padding():
"""
Feature: JIT Fallback
Description: Test str.format() in graph mode.
Expectation: No exception.git
"""
@ms_function
def foo():
num1 = 5
str1_format = "{:0>2}".format(num1)
str2_format = "{:x<4}".format(num1)
num2 = 10
str3_format = "{:x^4}".format(num2)
num3 = 13
str4_format = "{:10}".format(num3)
str5_format = "{:<10}".format(num3)
str6_format = "{:^10}".format(num3)
result = (str1_format, str2_format, str3_format, str4_format, str5_format, str6_format)
return result
correct_str = ("05", "5xxx", "x10x", " 13", "13 ", " 13 ")
result_str = foo()
assert result_str == correct_str