!34142 update ms_hybrid cpu support

Merge pull request !34142 from zichun_ye/mshyrid_cpu_fix
This commit is contained in:
i-robot 2022-05-11 03:33:08 +00:00 committed by Gitee
commit 8f07bd3d43
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 26 additions and 22 deletions

View File

@ -702,9 +702,6 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j
}
auto process_target = GetProcessorByTarget();
(*kernel_json)[kJsonKeyProcess] = process_target;
if (process_target == "cpu") {
(*kernel_json)[kJsonKeyTargetOption] = kCPUTargetOption;
}
size_t hash_id = std::hash<std::string>()(kernel_json->dump());
kernel_name_ = op_name + "_";
(void)kernel_name_.append(std::to_string(hash_id));
@ -781,9 +778,6 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
auto process_target = GetProcessorByTarget();
(*kernel_json)[kJsonKeyProcess] = process_target;
if (process_target == "cpu") {
(*kernel_json)[kJsonKeyTargetOption] = kCPUTargetOption;
}
size_t hash_id = std::hash<std::string>()(kernel_json->dump());
kernel_name_ = "Fused_";
auto fg = anf_nodes[0]->func_graph();

View File

@ -180,7 +180,7 @@ INTRIN_GENERAL_UNARY_OP = {
'round': numpy.round,
}
INTRIN_CPU_NOT_SUPPORT = ["atan2", "expm1"]
INTRIN_CPU_NOT_SUPPORT = ["atan2", "expm1", "float16"]
INTRIN_GENERAL_BINARY_OP = {
'ceil_div': lambda a, b: (a + b - 1) // b,

View File

@ -174,6 +174,22 @@ def ms_hybrid_grid():
raise ValueError("Precision error, compare result: {}".format(compare_res))
def ms_hybrid_grid_cpu():
"""
test case Custom Op with functions written in Hybrid DSL about grid
"""
np.random.seed(10)
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
test = TestMsHybridDSL(grid_example, "hybrid", lambda x, _: x, lambda x, _: x)
output = test(Tensor(input_x), Tensor(input_y))
expect = grid_example(input_x, input_y)
compare_res = np.allclose(expect, output.asnumpy(), 0.001, 0.001)
if not compare_res:
raise ValueError("Precision error, compare result: {}".format(compare_res))
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -250,19 +266,16 @@ def test_ms_hybrid_gpu_pynative_mode():
def test_ms_hybrid_cpu_graph_mode():
"""
Feature: test case for Custom op with func_type="ms_hybrid"
Description: gpu test case, Python DSL with ms_hybrid decorator in GRAPH_MODE.
Description: cpu test case, Python DSL with ms_hybrid decorator in GRAPH_MODE.
Expectation: the result match with numpy result
"""
sys = platform.system()
if sys == 'Windows':
# skip window, same for pynative below
if platform.system().lower() in {"windows", "darwin"}:
# skip window and mac, same for pynative below
pass
else:
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
ms_hybrid_cast_with_infer()
ms_hybrid_cast_without_infer()
ms_hybrid_allocate()
ms_hybrid_grid()
ms_hybrid_allocate_cpu()
ms_hybrid_grid_cpu()
@pytest.mark.level0
@ -271,15 +284,12 @@ def test_ms_hybrid_cpu_graph_mode():
def test_ms_hybrid_cpu_pynative_mode():
"""
Feature: test case for Custom op with func_type="ms_hybrid"
Description: gpu test case, Python DSL with ms_hybrid decorator in PYNATIVE_MODE.
Description: cpu test case, Python DSL with ms_hybrid decorator in PYNATIVE_MODE.
Expectation: the result match with numpy result
"""
sys = platform.system()
if sys == 'Windows':
if platform.system().lower() in {"windows", "darwin"}:
pass
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
ms_hybrid_cast_with_infer()
ms_hybrid_cast_without_infer()
ms_hybrid_allocate()
ms_hybrid_grid()
ms_hybrid_allocate_cpu()
ms_hybrid_grid_cpu()