concat dataset
class ConcatDataset(torch.utils.data.Dataset):
def __init__(self, *datasets):
self.datasets = datasets
def __getitem__(self, i):
return tuple(d[i] for d in self.datasets)
def __len__(self):
return min(len(d) for d in self.datasets)
train_loader = torch.utils.data.DataLoader(
ConcatDataset( # concat
datasets.ImageFolder(traindir_A),
datasets.ImageFolder(traindir_B)
),
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
for i, (input, target) in enumerate(train_loader):
...