forked from mindspore-Ecosystem/mindspore
support str.format
This commit is contained in:
parent
32fb08ce7f
commit
6304c2ef3d
|
@ -31,9 +31,8 @@ namespace pipeline {
|
|||
|
||||
BuiltInTypeMap &GetMethodMap() {
|
||||
static BuiltInTypeMap method_map = {{kObjectTypeString,
|
||||
{
|
||||
{"__bool__", std::string("str_bool")} // C.str_bool
|
||||
}},
|
||||
{{"__bool__", std::string("str_bool")}, // C.str_bool
|
||||
{"format", std::string("_format")}}},
|
||||
{kMetaTypeNone,
|
||||
{
|
||||
{"__bool__", std::string("none_bool")} // C.none_bool
|
||||
|
|
|
@ -1175,17 +1175,14 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
|
|||
REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) {
|
||||
MS_EXCEPTION_IF_NULL(old_conf);
|
||||
AbstractBasePtr abstract = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
|
||||
AbstractFunctionPtr abs_func = dyn_cast<abstract::AbstractFunction>(abstract);
|
||||
MS_EXCEPTION_IF_NULL(abs_func);
|
||||
|
||||
// Create new cnode
|
||||
std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
|
||||
auto func_graph_func = dyn_cast<abstract::FuncGraphAbstractClosure>(abs_func);
|
||||
auto func_graph_func = dyn_cast<abstract::FuncGraphAbstractClosure>(abstract);
|
||||
if (func_graph_func != nullptr) {
|
||||
FuncGraphPtr fg = func_graph_func->func_graph();
|
||||
input.push_back(NewValueNode(fg));
|
||||
} else {
|
||||
auto prim_func = dyn_cast<abstract::PrimitiveAbstractClosure>(abs_func);
|
||||
auto prim_func = dyn_cast<abstract::PrimitiveAbstractClosure>(abstract);
|
||||
MS_EXCEPTION_IF_NULL(prim_func);
|
||||
PrimitivePtr prim = prim_func->prim();
|
||||
input.push_back(NewValueNode(prim));
|
||||
|
|
|
@ -30,6 +30,7 @@ from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add,
|
|||
from ...ops.composite.base import _append, _insert
|
||||
from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
|
||||
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
|
||||
from ...ops.operations._inner_ops import Format
|
||||
from ...ops.primitive import constexpr
|
||||
|
||||
|
||||
|
@ -41,6 +42,7 @@ abs_ = P.Abs()
|
|||
ndim_ = P.Rank()
|
||||
cumsum_ = P.CumSum()
|
||||
size_op_ = P.Size()
|
||||
_format = Format()
|
||||
_reduce_sum_default = P.ReduceSum()
|
||||
_reduce_sum_keepdims = P.ReduceSum(True)
|
||||
_mean_keepdims = P.ReduceMean(True)
|
||||
|
|
|
@ -1819,3 +1819,42 @@ class CellBackwardHook(PrimitiveWithInfer):
|
|||
None.
|
||||
"""
|
||||
self.remove_backward_hook_fn(key)
|
||||
|
||||
|
||||
class Format(PrimitiveWithInfer):
|
||||
r"""
|
||||
This operator is used to format a string.
|
||||
|
||||
Note:
|
||||
Current not supported to using by customer.
|
||||
Only support convert str.format() in user code and it will be converted to be Format
|
||||
operation by ME-Compiler automatically.
|
||||
|
||||
|
||||
Inputs:
|
||||
- **input** -
|
||||
string : the string to be formatted.
|
||||
args : the format args.
|
||||
|
||||
Outputs:
|
||||
- **output** - Returns formatted string.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['string', 'args'], outputs=['string'])
|
||||
|
||||
def __infer__(self, str_, *var):
|
||||
str_value = str_["value"]
|
||||
var_value = list()
|
||||
if str_value is None and str_["dtype"] is not None:
|
||||
raise ValueError("str.format not support to input a variable.")
|
||||
for item in var:
|
||||
if item["value"] is None and item["dtype"] is not None:
|
||||
raise ValueError("str.format not support to input a variable.")
|
||||
var_value.append(item["value"])
|
||||
value = str_value.format(*var_value)
|
||||
return {'dtype': mstype.string, 'shape': [], 'value': value}
|
||||
|
|
|
@ -0,0 +1,251 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
import pytest
|
||||
from mindspore import ms_function, Tensor
|
||||
|
||||
|
||||
def test_str_format_single_input():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
ms_str = "string is {}".format("1")
|
||||
return ms_str
|
||||
|
||||
assert foo() == "string is 1"
|
||||
|
||||
|
||||
def test_str_format_mutiple_input():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
ms_str = "{} is {}".format("string", "1")
|
||||
return ms_str
|
||||
|
||||
assert foo() == "string is 1"
|
||||
|
||||
|
||||
def test_str_format_constant_tensor_input():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = Tensor([1])
|
||||
ms_str = "{} is {}".format("string", a)
|
||||
return ms_str
|
||||
|
||||
assert foo() == "string is Tensor(shape=[1], dtype=Int64, value=[1])"
|
||||
|
||||
|
||||
def test_str_format_variable_input():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(b, a):
|
||||
a = a + b
|
||||
ms_str = "{} is {}".format("String", a)
|
||||
return ms_str
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
foo(Tensor([1]), Tensor([1]))
|
||||
assert "str.format not support to input a variable." in str(ex.value)
|
||||
|
||||
|
||||
def test_fallback_str_format_input():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = Tensor([1])
|
||||
ms_str = format(a)
|
||||
ms_str2 = format(ms_str, "4")
|
||||
return ms_str, ms_str2
|
||||
|
||||
ms_str, ms_str2 = foo()
|
||||
assert ms_str == "[1]"
|
||||
assert ms_str2 == "[1] "
|
||||
|
||||
|
||||
def test_format_with_number_placeholder_input():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
ms_str = "{1} {0} {1}"
|
||||
ms_format_str = ms_str.format("hello", "world")
|
||||
return ms_format_str
|
||||
|
||||
ms_str = foo()
|
||||
assert ms_str == "world hello world"
|
||||
|
||||
|
||||
def test_format_with_key_input():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
ms_str = "hello {name2},It's me,{name1}"
|
||||
ms_format_str = ms_str.format(name2="Mind", name1="Spore")
|
||||
return ms_format_str
|
||||
|
||||
with pytest.raises(TypeError) as ex:
|
||||
foo()
|
||||
assert "Unsupported parameter type for python primitive," \
|
||||
" the parameter value is KeywordArg[key : name2, value : Mind]" in str(ex.value)
|
||||
|
||||
|
||||
def test_format_with_list_index():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
ms_str = "hello {0[1]},It's me {0[0]}"
|
||||
names = ["Mind", "Spore"]
|
||||
ms_format_str = ms_str.format(names)
|
||||
return ms_format_str
|
||||
|
||||
result_st = foo()
|
||||
assert result_st == "hello Spore,It's me Mind"
|
||||
|
||||
|
||||
@pytest.mark.skip("Need to support kwargs input of primitive "
|
||||
"operations same as test_format_with_key_input")
|
||||
def test_format_with_map():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
ms_str = "hello {0[name2]},It's me {0[name1]}"
|
||||
names = {"name1": "Mind", "name2": "Spore"}
|
||||
ms_format_str = ms_str.format(names)
|
||||
return ms_format_str
|
||||
|
||||
result_st = foo()
|
||||
assert result_st == "hello Spore,It's me Mind"
|
||||
|
||||
|
||||
def test_format_as_function():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.git
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
func = "hello {0[1]},It's me {0[0]}".format
|
||||
names = ["Mind", "Spore"]
|
||||
ms_format_str = func(names)
|
||||
return ms_format_str
|
||||
|
||||
result_st = foo()
|
||||
assert result_st == "hello Spore,It's me Mind"
|
||||
|
||||
|
||||
def test_format_number():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.git
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
num1 = 3.1415926
|
||||
str1_format = "{:.2f}".format(num1)
|
||||
str2_format = "{:.0f}".format(num1)
|
||||
num2 = 1000000
|
||||
str3_format = "{:,}".format(num2)
|
||||
num3 = 0.25
|
||||
str4_format = "{:.2%}".format(num3)
|
||||
num4 = 1000000000
|
||||
str5_format = "{:.2e}".format(num4)
|
||||
num5 = 25
|
||||
str6_format = "{0:b}".format(num5)
|
||||
str7_format = "{0:d}".format(num5)
|
||||
str8_format = "{0:o}".format(num5)
|
||||
str9_format = "{0:x}".format(num5)
|
||||
result = (str1_format, str2_format, str3_format, str4_format, str5_format,
|
||||
str6_format, str7_format, str8_format, str9_format)
|
||||
return result
|
||||
|
||||
correct_str = ("3.14", "3", "1,000,000", "25.00%", "1.00e+09", "11001", "25", "31", "19")
|
||||
result_str = foo()
|
||||
assert result_str == correct_str
|
||||
|
||||
|
||||
def test_format_padding():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test str.format() in graph mode.
|
||||
Expectation: No exception.git
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
num1 = 5
|
||||
str1_format = "{:0>2}".format(num1)
|
||||
str2_format = "{:x<4}".format(num1)
|
||||
num2 = 10
|
||||
str3_format = "{:x^4}".format(num2)
|
||||
num3 = 13
|
||||
str4_format = "{:10}".format(num3)
|
||||
str5_format = "{:<10}".format(num3)
|
||||
str6_format = "{:^10}".format(num3)
|
||||
|
||||
result = (str1_format, str2_format, str3_format, str4_format, str5_format, str6_format)
|
||||
return result
|
||||
|
||||
correct_str = ("05", "5xxx", "x10x", " 13", "13 ", " 13 ")
|
||||
result_str = foo()
|
||||
assert result_str == correct_str
|
Loading…
Reference in New Issue