From 369eedfeb19429b47ad92e618b65557fa2a52fd4 Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Wed, 16 Nov 2022 11:21:59 +0800 Subject: [PATCH] Add st for dynamic embedding --- .../run_test_dynamic_embedding_standalone.sh | 29 ++++ .../st/map_parameter/network/src/__init__.py | 0 tests/st/map_parameter/network/src/dataset.py | 71 +++++++++ tests/st/map_parameter/network/src/model.py | 136 ++++++++++++++++++ .../test_dynamic_embedding_standalone.py | 35 +++++ ..._entry_dynamic_embedding_standalone_gpu.py | 35 +++++ 6 files changed, 306 insertions(+) create mode 100644 tests/st/map_parameter/network/run_test_dynamic_embedding_standalone.sh create mode 100644 tests/st/map_parameter/network/src/__init__.py create mode 100644 tests/st/map_parameter/network/src/dataset.py create mode 100644 tests/st/map_parameter/network/src/model.py create mode 100644 tests/st/map_parameter/network/test_dynamic_embedding_standalone.py create mode 100644 tests/st/map_parameter/network/test_entry_dynamic_embedding_standalone_gpu.py diff --git a/tests/st/map_parameter/network/run_test_dynamic_embedding_standalone.sh b/tests/st/map_parameter/network/run_test_dynamic_embedding_standalone.sh new file mode 100644 index 00000000000..7a044dfc93b --- /dev/null +++ b/tests/st/map_parameter/network/run_test_dynamic_embedding_standalone.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +DEVICE_TARGET=$1 + +python ${self_path}/test_dynamic_embedding_standalone.py --device_target=$DEVICE_TARGET &>dynamic_embedding.log 2>&1 & +pid=`echo $!` + +wait ${pid} +status=`echo $?` +if [ "${status}" != "0" ]; then + echo "[ERROR] test dynamic embedding standalone failed, status: ${status}" + exit 1 +fi diff --git a/tests/st/map_parameter/network/src/__init__.py b/tests/st/map_parameter/network/src/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/st/map_parameter/network/src/dataset.py b/tests/st/map_parameter/network/src/dataset.py new file mode 100644 index 00000000000..b34c05e0a3f --- /dev/null +++ b/tests/st/map_parameter/network/src/dataset.py @@ -0,0 +1,71 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import mindspore.dataset.vision.c_transforms as c_version +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset as ds +import mindspore.common.dtype as mstype + + +DATASET_PATH = "/home/workspace/mindspore_dataset/animal/mini_animal_12" +_R_MEAN = 123.68 +_G_MEAN = 116.78 +_B_MEAN = 103.94 + +_R_STD = 1 +_G_STD = 1 +_B_STD = 1 + + +def create_dataset(epoch_size=1, batch_size=32, step_size=1, resize_height=224, + resize_width=224, full_batch=False, scale=1.0, rank_size=1): + try: + os.environ['DEVICE_ID'] + except KeyError: + device_id = 0 + os.environ['DEVICE_ID'] = str(device_id) + + if full_batch: + batch_size = batch_size * rank_size + + num_shards = 1 + shard_id = 0 + data_url = DATASET_PATH + dataset = ds.ImageFolderDataset(data_url, num_parallel_workers=1, num_shards=num_shards, + shard_id=shard_id, shuffle=False) + + # define map operations + decode_op = c_version.Decode() + c_version.Normalize(mean=[_R_MEAN, _G_MEAN, _B_MEAN], std=[_R_STD, _G_STD, _B_STD]) + random_resize_op = c_version.Resize((resize_height, resize_width)) + channelswap_op = c_version.HWC2CHW() + rescale = scale / 255.0 + shift = 0.0 + rescale_op = c_version.Rescale(rescale, shift) + type_cast_label = C.TypeCast(mstype.float32) + type_cast_image = C.TypeCast(mstype.int32) + + dataset = dataset.map(input_columns="label", operations=C.OneHot(dataset.num_classes())) + dataset = dataset.map(input_columns="label", operations=type_cast_label, num_parallel_workers=1) + + dataset = dataset.map(input_columns="image", operations=decode_op, num_parallel_workers=1) + dataset = dataset.map(input_columns="image", operations=random_resize_op, num_parallel_workers=1) + dataset = dataset.map(input_columns="image", operations=rescale_op, num_parallel_workers=1) + dataset = dataset.map(input_columns="image", operations=channelswap_op, num_parallel_workers=1) + dataset = dataset.map(input_columns="image", operations=type_cast_image, num_parallel_workers=1) + + dataset = dataset.batch(batch_size, drop_remainder=True) + return dataset diff --git a/tests/st/map_parameter/network/src/model.py b/tests/st/map_parameter/network/src/model.py new file mode 100644 index 00000000000..d786fe5a26d --- /dev/null +++ b/tests/st/map_parameter/network/src/model.py @@ -0,0 +1,136 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import copy + +import mindspore as ms +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore._checkparam import Validator as validator +from mindspore.ops.primitive import constexpr +from mindspore.nn.layer.basic import ClipByNorm +from mindspore.experimental import MapParameter + +from mindspore.nn import Cell, Flatten, Dense +from mindspore.nn import SoftmaxCrossEntropyWithLogits +from mindspore.nn import Adam +from mindspore.train import Model +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint +from mindspore.train.metrics import Accuracy +from mindspore.common import set_seed + + +@constexpr +def _make_axis_range(start, end): + axis = tuple(range(start, end)) + return axis + + +class HashEmbeddingLookup(Cell): + def __init__(self, embedding_size, key_dtype=ms.int32, param_init='normal', max_norm=None, sparse=True): + """Initialize HashEmbeddingLookup.""" + super(HashEmbeddingLookup, self).__init__() + validator.check_value_type('sparse', sparse, [bool], self.cls_name) + + self.forward_unique = sparse + self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name) + self.embedding_table = MapParameter(key_dtype=key_dtype, value_dtype=ms.float32, value_shape=(embedding_size,), + default_value=param_init, name='embedding_table') + + # Ops for sparse mode. + self.gather_revert = P.Gather() + self.reshape_first = P.Reshape() + self.reshape = P.Reshape() + self.unique = P.Unique() + self.shape = P.Shape() + + self.embedding_table.unique = self.forward_unique + + self.max_norm = max_norm + if self.max_norm is not None: + self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) + self.max_norm = Tensor(self.max_norm, dtype=mstype.float32) + + + def construct(self, indices): + if self.forward_unique: + shp = self.shape(indices) + (self.embedding_size,) + indices_flatten = self.reshape_first(indices, (-1,)) + unique_id, unique_idx = self.unique(indices_flatten) + weight_unique = self.embedding_table.get(unique_id) + weight_flatten = self.gather_revert(weight_unique, unique_idx, 0) + out = self.reshape(weight_flatten, shp) + else: + out = self.embedding_table.get(indices) + + if self.max_norm is not None: + axis = _make_axis_range(F.rank(indices), F.rank(out)) + clip_by_norm = ClipByNorm(axis) + out = clip_by_norm(out, self.max_norm) + return out + + +class Net(Cell): + def __init__(self, in_channels, out_channels, embedding_size, sparse): + super().__init__() + set_seed(5) + self.embedding_lookup1 = HashEmbeddingLookup(embedding_size=embedding_size, param_init='normal', + sparse=sparse) + self.flatten = Flatten() + self.dense = Dense(in_channels=in_channels, out_channels=out_channels, weight_init='normal', + has_bias=False) + self.type = ms.int32 + self.cast = P.Cast() + + def construct(self, x): + x = self.flatten(x) + x = self.cast(x, self.type) + x = self.embedding_lookup1(x) + x = self.flatten(x) + x = self.dense(x) + x = self.flatten(x) + return x + + +class ModelExecutor: + def __init__(self, dataset, input_shape, in_channels=320, out_channels=3, + embedding_size=10, epoch_size=2, sparse=True, save_ckpt=False): + self.in_channels = in_channels + self.out_channels = out_channels + self.embedding_size = embedding_size + self.train_dataset = dataset + self.eval_dataset = copy.deepcopy(dataset) + self.epoch_size = epoch_size + self.sparse = sparse + self.save_ckpt = save_ckpt + + def run_dynamic_embedding(self): + net = Net(self.in_channels, self.out_channels, self.embedding_size, self.sparse) + net.set_train() + loss = SoftmaxCrossEntropyWithLogits(reduction='mean') + opt = Adam(params=filter(lambda x: x.requires_grad, net.get_parameters()), use_lazy=True) + + model = Model(net, loss, opt, metrics={"Accuracy": Accuracy()}) + callback_list = [] + if self.save_ckpt: + config = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=1) + ckpoint_cb = ModelCheckpoint(prefix="ckpt_dynamic_embedding", directory='./ckpt', + config=config) + callback_list.append(ckpoint_cb) + model.train(self.epoch_size, self.train_dataset, callbacks=callback_list, dataset_sink_mode=True) + acc = model.eval(self.eval_dataset, dataset_sink_mode=True) + return acc['Accuracy'] diff --git a/tests/st/map_parameter/network/test_dynamic_embedding_standalone.py b/tests/st/map_parameter/network/test_dynamic_embedding_standalone.py new file mode 100644 index 00000000000..e5c3cf8dd8e --- /dev/null +++ b/tests/st/map_parameter/network/test_dynamic_embedding_standalone.py @@ -0,0 +1,35 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import argparse +import mindspore.context as context +from src.dataset import create_dataset +from src.model import ModelExecutor + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="test_dynamic_embedding_standalone") + parser.add_argument("--device_target", type=str, default="GPU") + args, _ = parser.parse_known_args() + device_target = args.device_target + context.set_context(mode=context.GRAPH_MODE, device_target=device_target) + + # Only run this test case in cuda11 environment. + if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']: + dataset = create_dataset(resize_height=32, resize_width=32, scale=30.0) + executor = ModelExecutor(dataset=dataset, sparse=True, in_channels=30720, + out_channels=12, input_shape=[32, 3, 32, 32]) + executor.run_dynamic_embedding() diff --git a/tests/st/map_parameter/network/test_entry_dynamic_embedding_standalone_gpu.py b/tests/st/map_parameter/network/test_entry_dynamic_embedding_standalone_gpu.py new file mode 100644 index 00000000000..9548e2c61b7 --- /dev/null +++ b/tests/st/map_parameter/network/test_entry_dynamic_embedding_standalone_gpu.py @@ -0,0 +1,35 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import pytest + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dynamic_embedding_gpu(): + """ + Feature: Test dynamic embedding feature on gpu. + Description: A small network contain dynamic embedding(MapParameter). + Expectation: All process execute and exit normal. + """ + + self_path = os.path.split(os.path.realpath(__file__))[0] + return_code = os.system(f"bash {self_path}/run_test_dynamic_embedding_standalone.sh GPU") + if return_code != 0: + os.system(f"echo '\n**************** Log ****************'") + os.system(f"grep -E 'ERROR|Error|error' {self_path}/dynamic_embedding.log") + assert return_code == 0