forked from mindspore-Ecosystem/mindspore
!46671 Support new MapParameter in construct.
Merge pull request !46671 from Margaret_wangrui/map_parameter
This commit is contained in:
commit
44b31ff9c1
|
@ -189,6 +189,12 @@ void DeviceAddressUtils::CreateValueNodeDeviceAddress(const DeviceContext *devic
|
|||
continue;
|
||||
}
|
||||
|
||||
const auto &abstract = value_node->abstract();
|
||||
if (abstract != nullptr && abstract->isa<abstract::AbstractMapTensor>()) {
|
||||
CreateDeviceAddressByMapTensorNode(device_context, value_node, 0);
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &node_value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(node_value);
|
||||
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
|
||||
|
|
|
@ -272,7 +272,7 @@ ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
|
|||
MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
|
||||
}
|
||||
auto shape_temp = tensor->shape();
|
||||
if (IsDynamic(shape_temp)) {
|
||||
if (IsDynamic(shape_temp) && !tensor->isa<tensor::MapTensor>()) {
|
||||
auto base_shape = tensor->base_shape_ptr();
|
||||
MS_EXCEPTION_IF_NULL(base_shape);
|
||||
if (base_shape->cast<abstract::ShapePtr>() == nullptr) {
|
||||
|
|
|
@ -136,6 +136,7 @@ def test_map_parameter_get():
|
|||
Description: Test get api for MapParameter.
|
||||
Expectation: get api works as expected.
|
||||
"""
|
||||
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||
keys = Tensor([1, 2], dtype=ms.int32)
|
||||
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
||||
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
||||
|
@ -165,6 +166,7 @@ def test_map_parameter_put():
|
|||
Description: Test put api for MapParameter.
|
||||
Expectation: put api works as expected.
|
||||
"""
|
||||
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||
keys = Tensor([1, 2], dtype=ms.int32)
|
||||
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
||||
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
||||
|
@ -184,6 +186,7 @@ def test_map_parameter_erase():
|
|||
Description: Test erase api for MapParameter.
|
||||
Expectation: erase api works as expected.
|
||||
"""
|
||||
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||
keys = Tensor([1, 2], dtype=ms.int32)
|
||||
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
||||
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
||||
|
@ -202,7 +205,9 @@ def test_basic_operations():
|
|||
Description: Test MapParameter basic operations.
|
||||
Expectation: MapParameter works as expected.
|
||||
"""
|
||||
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(2), default_value='zeros', name='my_map')
|
||||
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(2), default_value='zeros',
|
||||
name='my_map')
|
||||
assert m.name == 'my_map'
|
||||
assert m.requires_grad
|
||||
|
||||
|
@ -265,6 +270,7 @@ def test_export_update_api():
|
|||
Description: Test export update api for MapParameter.
|
||||
Expectation: Export update api works as expected.
|
||||
"""
|
||||
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||
m1 = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
|
||||
data1 = m1.export_data(incremental=False)
|
||||
print("data1:", data1)
|
||||
|
@ -373,11 +379,12 @@ def test_map_parameter_in_construct():
|
|||
self.value_tensor = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
||||
self.new_key_tensor = Tensor([3, 4], dtype=ms.int32)
|
||||
self.new_value_tensor = Tensor([[3, 3], [4, 4]], dtype=ms.float32)
|
||||
self.map_tensor = MapParameter(key_tensor=self.key_tensor, value_tensor=self.value_tensor)
|
||||
|
||||
def construct(self):
|
||||
new_map_tensor = MapParameter(self.key_tensor, self.value_tensor, self.default_value)
|
||||
new_map_tensor.put(self.new_key_tensor, self.new_value_tensor)
|
||||
return new_map_tensor
|
||||
return new_map_tensor, self.map_tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||
|
@ -395,6 +402,7 @@ def test_map_parameter_get_data_api():
|
|||
Description: Test get_data api for MapParameter.
|
||||
Expectation: get_data api works as expected.
|
||||
"""
|
||||
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||
keys = Tensor([1, 2], dtype=ms.int32)
|
||||
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
||||
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
||||
|
|
Loading…
Reference in New Issue