!46671 Support new MapParameter in construct.

Merge pull request !46671 from Margaret_wangrui/map_parameter
This commit is contained in:
i-robot 2022-12-12 01:34:20 +00:00 committed by Gitee
commit 44b31ff9c1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 81 additions and 67 deletions

View File

@ -189,6 +189,12 @@ void DeviceAddressUtils::CreateValueNodeDeviceAddress(const DeviceContext *devic
continue;
}
const auto &abstract = value_node->abstract();
if (abstract != nullptr && abstract->isa<abstract::AbstractMapTensor>()) {
CreateDeviceAddressByMapTensorNode(device_context, value_node, 0);
continue;
}
const auto &node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value);
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {

View File

@ -272,7 +272,7 @@ ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
}
auto shape_temp = tensor->shape();
if (IsDynamic(shape_temp)) {
if (IsDynamic(shape_temp) && !tensor->isa<tensor::MapTensor>()) {
auto base_shape = tensor->base_shape_ptr();
MS_EXCEPTION_IF_NULL(base_shape);
if (base_shape->cast<abstract::ShapePtr>() == nullptr) {

View File

@ -136,6 +136,7 @@ def test_map_parameter_get():
Description: Test get api for MapParameter.
Expectation: get api works as expected.
"""
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
keys = Tensor([1, 2], dtype=ms.int32)
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
@ -165,6 +166,7 @@ def test_map_parameter_put():
Description: Test put api for MapParameter.
Expectation: put api works as expected.
"""
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
keys = Tensor([1, 2], dtype=ms.int32)
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
@ -184,6 +186,7 @@ def test_map_parameter_erase():
Description: Test erase api for MapParameter.
Expectation: erase api works as expected.
"""
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
keys = Tensor([1, 2], dtype=ms.int32)
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
@ -202,7 +205,9 @@ def test_basic_operations():
Description: Test MapParameter basic operations.
Expectation: MapParameter works as expected.
"""
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(2), default_value='zeros', name='my_map')
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(2), default_value='zeros',
name='my_map')
assert m.name == 'my_map'
assert m.requires_grad
@ -265,6 +270,7 @@ def test_export_update_api():
Description: Test export update api for MapParameter.
Expectation: Export update api works as expected.
"""
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
m1 = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
data1 = m1.export_data(incremental=False)
print("data1:", data1)
@ -373,11 +379,12 @@ def test_map_parameter_in_construct():
self.value_tensor = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
self.new_key_tensor = Tensor([3, 4], dtype=ms.int32)
self.new_value_tensor = Tensor([[3, 3], [4, 4]], dtype=ms.float32)
self.map_tensor = MapParameter(key_tensor=self.key_tensor, value_tensor=self.value_tensor)
def construct(self):
new_map_tensor = MapParameter(self.key_tensor, self.value_tensor, self.default_value)
new_map_tensor.put(self.new_key_tensor, self.new_value_tensor)
return new_map_tensor
return new_map_tensor, self.map_tensor
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
@ -395,6 +402,7 @@ def test_map_parameter_get_data_api():
Description: Test get_data api for MapParameter.
Expectation: get_data api works as expected.
"""
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
keys = Tensor([1, 2], dtype=ms.int32)
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')