forked from mindspore-Ecosystem/mindspore
!30714 [MD][Offload] Add TypeCast op to offload
Merge pull request !30714 from markuskunej/offload_typecast
This commit is contained in:
commit
edcc6b790d
|
@ -56,8 +56,8 @@ class NodeOffloadPass : public IRTreePass {
|
|||
std::vector<std::shared_ptr<DatasetNode>> nodes_to_offload_;
|
||||
/// \brief Vector of supported offload operations
|
||||
const std::set<std::string> supported_ops_{
|
||||
"HwcToChw", "Normalize", "RandomColorAdjust", "RandomHorizontalFlip", "RandomSharpness",
|
||||
"RandomVerticalFlip", "Rescale"};
|
||||
"HwcToChw", "Normalize", "RandomColorAdjust", "RandomHorizontalFlip",
|
||||
"RandomSharpness", "RandomVerticalFlip", "Rescale", "TypeCast"};
|
||||
/// \brief std::map indicating if the map op for the input column is at the end of the pipeline
|
||||
std::map<std::string, bool> end_of_pipeline_;
|
||||
/// \brief bool indicating whether the auto_offload config option is enabled
|
||||
|
|
|
@ -369,6 +369,22 @@ class Normalize(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
class TypeCast(nn.Cell):
|
||||
"""
|
||||
Applies TypeCast transform on given input tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, data_type_str):
|
||||
super(TypeCast, self).__init__()
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.data_type = mstype.typing.str_to_type(data_type_str)
|
||||
|
||||
def construct(self, x):
|
||||
|
||||
return self.cast(x, self.data_type)
|
||||
|
||||
|
||||
class OffloadModel():
|
||||
def __init__(self, func, args_names=None):
|
||||
self.func = func
|
||||
|
@ -384,7 +400,8 @@ op_to_model = {
|
|||
"RandomHorizontalFlip": OffloadModel(RandomHorizontalFlip, ["prob"]),
|
||||
"RandomSharpness": OffloadModel(RandomSharpness, ["degrees"]),
|
||||
"RandomVerticalFlip": OffloadModel(RandomVerticalFlip, ["prob"]),
|
||||
"Rescale": OffloadModel(Rescale, ["rescale", "shift"])
|
||||
"Rescale": OffloadModel(Rescale, ["rescale", "shift"]),
|
||||
"TypeCast": OffloadModel(TypeCast, ["data_type"])
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -260,6 +260,28 @@ def test_offload_rescale_op():
|
|||
np.testing.assert_almost_equal(img_0, img_1, decimal=6)
|
||||
|
||||
|
||||
def test_offload_typecast_op():
|
||||
"""
|
||||
Feature: test map offload TypeCast op.
|
||||
Description: Input is image dataset.
|
||||
Expectation: Output should be the same with activated or deactivated offload for TypeCast op.
|
||||
"""
|
||||
# Dataset without offload activated.
|
||||
ds_baseline = ds.ImageFolderDataset(DATA_DIR)
|
||||
ds_baseline = ds_baseline.map(operations=[C.Decode(), C2.TypeCast(mstype.float32)], input_columns="image")
|
||||
ds_baseline = ds_baseline.map(operations=[C2.TypeCast(mstype.int32)], input_columns="label")
|
||||
|
||||
# Dataset with offload activated.
|
||||
ds_offload = ds.ImageFolderDataset(DATA_DIR)
|
||||
ds_offload = ds_offload.map(operations=[C.Decode(), C2.TypeCast(mstype.float32)],
|
||||
input_columns="image", offload=True)
|
||||
ds_offload = ds_offload.map(operations=[C2.TypeCast(mstype.int32)], input_columns="label", offload=True)
|
||||
|
||||
for (img_0, _), (img_1, _) in zip(ds_baseline.create_tuple_iterator(num_epochs=1, output_numpy=True),
|
||||
ds_offload.create_tuple_iterator(num_epochs=1, output_numpy=True)):
|
||||
np.testing.assert_almost_equal(img_0, img_1, decimal=6)
|
||||
|
||||
|
||||
def test_offload_different_column_end_of_pipeline():
|
||||
"""
|
||||
Feature: Test offload end_of_pipeline check.
|
||||
|
@ -327,6 +349,7 @@ if __name__ == "__main__":
|
|||
test_offload_concat_dataset_2()
|
||||
test_offload_normalize_op()
|
||||
test_offload_rescale_op()
|
||||
test_offload_typecast_op()
|
||||
test_offload_different_column_end_of_pipeline()
|
||||
test_offload_not_end_of_pipeline()
|
||||
test_offload_dim_check()
|
||||
|
|
Loading…
Reference in New Issue