pytorch bucket dataloader
from torch.nn.utils.rnn import pack_sequence
from torch.utils.data import DataLoader
def my_collate(batch):
# batch contains a list of tuples of structure (sequence, target)
data = [item[0] for item in batch]
data = pack_sequence(data, enforce_sorted=False)
targets = [item[1] for item in batch]
return [data, targets]
# ...
# later in you code, when you define you DataLoader - use the custom collate function
loader = DataLoader(dataset,
batch_size,
shuffle,
collate_fn=my_collate, # use custom collate function here
pin_memory=True)