paint-brush
Как улучшить распараллеливание загрузчиков данных Torch с помощью Torch.multiprocessingк@pixelperfectionist
540 чтения
540 чтения

Как улучшить распараллеливание загрузчиков данных Torch с помощью Torch.multiprocessing

к Prerak Mody13m2024/06/10
Read on Terminal Reader

Слишком долго; Читать

Загрузчик данных PyTorch — это инструмент для эффективной загрузки и предварительной обработки данных для обучения моделей глубокого обучения. В этом посте мы рассмотрим, как можно ускорить этот процесс, используя наш собственный загрузчик данных вместе с torch.multiprocessing. Мы экспериментируем с загрузкой нескольких 2D-срезов из набора данных 3D-медицинских сканирований.
featured image - Как улучшить распараллеливание загрузчиков данных Torch с помощью Torch.multiprocessing
Prerak Mody HackerNoon profile picture
0-item

Введение

DataLoader PyTorch ( torch.utils.data.Dataloader ) уже является полезным инструментом для эффективной загрузки и предварительной обработки данных для обучения моделей глубокого обучения. По умолчанию PyTorch использует процесс с одним рабочим процессом ( num_workers=0 ), но пользователи могут указать большее число, чтобы использовать параллелизм и ускорить загрузку данных.


Однако, поскольку это загрузчик данных общего назначения и хотя он предлагает распараллеливание, он все равно не подходит для определенных случаев использования. В этом посте мы исследуем, как можно ускорить загрузку нескольких 2D-срезов из набора данных 3D-медицинских сканирований с помощью torch.multiprocessing() .


Мы хотим извлечь набор срезов из 3D-сканирования каждого пациента. Эти пациенты являются частью большого набора данных.



Наш torch.utils.data.Dataset

Я представляю себе вариант использования, в котором дан набор 3D-сканов пациентов (т. е. P1, P2, P3,…) и список соответствующих срезов; Наша цель — создать загрузчик данных, который выводит срез на каждой итерации . Проверьте приведенный ниже код Python , где мы создаем набор данных Torch с именем myDataset и передаем его в torch.utils.data.Dataloader() .


 # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])


Основная проблема нашего варианта использования заключается в том, что медицинские 3D-сканы имеют большой размер ( здесь эмулируется операцией time.sleep() ) и, следовательно,

  • чтение их с диска может занять много времени

  • а большой набор данных 3D-сканирований в большинстве случаев невозможно предварительно прочитать в памяти


В идеале мы должны читать каждый скан пациента только один раз для всех связанных с ним срезов. Но поскольку данные разбиваются с помощью torch.utils.data.dataloader(myDataset, batch_size=b, workers=n) на рабочие в зависимости от размера пакета, у разных рабочих есть возможность прочитать данные пациента дважды ( проверьте изображение и журнал). ниже ).

Torch разделяет загрузку набора данных на каждого работника в зависимости от размера пакета (в данном случае = 3). Благодаря этому каждого пациента осматривают несколько работников.


 - [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]


Подводя итог, вот проблемы с существующей реализацией torch.utils.data.Dataloader

  • Каждому из рабочих передается копия myDataset() (ссылка: фонарик v1.2. 0 ), а поскольку у них нет общей памяти, это приводит к двойному чтению 3D-скана пациента с диска.


  • Более того, поскольку факел последовательно проходит по patientSliceList ( см. изображение ниже ), естественное перетасовка между комбинациями ( PatientId, SliceId ) невозможна. ( Примечание: можно перетасовать, но это предполагает сохранение выходных данных в памяти )


Стандартный torch.utils.data.Dataloader() имеет внутреннюю очередь, которая глобально управляет тем, как выходные данные извлекаются из рабочих процессов. Даже если данные готовы конкретным исполнителем, он не может их вывести, поскольку он должен учитывать эту глобальную очередь.



