forked from mindspore-Ecosystem/mindspore
!46671 Support new MapParameter in construct.
Merge pull request !46671 from Margaret_wangrui/map_parameter
This commit is contained in:
commit
44b31ff9c1
|
@ -189,6 +189,12 @@ void DeviceAddressUtils::CreateValueNodeDeviceAddress(const DeviceContext *devic
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto &abstract = value_node->abstract();
|
||||||
|
if (abstract != nullptr && abstract->isa<abstract::AbstractMapTensor>()) {
|
||||||
|
CreateDeviceAddressByMapTensorNode(device_context, value_node, 0);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
const auto &node_value = value_node->value();
|
const auto &node_value = value_node->value();
|
||||||
MS_EXCEPTION_IF_NULL(node_value);
|
MS_EXCEPTION_IF_NULL(node_value);
|
||||||
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
|
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
|
||||||
|
|
|
@ -272,7 +272,7 @@ ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
|
||||||
MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
|
MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
|
||||||
}
|
}
|
||||||
auto shape_temp = tensor->shape();
|
auto shape_temp = tensor->shape();
|
||||||
if (IsDynamic(shape_temp)) {
|
if (IsDynamic(shape_temp) && !tensor->isa<tensor::MapTensor>()) {
|
||||||
auto base_shape = tensor->base_shape_ptr();
|
auto base_shape = tensor->base_shape_ptr();
|
||||||
MS_EXCEPTION_IF_NULL(base_shape);
|
MS_EXCEPTION_IF_NULL(base_shape);
|
||||||
if (base_shape->cast<abstract::ShapePtr>() == nullptr) {
|
if (base_shape->cast<abstract::ShapePtr>() == nullptr) {
|
||||||
|
|
|
@ -136,24 +136,25 @@ def test_map_parameter_get():
|
||||||
Description: Test get api for MapParameter.
|
Description: Test get api for MapParameter.
|
||||||
Expectation: get api works as expected.
|
Expectation: get api works as expected.
|
||||||
"""
|
"""
|
||||||
keys = Tensor([1, 2], dtype=ms.int32)
|
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||||
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
keys = Tensor([1, 2], dtype=ms.int32)
|
||||||
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
||||||
key = Tensor([3], dtype=ms.int32)
|
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
||||||
get_value = map_tensor.get(key)
|
key = Tensor([3], dtype=ms.int32)
|
||||||
print("get_value:", get_value)
|
get_value = map_tensor.get(key)
|
||||||
data1 = map_tensor.export_data(incremental=False)
|
print("get_value:", get_value)
|
||||||
print("data1:", data1)
|
data1 = map_tensor.export_data(incremental=False)
|
||||||
|
print("data1:", data1)
|
||||||
|
|
||||||
map_tensor.put(Tensor([3], dtype=ms.int32), Tensor([[3, 3]], dtype=ms.float32))
|
map_tensor.put(Tensor([3], dtype=ms.int32), Tensor([[3, 3]], dtype=ms.float32))
|
||||||
data2 = map_tensor.export_data(incremental=False)
|
data2 = map_tensor.export_data(incremental=False)
|
||||||
print("data2:", data2)
|
print("data2:", data2)
|
||||||
map_tensor[Tensor([1, 2, 3], dtype=ms.int32)] = Tensor([[11, 11], [22, 22], [33, 33]], dtype=ms.float32)
|
map_tensor[Tensor([1, 2, 3], dtype=ms.int32)] = Tensor([[11, 11], [22, 22], [33, 33]], dtype=ms.float32)
|
||||||
data3 = map_tensor.export_data(incremental=False)
|
data3 = map_tensor.export_data(incremental=False)
|
||||||
print("data3:", data3)
|
print("data3:", data3)
|
||||||
map_tensor.erase(Tensor([1, 2, 3], dtype=ms.int32))
|
map_tensor.erase(Tensor([1, 2, 3], dtype=ms.int32))
|
||||||
data4 = map_tensor.export_data(incremental=False)
|
data4 = map_tensor.export_data(incremental=False)
|
||||||
print("data4:", data4)
|
print("data4:", data4)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
@ -165,14 +166,15 @@ def test_map_parameter_put():
|
||||||
Description: Test put api for MapParameter.
|
Description: Test put api for MapParameter.
|
||||||
Expectation: put api works as expected.
|
Expectation: put api works as expected.
|
||||||
"""
|
"""
|
||||||
keys = Tensor([1, 2], dtype=ms.int32)
|
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||||
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
keys = Tensor([1, 2], dtype=ms.int32)
|
||||||
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
||||||
key = Tensor([3], dtype=ms.int32)
|
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
||||||
value = Tensor([[4, 5]], dtype=ms.float32)
|
key = Tensor([3], dtype=ms.int32)
|
||||||
map_tensor.put(key, value)
|
value = Tensor([[4, 5]], dtype=ms.float32)
|
||||||
data1 = map_tensor.export_data(incremental=False)
|
map_tensor.put(key, value)
|
||||||
print("data1:", data1)
|
data1 = map_tensor.export_data(incremental=False)
|
||||||
|
print("data1:", data1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
@ -184,13 +186,14 @@ def test_map_parameter_erase():
|
||||||
Description: Test erase api for MapParameter.
|
Description: Test erase api for MapParameter.
|
||||||
Expectation: erase api works as expected.
|
Expectation: erase api works as expected.
|
||||||
"""
|
"""
|
||||||
keys = Tensor([1, 2], dtype=ms.int32)
|
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||||
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
keys = Tensor([1, 2], dtype=ms.int32)
|
||||||
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
||||||
key = Tensor([2], dtype=ms.int32)
|
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
||||||
map_tensor.erase(key)
|
key = Tensor([2], dtype=ms.int32)
|
||||||
data1 = map_tensor.export_data(incremental=False)
|
map_tensor.erase(key)
|
||||||
print("data1:", data1)
|
data1 = map_tensor.export_data(incremental=False)
|
||||||
|
print("data1:", data1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
@ -202,24 +205,26 @@ def test_basic_operations():
|
||||||
Description: Test MapParameter basic operations.
|
Description: Test MapParameter basic operations.
|
||||||
Expectation: MapParameter works as expected.
|
Expectation: MapParameter works as expected.
|
||||||
"""
|
"""
|
||||||
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(2), default_value='zeros', name='my_map')
|
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||||
assert m.name == 'my_map'
|
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(2), default_value='zeros',
|
||||||
assert m.requires_grad
|
name='my_map')
|
||||||
|
assert m.name == 'my_map'
|
||||||
|
assert m.requires_grad
|
||||||
|
|
||||||
t = m.get(Tensor([1, 2, 3], dtype=ms.int32))
|
t = m.get(Tensor([1, 2, 3], dtype=ms.int32))
|
||||||
assert t.dtype == ms.float32
|
assert t.dtype == ms.float32
|
||||||
assert t.shape == (3, 2)
|
assert t.shape == (3, 2)
|
||||||
assert np.allclose(t.asnumpy(), 0)
|
assert np.allclose(t.asnumpy(), 0)
|
||||||
|
|
||||||
t = m[Tensor([1, 2, 3], dtype=ms.int32)]
|
t = m[Tensor([1, 2, 3], dtype=ms.int32)]
|
||||||
assert t.dtype == ms.float32
|
assert t.dtype == ms.float32
|
||||||
assert t.shape == (3, 2)
|
assert t.shape == (3, 2)
|
||||||
assert np.allclose(t.asnumpy(), 0)
|
assert np.allclose(t.asnumpy(), 0)
|
||||||
|
|
||||||
m.put(Tensor([1, 2, 3], dtype=ms.int32), Tensor([[1, 1], [2, 2], [3, 3]], dtype=ms.float32))
|
m.put(Tensor([1, 2, 3], dtype=ms.int32), Tensor([[1, 1], [2, 2], [3, 3]], dtype=ms.float32))
|
||||||
data = m.export_data()
|
data = m.export_data()
|
||||||
print(m)
|
print(m)
|
||||||
print("data:", data)
|
print("data:", data)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
@ -265,19 +270,20 @@ def test_export_update_api():
|
||||||
Description: Test export update api for MapParameter.
|
Description: Test export update api for MapParameter.
|
||||||
Expectation: Export update api works as expected.
|
Expectation: Export update api works as expected.
|
||||||
"""
|
"""
|
||||||
m1 = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
|
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||||
data1 = m1.export_data(incremental=False)
|
m1 = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
|
||||||
print("data1:", data1)
|
data1 = m1.export_data(incremental=False)
|
||||||
m1.import_data(data1)
|
print("data1:", data1)
|
||||||
|
m1.import_data(data1)
|
||||||
|
|
||||||
keys = Tensor([1, 2], dtype=ms.int32)
|
keys = Tensor([1, 2], dtype=ms.int32)
|
||||||
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
||||||
m2 = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
m2 = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
||||||
data2 = m2.export_data(incremental=False)
|
data2 = m2.export_data(incremental=False)
|
||||||
print("data2:", data2)
|
print("data2:", data2)
|
||||||
m1.import_data(data2)
|
m1.import_data(data2)
|
||||||
new_data1 = m1.export_data(incremental=False)
|
new_data1 = m1.export_data(incremental=False)
|
||||||
print("new_data1:", new_data1)
|
print("new_data1:", new_data1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
@ -373,11 +379,12 @@ def test_map_parameter_in_construct():
|
||||||
self.value_tensor = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
self.value_tensor = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
||||||
self.new_key_tensor = Tensor([3, 4], dtype=ms.int32)
|
self.new_key_tensor = Tensor([3, 4], dtype=ms.int32)
|
||||||
self.new_value_tensor = Tensor([[3, 3], [4, 4]], dtype=ms.float32)
|
self.new_value_tensor = Tensor([[3, 3], [4, 4]], dtype=ms.float32)
|
||||||
|
self.map_tensor = MapParameter(key_tensor=self.key_tensor, value_tensor=self.value_tensor)
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
new_map_tensor = MapParameter(self.key_tensor, self.value_tensor, self.default_value)
|
new_map_tensor = MapParameter(self.key_tensor, self.value_tensor, self.default_value)
|
||||||
new_map_tensor.put(self.new_key_tensor, self.new_value_tensor)
|
new_map_tensor.put(self.new_key_tensor, self.new_value_tensor)
|
||||||
return new_map_tensor
|
return new_map_tensor, self.map_tensor
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||||
|
@ -395,12 +402,13 @@ def test_map_parameter_get_data_api():
|
||||||
Description: Test get_data api for MapParameter.
|
Description: Test get_data api for MapParameter.
|
||||||
Expectation: get_data api works as expected.
|
Expectation: get_data api works as expected.
|
||||||
"""
|
"""
|
||||||
keys = Tensor([1, 2], dtype=ms.int32)
|
if not 'SAULT_ENV_TYPE' in os.environ or not "CUDA10" in os.environ['SAULT_ENV_TYPE']:
|
||||||
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
keys = Tensor([1, 2], dtype=ms.int32)
|
||||||
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
||||||
[the_keys, the_values] = map_tensor.get_data()
|
map_tensor = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
||||||
print("the_keys:", the_keys)
|
[the_keys, the_values] = map_tensor.get_data()
|
||||||
print("the_values:", the_values)
|
print("the_keys:", the_keys)
|
||||||
|
print("the_values:", the_values)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
|
Loading…
Reference in New Issue