Source code for neodroidvision.utilities.torch_utilities.distributing.serialisation
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "heider"
__doc__ = r"""
Created on 5/5/22
"""
import pickle
from typing import Any, List
import torch
__all__ = ["to_byte_tensor", "serialise_byte_tensor", "deserialise_byte_tensor"]
[docs]def to_byte_tensor(data: Any, *, device: str = "cuda") -> torch.ByteTensor:
"""
:param data:
:param device:
:return:
"""
return torch.ByteTensor(
torch.ByteStorage.from_buffer(
pickle.dumps(data)
) # gets a byte representation for the data
).to(
device
) # convert this byte string into a byte tensor
[docs]def serialise_byte_tensor(
encoded_data: Any, data: Any, *, device: str = "cuda"
) -> None:
"""
:param device:
:type device:
:param encoded_data:
:param data:
:return:"""
tensor = to_byte_tensor(data, device=device)
s = tensor.numel() # encoding: first byte is the size and then rest is the data
assert s <= 255, "Can't encode data greater than 255 bytes"
encoded_data[0] = s # put the size in encoded_data
encoded_data[1 : (s + 1)] = tensor # put the encoded data in encoded_data
[docs]def deserialise_byte_tensor(size_list, tensor_list) -> List:
"""
:param size_list:
:param tensor_list:
:return:
"""
data_list = []
for size, tensor in zip(size_list, tensor_list):
data_list.append(pickle.loads(tensor.cpu().numpy().tobytes()[:size]))
return data_list