!17807 MD Profiling UT: Add MD Analyze UT for MinddataProfilingAnalyzer

Merge pull request !17807 from cathwong/ckw_mon_py_analyze_ut3
This commit is contained in:
i-robot 2021-06-08 21:52:55 +08:00 committed by Gitee
commit 3a63a66d64
2 changed files with 438 additions and 97 deletions

View File

@ -18,23 +18,50 @@ Testing profiling support in DE
import json import json
import os import os
import numpy as np import numpy as np
import mindspore.common.dtype as mstype
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as vision
FILES = ["../data/dataset/testTFTestAllTypes/test.data"] FILES = ["../data/dataset/testTFTestAllTypes/test.data"]
DATASET_ROOT = "../data/dataset/testTFTestAllTypes/" DATASET_ROOT = "../data/dataset/testTFTestAllTypes/"
SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json" SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
PIPELINE_FILE = "./pipeline_profiling_1.json" PIPELINE_FILE = "./pipeline_profiling_1.json"
CPU_UTIL_FILE = "./minddata_cpu_utilization_1.json"
DATASET_ITERATOR_FILE = "./dataset_iterator_profiling_1.txt" DATASET_ITERATOR_FILE = "./dataset_iterator_profiling_1.txt"
def set_profiling_env_var():
"""
Set the MindData Profiling environment variables
"""
os.environ['PROFILING_MODE'] = 'true'
os.environ['MINDDATA_PROFILING_DIR'] = '.'
os.environ['DEVICE_ID'] = '1'
def delete_profiling_files():
"""
Delete the MindData profiling files generated from the test.
Also disable the MindData Profiling environment variables.
"""
# Delete MindData profiling files
os.remove(PIPELINE_FILE)
os.remove(CPU_UTIL_FILE)
os.remove(DATASET_ITERATOR_FILE)
# Disable MindData Profiling environment variables
del os.environ['PROFILING_MODE']
del os.environ['MINDDATA_PROFILING_DIR']
del os.environ['DEVICE_ID']
def test_profiling_simple_pipeline(): def test_profiling_simple_pipeline():
""" """
Generator -> Shuffle -> Batch Generator -> Shuffle -> Batch
""" """
os.environ['PROFILING_MODE'] = 'true' set_profiling_env_var()
os.environ['MINDDATA_PROFILING_DIR'] = '.'
os.environ['DEVICE_ID'] = '1'
source = [(np.array([x]),) for x in range(1024)] source = [(np.array([x]),) for x in range(1024)]
data1 = ds.GeneratorDataset(source, ["data"]) data1 = ds.GeneratorDataset(source, ["data"])
@ -44,18 +71,27 @@ def test_profiling_simple_pipeline():
assert data1.output_shapes() == [[32, 1]] assert data1.output_shapes() == [[32, 1]]
assert [str(tp) for tp in data1.output_types()] == ["int64"] assert [str(tp) for tp in data1.output_types()] == ["int64"]
assert data1.get_dataset_size() == 32 assert data1.get_dataset_size() == 32
# Confirm profiling files do not (yet) exist
assert os.path.exists(PIPELINE_FILE) is False assert os.path.exists(PIPELINE_FILE) is False
assert os.path.exists(CPU_UTIL_FILE) is False
assert os.path.exists(DATASET_ITERATOR_FILE) is False assert os.path.exists(DATASET_ITERATOR_FILE) is False
for _ in data1: try:
pass for _ in data1:
pass
assert os.path.exists(PIPELINE_FILE) is True # Confirm profiling files now exist
os.remove(PIPELINE_FILE) assert os.path.exists(PIPELINE_FILE) is True
assert os.path.exists(DATASET_ITERATOR_FILE) is True assert os.path.exists(CPU_UTIL_FILE) is True
os.remove(DATASET_ITERATOR_FILE) assert os.path.exists(DATASET_ITERATOR_FILE) is True
del os.environ['PROFILING_MODE']
del os.environ['MINDDATA_PROFILING_DIR'] except Exception as error:
delete_profiling_files()
raise error
else:
delete_profiling_files()
def test_profiling_complex_pipeline(): def test_profiling_complex_pipeline():
@ -64,9 +100,7 @@ def test_profiling_complex_pipeline():
-> Zip -> Zip
TFReader -> Shuffle -> TFReader -> Shuffle ->
""" """
os.environ['PROFILING_MODE'] = 'true' set_profiling_env_var()
os.environ['MINDDATA_PROFILING_DIR'] = '.'
os.environ['DEVICE_ID'] = '1'
source = [(np.array([x]),) for x in range(1024)] source = [(np.array([x]),) for x in range(1024)]
data1 = ds.GeneratorDataset(source, ["gen"]) data1 = ds.GeneratorDataset(source, ["gen"])
@ -78,28 +112,29 @@ def test_profiling_complex_pipeline():
data3 = ds.zip((data1, data2)) data3 = ds.zip((data1, data2))
for _ in data3: try:
pass for _ in data3:
pass
with open(PIPELINE_FILE) as f: with open(PIPELINE_FILE) as f:
data = json.load(f) data = json.load(f)
op_info = data["op_info"] op_info = data["op_info"]
assert len(op_info) == 5 assert len(op_info) == 5
for i in range(5): for i in range(5):
if op_info[i]["op_type"] != "ZipOp": if op_info[i]["op_type"] != "ZipOp":
assert "size" in op_info[i]["metrics"]["output_queue"] assert "size" in op_info[i]["metrics"]["output_queue"]
assert "length" in op_info[i]["metrics"]["output_queue"] assert "length" in op_info[i]["metrics"]["output_queue"]
assert "throughput" in op_info[i]["metrics"]["output_queue"] assert "throughput" in op_info[i]["metrics"]["output_queue"]
else: else:
# Note: Zip is an inline op and hence does not have metrics information # Note: Zip is an inline op and hence does not have metrics information
assert op_info[i]["metrics"] is None assert op_info[i]["metrics"] is None
assert os.path.exists(PIPELINE_FILE) is True except Exception as error:
os.remove(PIPELINE_FILE) delete_profiling_files()
assert os.path.exists(DATASET_ITERATOR_FILE) is True raise error
os.remove(DATASET_ITERATOR_FILE)
del os.environ['PROFILING_MODE'] else:
del os.environ['MINDDATA_PROFILING_DIR'] delete_profiling_files()
def test_profiling_inline_ops_pipeline1(): def test_profiling_inline_ops_pipeline1():
@ -109,9 +144,7 @@ def test_profiling_inline_ops_pipeline1():
Concat -> EpochCtrl Concat -> EpochCtrl
Generator -> Generator ->
""" """
os.environ['PROFILING_MODE'] = 'true' set_profiling_env_var()
os.environ['MINDDATA_PROFILING_DIR'] = '.'
os.environ['DEVICE_ID'] = '1'
# In source1 dataset: Number of rows is 3; its values are 0, 1, 2 # In source1 dataset: Number of rows is 3; its values are 0, 1, 2
def source1(): def source1():
@ -127,33 +160,37 @@ def test_profiling_inline_ops_pipeline1():
data2 = ds.GeneratorDataset(source2, ["col1"]) data2 = ds.GeneratorDataset(source2, ["col1"])
data3 = data1.concat(data2) data3 = data1.concat(data2)
# Here i refers to index, d refers to data element try:
for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): # Note: If create_tuple_iterator() is called with num_epochs>1, then EpochCtrlOp is added to the pipeline
t = d num_iter = 0
assert i == t[0][0] # Here i refers to index, d refers to data element
for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True, num_epochs=2)):
num_iter = num_iter + 1
t = d
assert i == t[0][0]
assert sum([1 for _ in data3]) == 10 assert num_iter == 10
with open(PIPELINE_FILE) as f: with open(PIPELINE_FILE) as f:
data = json.load(f) data = json.load(f)
op_info = data["op_info"] op_info = data["op_info"]
assert len(op_info) == 4 assert len(op_info) == 4
for i in range(4): for i in range(4):
# Note: The following ops are inline ops: Concat, EpochCtrl # Note: The following ops are inline ops: Concat, EpochCtrl
if op_info[i]["op_type"] in ("ConcatOp", "EpochCtrlOp"): if op_info[i]["op_type"] in ("ConcatOp", "EpochCtrlOp"):
# Confirm these inline ops do not have metrics information # Confirm these inline ops do not have metrics information
assert op_info[i]["metrics"] is None assert op_info[i]["metrics"] is None
else: else:
assert "size" in op_info[i]["metrics"]["output_queue"] assert "size" in op_info[i]["metrics"]["output_queue"]
assert "length" in op_info[i]["metrics"]["output_queue"] assert "length" in op_info[i]["metrics"]["output_queue"]
assert "throughput" in op_info[i]["metrics"]["output_queue"] assert "throughput" in op_info[i]["metrics"]["output_queue"]
assert os.path.exists(PIPELINE_FILE) is True except Exception as error:
os.remove(PIPELINE_FILE) delete_profiling_files()
assert os.path.exists(DATASET_ITERATOR_FILE) is True raise error
os.remove(DATASET_ITERATOR_FILE)
del os.environ['PROFILING_MODE'] else:
del os.environ['MINDDATA_PROFILING_DIR'] delete_profiling_files()
def test_profiling_inline_ops_pipeline2(): def test_profiling_inline_ops_pipeline2():
@ -161,9 +198,7 @@ def test_profiling_inline_ops_pipeline2():
Test pipeline with many inline ops Test pipeline with many inline ops
Generator -> Rename -> Skip -> Repeat -> Take Generator -> Rename -> Skip -> Repeat -> Take
""" """
os.environ['PROFILING_MODE'] = 'true' set_profiling_env_var()
os.environ['MINDDATA_PROFILING_DIR'] = '.'
os.environ['DEVICE_ID'] = '1'
# In source1 dataset: Number of rows is 10; its values are 0, 1, 2, 3, 4, 5 ... 9 # In source1 dataset: Number of rows is 10; its values are 0, 1, 2, 3, 4, 5 ... 9
def source1(): def source1():
@ -176,38 +211,38 @@ def test_profiling_inline_ops_pipeline2():
data1 = data1.repeat(2) data1 = data1.repeat(2)
data1 = data1.take(12) data1 = data1.take(12)
for _ in data1: try:
pass for _ in data1:
pass
with open(PIPELINE_FILE) as f: with open(PIPELINE_FILE) as f:
data = json.load(f) data = json.load(f)
op_info = data["op_info"] op_info = data["op_info"]
assert len(op_info) == 5 assert len(op_info) == 5
for i in range(5): for i in range(5):
# Check for these inline ops # Check for these inline ops
if op_info[i]["op_type"] in ("RenameOp", "RepeatOp", "SkipOp", "TakeOp"): if op_info[i]["op_type"] in ("RenameOp", "RepeatOp", "SkipOp", "TakeOp"):
# Confirm these inline ops do not have metrics information # Confirm these inline ops do not have metrics information
assert op_info[i]["metrics"] is None assert op_info[i]["metrics"] is None
else: else:
assert "size" in op_info[i]["metrics"]["output_queue"] assert "size" in op_info[i]["metrics"]["output_queue"]
assert "length" in op_info[i]["metrics"]["output_queue"] assert "length" in op_info[i]["metrics"]["output_queue"]
assert "throughput" in op_info[i]["metrics"]["output_queue"] assert "throughput" in op_info[i]["metrics"]["output_queue"]
assert os.path.exists(PIPELINE_FILE) is True except Exception as error:
os.remove(PIPELINE_FILE) delete_profiling_files()
assert os.path.exists(DATASET_ITERATOR_FILE) is True raise error
os.remove(DATASET_ITERATOR_FILE)
del os.environ['PROFILING_MODE'] else:
del os.environ['MINDDATA_PROFILING_DIR'] delete_profiling_files()
def test_profiling_sampling_interval(): def test_profiling_sampling_interval():
""" """
Test non-default monitor sampling interval Test non-default monitor sampling interval
""" """
os.environ['PROFILING_MODE'] = 'true' set_profiling_env_var()
os.environ['MINDDATA_PROFILING_DIR'] = '.'
os.environ['DEVICE_ID'] = '1'
interval_origin = ds.config.get_monitor_sampling_interval() interval_origin = ds.config.get_monitor_sampling_interval()
ds.config.set_monitor_sampling_interval(30) ds.config.set_monitor_sampling_interval(30)
@ -219,17 +254,118 @@ def test_profiling_sampling_interval():
data1 = data1.shuffle(64) data1 = data1.shuffle(64)
data1 = data1.batch(32) data1 = data1.batch(32)
for _ in data1: try:
pass for _ in data1:
pass
assert os.path.exists(PIPELINE_FILE) is True except Exception as error:
os.remove(PIPELINE_FILE) ds.config.set_monitor_sampling_interval(interval_origin)
assert os.path.exists(DATASET_ITERATOR_FILE) is True delete_profiling_files()
os.remove(DATASET_ITERATOR_FILE) raise error
ds.config.set_monitor_sampling_interval(interval_origin) else:
del os.environ['PROFILING_MODE'] ds.config.set_monitor_sampling_interval(interval_origin)
del os.environ['MINDDATA_PROFILING_DIR'] delete_profiling_files()
def test_profiling_basic_pipeline():
"""
Test with this basic pipeline
Generator -> Map -> Batch -> Repeat -> EpochCtrl
"""
set_profiling_env_var()
def source1():
for i in range(8000):
yield (np.array([i]),)
# Create this basic and common pipeline
# Leaf/Source-Op -> Map -> Batch -> Repeat
data1 = ds.GeneratorDataset(source1, ["col1"])
type_cast_op = C.TypeCast(mstype.int32)
data1 = data1.map(operations=type_cast_op, input_columns="col1")
data1 = data1.batch(16)
data1 = data1.repeat(2)
try:
num_iter = 0
# Note: If create_tuple_iterator() is called with num_epochs>1, then EpochCtrlOp is added to the pipeline
for _ in data1.create_dict_iterator(num_epochs=2):
num_iter = num_iter + 1
assert num_iter == 1000
with open(PIPELINE_FILE) as f:
data = json.load(f)
op_info = data["op_info"]
assert len(op_info) == 5
for i in range(5):
# Check for inline ops
if op_info[i]["op_type"] in ("EpochCtrlOp", "RepeatOp"):
# Confirm these inline ops do not have metrics information
assert op_info[i]["metrics"] is None
else:
assert "size" in op_info[i]["metrics"]["output_queue"]
assert "length" in op_info[i]["metrics"]["output_queue"]
assert "throughput" in op_info[i]["metrics"]["output_queue"]
except Exception as error:
delete_profiling_files()
raise error
else:
delete_profiling_files()
def test_profiling_cifar10_pipeline():
"""
Test with this common pipeline with Cifar10
Cifar10 -> Map -> Map -> Batch -> Repeat
"""
set_profiling_env_var()
# Create this common pipeline
# Cifar10 -> Map -> Map -> Batch -> Repeat
DATA_DIR_10 = "../data/dataset/testCifar10Data"
data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=8000)
type_cast_op = C.TypeCast(mstype.int32)
data1 = data1.map(operations=type_cast_op, input_columns="label")
random_horizontal_op = vision.RandomHorizontalFlip()
data1 = data1.map(operations=random_horizontal_op, input_columns="image")
data1 = data1.batch(32)
data1 = data1.repeat(3)
try:
num_iter = 0
# Note: If create_tuple_iterator() is called with num_epochs=1, then EpochCtrlOp is NOT added to the pipeline
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter = num_iter + 1
assert num_iter == 750
with open(PIPELINE_FILE) as f:
data = json.load(f)
op_info = data["op_info"]
assert len(op_info) == 5
for i in range(5):
# Check for inline ops
if op_info[i]["op_type"] == "RepeatOp":
# Confirm these inline ops do not have metrics information
assert op_info[i]["metrics"] is None
else:
assert "size" in op_info[i]["metrics"]["output_queue"]
assert "length" in op_info[i]["metrics"]["output_queue"]
assert "throughput" in op_info[i]["metrics"]["output_queue"]
except Exception as error:
delete_profiling_files()
raise error
else:
delete_profiling_files()
if __name__ == "__main__": if __name__ == "__main__":
@ -238,3 +374,5 @@ if __name__ == "__main__":
test_profiling_inline_ops_pipeline1() test_profiling_inline_ops_pipeline1()
test_profiling_inline_ops_pipeline2() test_profiling_inline_ops_pipeline2()
test_profiling_sampling_interval() test_profiling_sampling_interval()
test_profiling_basic_pipeline()
test_profiling_cifar10_pipeline()

