DataLoader PyTorch ( torch.utils.data.Dataloader
) уже является полезным инструментом для эффективной загрузки и предварительной обработки данных для обучения моделей глубокого обучения. По умолчанию PyTorch использует процесс с одним рабочим процессом ( num_workers=0
), но пользователи могут указать большее число, чтобы использовать параллелизм и ускорить загрузку данных.
Однако, поскольку это загрузчик данных общего назначения и хотя он предлагает распараллеливание, он все равно не подходит для определенных случаев использования. В этом посте мы исследуем, как можно ускорить загрузку нескольких 2D-срезов из набора данных 3D-медицинских сканирований с помощью torch.multiprocessing()
.
torch.utils.data.Dataset
Я представляю себе вариант использования, в котором дан набор 3D-сканов пациентов (т. е. P1, P2, P3,…) и список соответствующих срезов; Наша цель — создать загрузчик данных, который выводит срез на каждой итерации . Проверьте приведенный ниже код Python , где мы создаем набор данных Torch с именем myDataset
и передаем его в torch.utils.data.Dataloader()
.
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])
Основная проблема нашего варианта использования заключается в том, что медицинские 3D-сканы имеют большой размер ( здесь эмулируется операцией time.sleep()
) и, следовательно,
чтение их с диска может занять много времени
а большой набор данных 3D-сканирований в большинстве случаев невозможно предварительно прочитать в памяти
В идеале мы должны читать каждый скан пациента только один раз для всех связанных с ним срезов. Но поскольку данные разбиваются с помощью torch.utils.data.dataloader(myDataset, batch_size=b, workers=n)
на рабочие в зависимости от размера пакета, у разных рабочих есть возможность прочитать данные пациента дважды ( проверьте изображение и журнал). ниже ).
- [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]
Подводя итог, вот проблемы с существующей реализацией torch.utils.data.Dataloader
myDataset()
(ссылка:
patientSliceList
( см. изображение ниже ), естественное перетасовка между комбинациями ( PatientId, SliceId ) невозможна. ( Примечание: можно перетасовать, но это предполагает сохранение выходных данных в памяти )
Примечание. Можно также просто вернуть несколько срезов из 3D-сканирования каждого пациента. Но если мы хотим также возвращать слайс-зависимые 3D-массивы (например, интерактивные сети уточнения ( см. рис.1 этой работы ), то это сильно увеличивает объем памяти вашего загрузчика данных.
torch.multiprocessing
Чтобы предотвратить многократное чтение сканов пациентов , в идеале нам нужно, чтобы каждого пациента ( представим себе 8 пациентов ) считывал конкретный работник.
Для этого мы используем те же внутренние инструменты, что и класс загрузчика данных torch (т. е. torch.multiprocessing()
), но с небольшой разницей. Проверьте рисунок рабочего процесса и код ниже для нашего пользовательского загрузчика данных — myDataloader
.
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()
Приведенный выше фрагмент ( с 8 пациентами вместо этого ) содержит следующие функции:
__iter__()
— поскольку myDataloader()
— это цикл, именно эту функцию он фактически выполняет в цикле.
_initWorkers()
— здесь мы создаем наши рабочие процессы с их отдельными входными очередями workerInputQueues[workerId]
. Это вызывается при инициализации класса.
fillInputQueues()
— эта функция вызывается, когда мы начинаем цикл ( по сути, в начале каждой эпохи ). Он заполняет входную очередь отдельного работника.
getSlice()
— это основная логическая функция, которая возвращает срез из тома пациента. Проверьте код здесь .
collate_tensor_fn()
— эта функция напрямую скопирована из репозитория Torch — torchv1.12.0 и используется для пакетной обработки данных.Чтобы проверить, обеспечивает ли наш загрузчик данных ускорение по сравнению с опцией по умолчанию, мы проверяем скорость каждого цикла загрузчика данных , используя различные счетчики рабочих процессов . В наших экспериментах мы варьировали два параметра:
Сначала мы поэкспериментируем с нашим игрушечным набором данных и увидим, что наш загрузчик данных работает намного быстрее. См. рисунок ниже (или воспроизведите его с помощью этого кода ).
Здесь мы можем увидеть следующее
patientSlicesList
содержит 5 срезов на каждого пациента. Таким образом, работнику нужно подождать, чтобы прочитать второго пациента, чтобы добавить его к последнему индексу пакета. Затем мы сравниваем реальный набор данных, в который загружаются 3D-сканы, извлекается срез,
Мы заметили, что
Мы также отслеживали использование ресурсов во время загрузки данных с различным количеством рабочих. При большем количестве рабочих мы наблюдали увеличение использования ЦП и памяти, что ожидаемо из-за параллелизма, введенного дополнительными процессами. Пользователи должны учитывать свои аппаратные ограничения и доступность ресурсов при выборе оптимального количества рабочих.
В этом сообщении блога мы рассмотрели ограничения стандартного DataLoader PyTorch при работе с наборами данных, содержащими большие медицинские 3D-сканы, и представили специальное решение, использующее torch.multiprocessing
для повышения эффективности загрузки данных.
В контексте извлечения срезов из этих медицинских 3D-сканирований загрузчик данных по умолчанию потенциально может привести к многократному чтению одного и того же скана пациента, поскольку рабочие не используют общую память. Эта избыточность вызывает значительные задержки, особенно при работе с большими наборами данных.
Наш специальный загрузчик данных разделяет пациентов между работниками, гарантируя, что каждое 3D-сканирование считывается только один раз для каждого работника. Этот подход предотвращает избыточное чтение с диска и использует параллельную обработку для ускорения загрузки данных.
Тестирование производительности показало, что наш пользовательский dataLoader обычно превосходит стандартный dataLoader, особенно при небольших размерах пакетов и нескольких рабочих процессах.
Наш специальный dataLoader повышает эффективность загрузки данных для больших наборов трехмерных медицинских данных за счет сокращения избыточного чтения и максимального увеличения параллелизма. Это улучшение может привести к сокращению времени обучения и лучшему использованию аппаратных ресурсов.
Этот блог был написан вместе с моим коллегой Цзиннан Цзя .