!45615 Add network st for dynamic embedding
Merge pull request !45615 from zyli2020/dynamic_embedding_compile
This commit is contained in:
commit
1a03780771
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
DEVICE_TARGET=$1
|
||||
|
||||
python ${self_path}/test_dynamic_embedding_standalone.py --device_target=$DEVICE_TARGET &>dynamic_embedding.log 2>&1 &
|
||||
pid=`echo $!`
|
||||
|
||||
wait ${pid}
|
||||
status=`echo $?`
|
||||
if [ "${status}" != "0" ]; then
|
||||
echo "[ERROR] test dynamic embedding standalone failed, status: ${status}"
|
||||
exit 1
|
||||
fi
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import mindspore.dataset.vision.c_transforms as c_version
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
DATASET_PATH = "/home/workspace/mindspore_dataset/animal/mini_animal_12"
|
||||
_R_MEAN = 123.68
|
||||
_G_MEAN = 116.78
|
||||
_B_MEAN = 103.94
|
||||
|
||||
_R_STD = 1
|
||||
_G_STD = 1
|
||||
_B_STD = 1
|
||||
|
||||
|
||||
def create_dataset(epoch_size=1, batch_size=32, step_size=1, resize_height=224,
|
||||
resize_width=224, full_batch=False, scale=1.0, rank_size=1):
|
||||
try:
|
||||
os.environ['DEVICE_ID']
|
||||
except KeyError:
|
||||
device_id = 0
|
||||
os.environ['DEVICE_ID'] = str(device_id)
|
||||
|
||||
if full_batch:
|
||||
batch_size = batch_size * rank_size
|
||||
|
||||
num_shards = 1
|
||||
shard_id = 0
|
||||
data_url = DATASET_PATH
|
||||
dataset = ds.ImageFolderDataset(data_url, num_parallel_workers=1, num_shards=num_shards,
|
||||
shard_id=shard_id, shuffle=False)
|
||||
|
||||
# define map operations
|
||||
decode_op = c_version.Decode()
|
||||
c_version.Normalize(mean=[_R_MEAN, _G_MEAN, _B_MEAN], std=[_R_STD, _G_STD, _B_STD])
|
||||
random_resize_op = c_version.Resize((resize_height, resize_width))
|
||||
channelswap_op = c_version.HWC2CHW()
|
||||
rescale = scale / 255.0
|
||||
shift = 0.0
|
||||
rescale_op = c_version.Rescale(rescale, shift)
|
||||
type_cast_label = C.TypeCast(mstype.float32)
|
||||
type_cast_image = C.TypeCast(mstype.int32)
|
||||
|
||||
dataset = dataset.map(input_columns="label", operations=C.OneHot(dataset.num_classes()))
|
||||
dataset = dataset.map(input_columns="label", operations=type_cast_label, num_parallel_workers=1)
|
||||
|
||||
dataset = dataset.map(input_columns="image", operations=decode_op, num_parallel_workers=1)
|
||||
dataset = dataset.map(input_columns="image", operations=random_resize_op, num_parallel_workers=1)
|
||||
dataset = dataset.map(input_columns="image", operations=rescale_op, num_parallel_workers=1)
|
||||
dataset = dataset.map(input_columns="image", operations=channelswap_op, num_parallel_workers=1)
|
||||
dataset = dataset.map(input_columns="image", operations=type_cast_image, num_parallel_workers=1)
|
||||
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
return dataset
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import copy
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.nn.layer.basic import ClipByNorm
|
||||
from mindspore.experimental import MapParameter
|
||||
|
||||
from mindspore.nn import Cell, Flatten, Dense
|
||||
from mindspore.nn import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.nn import Adam
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.train.metrics import Accuracy
|
||||
from mindspore.common import set_seed
|
||||
|
||||
|
||||
@constexpr
|
||||
def _make_axis_range(start, end):
|
||||
axis = tuple(range(start, end))
|
||||
return axis
|
||||
|
||||
|
||||
class HashEmbeddingLookup(Cell):
|
||||
def __init__(self, embedding_size, key_dtype=ms.int32, param_init='normal', max_norm=None, sparse=True):
|
||||
"""Initialize HashEmbeddingLookup."""
|
||||
super(HashEmbeddingLookup, self).__init__()
|
||||
validator.check_value_type('sparse', sparse, [bool], self.cls_name)
|
||||
|
||||
self.forward_unique = sparse
|
||||
self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name)
|
||||
self.embedding_table = MapParameter(key_dtype=key_dtype, value_dtype=ms.float32, value_shape=(embedding_size,),
|
||||
default_value=param_init, name='embedding_table')
|
||||
|
||||
# Ops for sparse mode.
|
||||
self.gather_revert = P.Gather()
|
||||
self.reshape_first = P.Reshape()
|
||||
self.reshape = P.Reshape()
|
||||
self.unique = P.Unique()
|
||||
self.shape = P.Shape()
|
||||
|
||||
self.embedding_table.unique = self.forward_unique
|
||||
|
||||
self.max_norm = max_norm
|
||||
if self.max_norm is not None:
|
||||
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
|
||||
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
|
||||
|
||||
|
||||
def construct(self, indices):
|
||||
if self.forward_unique:
|
||||
shp = self.shape(indices) + (self.embedding_size,)
|
||||
indices_flatten = self.reshape_first(indices, (-1,))
|
||||
unique_id, unique_idx = self.unique(indices_flatten)
|
||||
weight_unique = self.embedding_table.get(unique_id)
|
||||
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
|
||||
out = self.reshape(weight_flatten, shp)
|
||||
else:
|
||||
out = self.embedding_table.get(indices)
|
||||
|
||||
if self.max_norm is not None:
|
||||
axis = _make_axis_range(F.rank(indices), F.rank(out))
|
||||
clip_by_norm = ClipByNorm(axis)
|
||||
out = clip_by_norm(out, self.max_norm)
|
||||
return out
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, in_channels, out_channels, embedding_size, sparse):
|
||||
super().__init__()
|
||||
set_seed(5)
|
||||
self.embedding_lookup1 = HashEmbeddingLookup(embedding_size=embedding_size, param_init='normal',
|
||||
sparse=sparse)
|
||||
self.flatten = Flatten()
|
||||
self.dense = Dense(in_channels=in_channels, out_channels=out_channels, weight_init='normal',
|
||||
has_bias=False)
|
||||
self.type = ms.int32
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.flatten(x)
|
||||
x = self.cast(x, self.type)
|
||||
x = self.embedding_lookup1(x)
|
||||
x = self.flatten(x)
|
||||
x = self.dense(x)
|
||||
x = self.flatten(x)
|
||||
return x
|
||||
|
||||
|
||||
class ModelExecutor:
|
||||
def __init__(self, dataset, input_shape, in_channels=320, out_channels=3,
|
||||
embedding_size=10, epoch_size=2, sparse=True, save_ckpt=False):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.embedding_size = embedding_size
|
||||
self.train_dataset = dataset
|
||||
self.eval_dataset = copy.deepcopy(dataset)
|
||||
self.epoch_size = epoch_size
|
||||
self.sparse = sparse
|
||||
self.save_ckpt = save_ckpt
|
||||
|
||||
def run_dynamic_embedding(self):
|
||||
net = Net(self.in_channels, self.out_channels, self.embedding_size, self.sparse)
|
||||
net.set_train()
|
||||
loss = SoftmaxCrossEntropyWithLogits(reduction='mean')
|
||||
opt = Adam(params=filter(lambda x: x.requires_grad, net.get_parameters()), use_lazy=True)
|
||||
|
||||
model = Model(net, loss, opt, metrics={"Accuracy": Accuracy()})
|
||||
callback_list = []
|
||||
if self.save_ckpt:
|
||||
config = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=1)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ckpt_dynamic_embedding", directory='./ckpt',
|
||||
config=config)
|
||||
callback_list.append(ckpoint_cb)
|
||||
model.train(self.epoch_size, self.train_dataset, callbacks=callback_list, dataset_sink_mode=True)
|
||||
acc = model.eval(self.eval_dataset, dataset_sink_mode=True)
|
||||
return acc['Accuracy']
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import mindspore.context as context
|
||||
from src.dataset import create_dataset
|
||||
from src.model import ModelExecutor
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="test_dynamic_embedding_standalone")
|
||||
parser.add_argument("--device_target", type=str, default="GPU")
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
|
||||
# Only run this test case in cuda11 environment.
|
||||
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||
dataset = create_dataset(resize_height=32, resize_width=32, scale=30.0)
|
||||
executor = ModelExecutor(dataset=dataset, sparse=True, in_channels=30720,
|
||||
out_channels=12, input_shape=[32, 3, 32, 32])
|
||||
executor.run_dynamic_embedding()
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_embedding_gpu():
|
||||
"""
|
||||
Feature: Test dynamic embedding feature on gpu.
|
||||
Description: A small network contain dynamic embedding(MapParameter).
|
||||
Expectation: All process execute and exit normal.
|
||||
"""
|
||||
|
||||
self_path = os.path.split(os.path.realpath(__file__))[0]
|
||||
return_code = os.system(f"bash {self_path}/run_test_dynamic_embedding_standalone.sh GPU")
|
||||
if return_code != 0:
|
||||
os.system(f"echo '\n**************** Log ****************'")
|
||||
os.system(f"grep -E 'ERROR|Error|error' {self_path}/dynamic_embedding.log")
|
||||
assert return_code == 0
|
Loading…
Reference in New Issue