forked from mindspore-Ecosystem/mindspore
!31603 support str.format
Merge pull request !31603 from lianliguang/supported-str-format
This commit is contained in:
commit
a15f682803
|
@ -31,9 +31,8 @@ namespace pipeline {
|
||||||
|
|
||||||
BuiltInTypeMap &GetMethodMap() {
|
BuiltInTypeMap &GetMethodMap() {
|
||||||
static BuiltInTypeMap method_map = {{kObjectTypeString,
|
static BuiltInTypeMap method_map = {{kObjectTypeString,
|
||||||
{
|
{{"__bool__", std::string("str_bool")}, // C.str_bool
|
||||||
{"__bool__", std::string("str_bool")} // C.str_bool
|
{"format", std::string("_format")}}},
|
||||||
}},
|
|
||||||
{kMetaTypeNone,
|
{kMetaTypeNone,
|
||||||
{
|
{
|
||||||
{"__bool__", std::string("none_bool")} // C.none_bool
|
{"__bool__", std::string("none_bool")} // C.none_bool
|
||||||
|
|
|
@ -1187,17 +1187,14 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
|
||||||
REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) {
|
REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) {
|
||||||
MS_EXCEPTION_IF_NULL(old_conf);
|
MS_EXCEPTION_IF_NULL(old_conf);
|
||||||
AbstractBasePtr abstract = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
|
AbstractBasePtr abstract = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
|
||||||
AbstractFunctionPtr abs_func = dyn_cast<abstract::AbstractFunction>(abstract);
|
|
||||||
MS_EXCEPTION_IF_NULL(abs_func);
|
|
||||||
|
|
||||||
// Create new cnode
|
// Create new cnode
|
||||||
std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
|
std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
|
||||||
auto func_graph_func = dyn_cast<abstract::FuncGraphAbstractClosure>(abs_func);
|
auto func_graph_func = dyn_cast<abstract::FuncGraphAbstractClosure>(abstract);
|
||||||
if (func_graph_func != nullptr) {
|
if (func_graph_func != nullptr) {
|
||||||
FuncGraphPtr fg = func_graph_func->func_graph();
|
FuncGraphPtr fg = func_graph_func->func_graph();
|
||||||
input.push_back(NewValueNode(fg));
|
input.push_back(NewValueNode(fg));
|
||||||
} else {
|
} else {
|
||||||
auto prim_func = dyn_cast<abstract::PrimitiveAbstractClosure>(abs_func);
|
auto prim_func = dyn_cast<abstract::PrimitiveAbstractClosure>(abstract);
|
||||||
MS_EXCEPTION_IF_NULL(prim_func);
|
MS_EXCEPTION_IF_NULL(prim_func);
|
||||||
PrimitivePtr prim = prim_func->prim();
|
PrimitivePtr prim = prim_func->prim();
|
||||||
input.push_back(NewValueNode(prim));
|
input.push_back(NewValueNode(prim));
|
||||||
|
|
|
@ -30,6 +30,7 @@ from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add,
|
||||||
from ...ops.composite.base import _append, _insert
|
from ...ops.composite.base import _append, _insert
|
||||||
from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
|
from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
|
||||||
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
|
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
|
||||||
|
from ...ops.operations._inner_ops import Format
|
||||||
from ...ops.primitive import constexpr
|
from ...ops.primitive import constexpr
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,6 +42,7 @@ abs_ = P.Abs()
|
||||||
ndim_ = P.Rank()
|
ndim_ = P.Rank()
|
||||||
cumsum_ = P.CumSum()
|
cumsum_ = P.CumSum()
|
||||||
size_op_ = P.Size()
|
size_op_ = P.Size()
|
||||||
|
_format = Format()
|
||||||
_reduce_sum_default = P.ReduceSum()
|
_reduce_sum_default = P.ReduceSum()
|
||||||
_reduce_sum_keepdims = P.ReduceSum(True)
|
_reduce_sum_keepdims = P.ReduceSum(True)
|
||||||
_mean_keepdims = P.ReduceMean(True)
|
_mean_keepdims = P.ReduceMean(True)
|
||||||
|
|
|
@ -1872,3 +1872,42 @@ class CellBackwardHook(PrimitiveWithInfer):
|
||||||
None.
|
None.
|
||||||
"""
|
"""
|
||||||
self.remove_backward_hook_fn(key)
|
self.remove_backward_hook_fn(key)
|
||||||
|
|
||||||
|
|
||||||
|
class Format(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
This operator is used to format a string.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Current not supported to using by customer.
|
||||||
|
Only support convert str.format() in user code and it will be converted to be Format
|
||||||
|
operation by ME-Compiler automatically.
|
||||||
|
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input** -
|
||||||
|
string : the string to be formatted.
|
||||||
|
args : the format args.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- **output** - Returns formatted string.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
self.init_prim_io_names(inputs=['string', 'args'], outputs=['string'])
|
||||||
|
|
||||||
|
def __infer__(self, str_, *var):
|
||||||
|
str_value = str_["value"]
|
||||||
|
var_value = list()
|
||||||
|
if str_value is None and str_["dtype"] is not None:
|
||||||
|
raise ValueError("str.format not support to input a variable.")
|
||||||
|
for item in var:
|
||||||
|
if item["value"] is None and item["dtype"] is not None:
|
||||||
|
raise ValueError("str.format not support to input a variable.")
|
||||||
|
var_value.append(item["value"])
|
||||||
|
value = str_value.format(*var_value)
|
||||||
|
return {'dtype': mstype.string, 'shape': [], 'value': value}
|
||||||
|
|
|
@ -0,0 +1,251 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" test graph fallback """
|
||||||
|
import pytest
|
||||||
|
from mindspore import ms_function, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_format_single_input():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Test str.format() in graph mode.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def foo():
|
||||||
|
ms_str = "string is {}".format("1")
|
||||||
|
return ms_str
|
||||||
|
|
||||||
|
assert foo() == "string is 1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_format_mutiple_input():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Test str.format() in graph mode.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def foo():
|
||||||
|
ms_str = "{} is {}".format("string", "1")
|
||||||
|
return ms_str
|
||||||
|
|
||||||
|
assert foo() == "string is 1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_format_constant_tensor_input():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Test str.format() in graph mode.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def foo():
|
||||||
|
a = Tensor([1])
|
||||||
|
ms_str = "{} is {}".format("string", a)
|
||||||
|
return ms_str
|
||||||
|
|
||||||
|
assert foo() == "string is Tensor(shape=[1], dtype=Int64, value=[1])"
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_format_variable_input():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Test str.format() in graph mode.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def foo(b, a):
|
||||||
|
a = a + b
|
||||||
|
ms_str = "{} is {}".format("String", a)
|
||||||
|
return ms_str
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as ex:
|
||||||
|
foo(Tensor([1]), Tensor([1]))
|
||||||
|
assert "str.format not support to input a variable." in str(ex.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_str_format_input():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Test str.format() in graph mode.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def foo():
|
||||||
|
a = Tensor([1])
|
||||||
|
ms_str = format(a)
|
||||||
|
ms_str2 = format(ms_str, "4")
|
||||||
|
return ms_str, ms_str2
|
||||||
|
|
||||||
|
ms_str, ms_str2 = foo()
|
||||||
|
assert ms_str == "[1]"
|
||||||
|
assert ms_str2 == "[1] "
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_with_number_placeholder_input():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Test str.format() in graph mode.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def foo():
|
||||||
|
ms_str = "{1} {0} {1}"
|
||||||
|
ms_format_str = ms_str.format("hello", "world")
|
||||||
|
return ms_format_str
|
||||||
|
|
||||||
|
ms_str = foo()
|
||||||
|
assert ms_str == "world hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_with_key_input():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Test str.format() in graph mode.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def foo():
|
||||||
|
ms_str = "hello {name2},It's me,{name1}"
|
||||||
|
ms_format_str = ms_str.format(name2="Mind", name1="Spore")
|
||||||
|
return ms_format_str
|
||||||
|
|
||||||
|
with pytest.raises(TypeError) as ex:
|
||||||
|
foo()
|
||||||
|
assert "Unsupported parameter type for python primitive," \
|
||||||
|
" the parameter value is KeywordArg[key : name2, value : Mind]" in str(ex.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_with_list_index():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Test str.format() in graph mode.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def foo():
|
||||||
|
ms_str = "hello {0[1]},It's me {0[0]}"
|
||||||
|
names = ["Mind", "Spore"]
|
||||||
|
ms_format_str = ms_str.format(names)
|
||||||
|
return ms_format_str
|
||||||
|
|
||||||
|
result_st = foo()
|
||||||
|
assert result_st == "hello Spore,It's me Mind"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("Need to support kwargs input of primitive "
|
||||||
|
"operations same as test_format_with_key_input")
|
||||||
|
def test_format_with_map():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Test str.format() in graph mode.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def foo():
|
||||||
|
ms_str = "hello {0[name2]},It's me {0[name1]}"
|
||||||
|
names = {"name1": "Mind", "name2": "Spore"}
|
||||||
|
ms_format_str = ms_str.format(names)
|
||||||
|
return ms_format_str
|
||||||
|
|
||||||
|
result_st = foo()
|
||||||
|
assert result_st == "hello Spore,It's me Mind"
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_as_function():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Test str.format() in graph mode.
|
||||||
|
Expectation: No exception.git
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def foo():
|
||||||
|
func = "hello {0[1]},It's me {0[0]}".format
|
||||||
|
names = ["Mind", "Spore"]
|
||||||
|
ms_format_str = func(names)
|
||||||
|
return ms_format_str
|
||||||
|
|
||||||
|
result_st = foo()
|
||||||
|
assert result_st == "hello Spore,It's me Mind"
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_number():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Test str.format() in graph mode.
|
||||||
|
Expectation: No exception.git
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def foo():
|
||||||
|
num1 = 3.1415926
|
||||||
|
str1_format = "{:.2f}".format(num1)
|
||||||
|
str2_format = "{:.0f}".format(num1)
|
||||||
|
num2 = 1000000
|
||||||
|
str3_format = "{:,}".format(num2)
|
||||||
|
num3 = 0.25
|
||||||
|
str4_format = "{:.2%}".format(num3)
|
||||||
|
num4 = 1000000000
|
||||||
|
str5_format = "{:.2e}".format(num4)
|
||||||
|
num5 = 25
|
||||||
|
str6_format = "{0:b}".format(num5)
|
||||||
|
str7_format = "{0:d}".format(num5)
|
||||||
|
str8_format = "{0:o}".format(num5)
|
||||||
|
str9_format = "{0:x}".format(num5)
|
||||||
|
result = (str1_format, str2_format, str3_format, str4_format, str5_format,
|
||||||
|
str6_format, str7_format, str8_format, str9_format)
|
||||||
|
return result
|
||||||
|
|
||||||
|
correct_str = ("3.14", "3", "1,000,000", "25.00%", "1.00e+09", "11001", "25", "31", "19")
|
||||||
|
result_str = foo()
|
||||||
|
assert result_str == correct_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_padding():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Test str.format() in graph mode.
|
||||||
|
Expectation: No exception.git
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def foo():
|
||||||
|
num1 = 5
|
||||||
|
str1_format = "{:0>2}".format(num1)
|
||||||
|
str2_format = "{:x<4}".format(num1)
|
||||||
|
num2 = 10
|
||||||
|
str3_format = "{:x^4}".format(num2)
|
||||||
|
num3 = 13
|
||||||
|
str4_format = "{:10}".format(num3)
|
||||||
|
str5_format = "{:<10}".format(num3)
|
||||||
|
str6_format = "{:^10}".format(num3)
|
||||||
|
|
||||||
|
result = (str1_format, str2_format, str3_format, str4_format, str5_format, str6_format)
|
||||||
|
return result
|
||||||
|
|
||||||
|
correct_str = ("05", "5xxx", "x10x", " 13", "13 ", " 13 ")
|
||||||
|
result_str = foo()
|
||||||
|
assert result_str == correct_str
|
Loading…
Reference in New Issue