Примечание. Можно также просто вернуть несколько срезов из 3D-сканирования каждого пациента. Но если мы хотим также возвращать слайс-зависимые 3D-массивы (например, интерактивные сети уточнения ( см. рис.1 этой работы ), то это сильно увеличивает объем памяти вашего загрузчика данных.



Использование torch.multiprocessing

Чтобы предотвратить многократное чтение сканов пациентов , в идеале нам нужно, чтобы каждого пациента ( представим себе 8 пациентов ) считывал конкретный работник.

Здесь каждый работник сосредоточен на чтении пациентов (группы пациентов).


Для этого мы используем те же внутренние инструменты, что и класс загрузчика данных torch (т. е. torch.multiprocessing() ), но с небольшой разницей. Проверьте рисунок рабочего процесса и код ниже для нашего пользовательского загрузчика данных — myDataloader .

Здесь очередь вывода (внизу) содержит выходные данные каждого работника. Каждый работник получает входную информацию (очередь ввода показана вверху) только для определенного набора пациентов. Таким образом, это предотвращает многократное считывание 3D-скана пациента.



 # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()


Приведенный выше фрагмент ( с 8 пациентами вместо этого ) содержит следующие функции:

  • __iter__() — поскольку myDataloader() — это цикл, именно эту функцию он фактически выполняет в цикле.


  • _initWorkers() — здесь мы создаем наши рабочие процессы с их отдельными входными очередями workerInputQueues[workerId] . Это вызывается при инициализации класса.


  • fillInputQueues() — эта функция вызывается, когда мы начинаем цикл ( по сути, в начале каждой эпохи ). Он заполняет входную очередь отдельного работника.


  • getSlice() — это основная логическая функция, которая возвращает срез из тома пациента. Проверьте код здесь .


  • collate_tensor_fn() — эта функция напрямую скопирована из репозитория Torch — torchv1.12.0 и используется для пакетной обработки данных.


Производительность

Чтобы проверить, обеспечивает ли наш загрузчик данных ускорение по сравнению с опцией по умолчанию, мы проверяем скорость каждого цикла загрузчика данных , используя различные счетчики рабочих процессов . В наших экспериментах мы варьировали два параметра:


  • Количество работников : Мы протестировали 1, 2, 4 и 8 рабочих процессов.
  • Размер партии : Мы оценили различные размеры партий от 1 до 8.

Набор данных игрушек

Сначала мы поэкспериментируем с нашим игрушечным набором данных и увидим, что наш загрузчик данных работает намного быстрее. См. рисунок ниже (или воспроизведите его с помощью этого кода ).
Меньшее общее время и большее количество итераций в секунду означает лучший загрузчик данных.

Здесь мы можем увидеть следующее

  • При использовании одного работника оба загрузчика данных одинаковы.


  • При использовании дополнительных воркеров (т.е. 2,4,8) наблюдается ускорение в обоих загрузчиках данных, однако в нашем пользовательском загрузчике данных ускорение намного выше.


  • При использовании размера пакета 6 (по сравнению с 1,2,3,4) производительность немного снижается. Это связано с тем, что в нашем игрушечном наборе данных переменная patientSlicesList содержит 5 срезов на каждого пациента. Таким образом, работнику нужно подождать, чтобы прочитать второго пациента, чтобы добавить его к последнему индексу пакета.

Реальный набор данных

Затем мы сравниваем реальный набор данных, в который загружаются 3D-сканы, извлекается срез, выполняется некоторая дополнительная предварительная обработка , а затем возвращаются срез и другие массивы. Результаты смотрите на рисунке ниже.


Мы заметили, что увеличение количества рабочих процессов (и размеров пакетов) обычно приводило к более быстрой загрузке данных и, следовательно, может привести к более быстрому обучению. Для меньших размеров пакетов (например, 1 или 2) удвоение числа рабочих приводило к гораздо большему ускорению. Однако по мере увеличения размера партии незначительное улучшение от добавления большего количества рабочих уменьшалось.

Чем выше количество итераций в секунду, тем быстрее загрузчик данных.

Использование ресурсов

Мы также отслеживали использование ресурсов во время загрузки данных с различным количеством рабочих. При большем количестве рабочих мы наблюдали увеличение использования ЦП и памяти, что ожидаемо из-за параллелизма, введенного дополнительными процессами. Пользователи должны учитывать свои аппаратные ограничения и доступность ресурсов при выборе оптимального количества рабочих.

Краткое содержание

  1. В этом сообщении блога мы рассмотрели ограничения стандартного DataLoader PyTorch при работе с наборами данных, содержащими большие медицинские 3D-сканы, и представили специальное решение, использующее torch.multiprocessing для повышения эффективности загрузки данных.


  2. В контексте извлечения срезов из этих медицинских 3D-сканирований загрузчик данных по умолчанию потенциально может привести к многократному чтению одного и того же скана пациента, поскольку рабочие не используют общую память. Эта избыточность вызывает значительные задержки, особенно при работе с большими наборами данных.


  3. Наш специальный загрузчик данных разделяет пациентов между работниками, гарантируя, что каждое 3D-сканирование считывается только один раз для каждого работника. Этот подход предотвращает избыточное чтение с диска и использует параллельную обработку для ускорения загрузки данных.


  4. Тестирование производительности показало, что наш пользовательский dataLoader обычно превосходит стандартный dataLoader, особенно при небольших размерах пакетов и нескольких рабочих процессах.


    1. Однако прирост производительности уменьшался при увеличении размера партии.


Наш специальный dataLoader повышает эффективность загрузки данных для больших наборов трехмерных медицинских данных за счет сокращения избыточного чтения и максимального увеличения параллелизма. Это улучшение может привести к сокращению времени обучения и лучшему использованию аппаратных ресурсов.


Этот блог был написан вместе с моим коллегой Цзиннан Цзя .