View File

@ -0,0 +1,203 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test MindData Profiling Analyzer Support
"""
import csv
import json
import os
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
from mindspore.profiler.parser.minddata_analyzer import MinddataProfilingAnalyzer
PIPELINE_FILE = "./pipeline_profiling_0.json"
CPU_UTIL_FILE = "./minddata_cpu_utilization_0.json"
DATASET_ITERATOR_FILE = "./dataset_iterator_profiling_0.txt"
SUMMARY_JSON_FILE = "./minddata_pipeline_summary_0.json"
SUMMARY_CSV_FILE = "./minddata_pipeline_summary_0.csv"
ANALYZE_FILE_PATH = "./"
# This is the minimum subset of expected keys (in alphabetical order) in the MindData Analyzer summary output
EXPECTED_SUMMARY_KEYS = ['avg_cpu_pct', 'children_ids', 'num_workers', 'op_ids', 'op_names', 'parent_id',
'per_batch_time', 'pipeline_ops', 'queue_average_size', 'queue_empty_freq_pct',
'queue_utilization_pct']
def get_csv_result(file_pathname):
"""
Get result from the CSV file.
Args:
file_pathname (str): The CSV file pathname.
Returns:
list[list], the parsed CSV information.
"""
result = []
with open(file_pathname, 'r') as csvfile:
csv_reader = csv.reader(csvfile)
for row in csv_reader:
result.append(row)
return result
def delete_profiling_files():
"""
Delete the MindData profiling files generated from the test.
Also disable the MindData Profiling environment variables.
"""
# Delete MindData profiling files
os.remove(PIPELINE_FILE)
os.remove(CPU_UTIL_FILE)
os.remove(DATASET_ITERATOR_FILE)
# Delete MindData profiling analyze summary files
os.remove(SUMMARY_JSON_FILE)
os.remove(SUMMARY_CSV_FILE)
# Disable MindData Profiling environment variables
del os.environ['PROFILING_MODE']
del os.environ['MINDDATA_PROFILING_DIR']
del os.environ['DEVICE_ID']
def test_analyze_basic():
"""
Test MindData profiling analyze summary files exist with basic pipeline.
Also test basic content (subset of keys and values) from the returned summary result.
"""
# Confirm MindData Profiling files do not yet exist
assert os.path.exists(PIPELINE_FILE) is False
assert os.path.exists(CPU_UTIL_FILE) is False
assert os.path.exists(DATASET_ITERATOR_FILE) is False
# Confirm MindData Profiling analyze summary files do not yet exist
assert os.path.exists(SUMMARY_JSON_FILE) is False
assert os.path.exists(SUMMARY_CSV_FILE) is False
# Enable MindData Profiling environment variables
os.environ['PROFILING_MODE'] = 'true'
os.environ['MINDDATA_PROFILING_DIR'] = '.'
os.environ['DEVICE_ID'] = '0'
def source1():
for i in range(8000):
yield (np.array([i]),)
try:
# Create this basic and common linear pipeline
# Generator -> Map -> Batch -> Repeat -> EpochCtrl
data1 = ds.GeneratorDataset(source1, ["col1"])
type_cast_op = C.TypeCast(mstype.int32)
data1 = data1.map(operations=type_cast_op, input_columns="col1")
data1 = data1.batch(16)
data1 = data1.repeat(2)
num_iter = 0
# Note: If create_tuple_iterator() is called with num_epochs>1, then EpochCtrlOp is added to the pipeline
for _ in data1.create_dict_iterator(num_epochs=2):
num_iter = num_iter + 1
# Confirm number of rows returned
assert num_iter == 1000
# Confirm MindData Profiling files are created
assert os.path.exists(PIPELINE_FILE) is True
assert os.path.exists(CPU_UTIL_FILE) is True
assert os.path.exists(DATASET_ITERATOR_FILE) is True
# Call MindData Analyzer for generated MindData profiling files to generate MindData pipeline summary result
# Note: MindData Analyzer returns the result in 3 formats:
# 1. returned dictionary
# 2. JSON file
# 3. CSV file
md_analyzer = MinddataProfilingAnalyzer(ANALYZE_FILE_PATH, "CPU", 0, ANALYZE_FILE_PATH)
md_summary_dict = md_analyzer.analyze()
# Confirm MindData Profiling analyze summary files are created
assert os.path.exists(SUMMARY_JSON_FILE) is True
assert os.path.exists(SUMMARY_CSV_FILE) is True
# Build a list of the sorted returned keys
summary_returned_keys = list(md_summary_dict.keys())
summary_returned_keys.sort()
# 1. Confirm expected keys are in returned keys
for k in EXPECTED_SUMMARY_KEYS:
assert k in summary_returned_keys
# Read summary JSON file
with open(SUMMARY_JSON_FILE) as f:
summary_json_data = json.load(f)
# Build a list of the sorted JSON keys
summary_json_keys = list(summary_json_data.keys())
summary_json_keys.sort()
# 2a. Confirm expected keys are in JSON file keys
for k in EXPECTED_SUMMARY_KEYS:
assert k in summary_json_keys
# 2b. Confirm returned dictionary keys are identical to JSON file keys
np.testing.assert_array_equal(summary_returned_keys, summary_json_keys)
# Read summary CSV file
summary_csv_data = get_csv_result(SUMMARY_CSV_FILE)
# Build a list of the sorted CSV keys from the first column in the CSV file
summary_csv_keys = []
for x in summary_csv_data:
summary_csv_keys.append(x[0])
summary_csv_keys.sort()
# 3a. Confirm expected keys are in the first column of the CSV file
for k in EXPECTED_SUMMARY_KEYS:
assert k in summary_csv_keys
# 3b. Confirm returned dictionary keys are identical to CSV file first column keys
np.testing.assert_array_equal(summary_returned_keys, summary_csv_keys)
# 4. Verify non-variant values or number of values in the tested pipeline for certain keys
# of the returned dictionary
# Note: Values of num_workers are not tested since default may change in the future
# Note: Values related to queue metrics are not tested since they may vary on different execution environments
assert md_summary_dict["pipeline_ops"] == ["EpochCtrl(id=0)", "Repeat(id=1)", "Batch(id=2)", "Map(id=3)",
"Generator(id=4)"]
assert md_summary_dict["op_names"] == ["EpochCtrl", "Repeat", "Batch", "Map", "Generator"]
assert md_summary_dict["op_ids"] == [0, 1, 2, 3, 4]
assert len(md_summary_dict["num_workers"]) == 5
assert len(md_summary_dict["queue_average_size"]) == 5
assert len(md_summary_dict["queue_utilization_pct"]) == 5
assert len(md_summary_dict["queue_empty_freq_pct"]) == 5
assert md_summary_dict["children_ids"] == [[1], [2], [3], [4], []]
assert md_summary_dict["parent_id"] == [-1, 0, 1, 2, 3]
assert len(md_summary_dict["avg_cpu_pct"]) == 5
# 5. Confirm exact list of keys
# Note: This is a very strong comparison.
# e.g. No bottleneck info is in the result.
# e.g. No additional keys are in the returned summary result
np.testing.assert_array_equal(summary_returned_keys, EXPECTED_SUMMARY_KEYS)
except Exception as error:
delete_profiling_files()
raise error
else:
delete_profiling_files()
if __name__ == "__main__":
test_analyze_basic()