Der DataLoader von PyTorch ( torch.utils.data.Dataloader
) ist bereits ein nützliches Tool zum effizienten Laden und Vorverarbeiten von Daten für das Training von Deep-Learning-Modellen. Standardmäßig verwendet PyTorch einen Single-Worker-Prozess ( num_workers=0
), aber Benutzer können eine höhere Zahl angeben, um Parallelität zu nutzen und das Laden der Daten zu beschleunigen.
Da es sich jedoch um einen Allzweck-Datenlader handelt und er Parallelisierung bietet, ist er für bestimmte benutzerdefinierte Anwendungsfälle nicht geeignet. In diesem Beitrag untersuchen wir, wie wir das Laden mehrerer 2D-Schnitte aus einem Datensatz mit 3D-medizinischen Scans mithilfe von torch.multiprocessing()
beschleunigen können.
torch.utils.data.Dataset
Ich stelle mir einen Anwendungsfall vor, bei dem ein Satz von 3D-Scans für Patienten (z. B. P1, P2, P3, …) und eine Liste der entsprechenden Slices gegeben sind. Unser Ziel ist es, einen Datenlader zu erstellen, der in jeder Iteration ein Slice ausgibt . Sehen Sie sich den Python-Code unten an, in dem wir einen Torch-Datensatz namens myDataset
erstellen und ihn an torch.utils.data.Dataloader()
übergeben.
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])
Das Hauptproblem bei unserem Anwendungsfall ist, dass medizinische 3D-Scans sehr groß sind ( hier emuliert durch die Operation time.sleep()
) und daher
Das Lesen von der Festplatte kann zeitintensiv sein
und ein großer Datensatz von 3D-Scans kann in den meisten Fällen nicht vorab in den Speicher eingelesen werden
Idealerweise sollten wir jeden Patientenscan für alle damit verbundenen Slices nur einmal lesen. Da die Daten jedoch von torch.utils.data.dataloader(myDataset, batch_size=b, workers=n)
je nach Batchgröße auf verschiedene Worker aufgeteilt werden, besteht die Möglichkeit, dass verschiedene Worker einen Patienten zweimal lesen ( siehe Bild und Protokoll unten ).
- [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]
Zusammenfassend sind hier die Probleme mit der bestehenden Implementierung von torch.utils.data.Dataloader
myDataset()
(Ref:
patientSliceList
schleift ( siehe Abbildung unten ), ist kein natürliches Mischen zwischen (patientId, sliceId)-Kombinationen möglich. ( Hinweis: Man kann mischen, aber dazu müssen die Ausgaben im Speicher abgelegt werden .)
Hinweis: Man könnte auch einfach eine Reihe von Schnitten aus den 3D-Scans aller Patienten zusammen zurückgeben. Wenn wir jedoch auch schnittabhängige 3D-Arrays zurückgeben möchten (z. B. interaktive Verfeinerungsnetzwerke ( siehe Abb. 1 dieser Arbeit ), erhöht dies den Speicherbedarf Ihres Datenladers erheblich.
torch.multiprocessing
Um ein mehrfaches Lesen der Patientenscans zu verhindern , wäre es idealerweise erforderlich, dass jeder Patient ( stellen wir uns 8 Patienten vor ) von einem bestimmten Mitarbeiter gelesen wird.
Um dies zu erreichen, verwenden wir die gleichen internen Tools wie die Torch-Dataloader-Klasse (d. h. torch.multiprocessing()
), jedoch mit einem kleinen Unterschied. Sehen Sie sich die Workflow-Abbildung und den Code unten für unseren benutzerdefinierten Dataloader an - myDataloader
# check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()
Das obige Snippet ( mit 8 Patienten statt ) enthält die folgenden Funktionen
__iter__()
– Da myDataloader()
eine Schleife ist, ist dies die Funktion, die tatsächlich durchlaufen wird.
_initWorkers()
- Hier erstellen wir unsere Worker-Prozesse mit ihren individuellen Eingabewarteschlangen workerInputQueues[workerId]
. Dies wird aufgerufen, wenn die Klasse initialisiert wird.
fillInputQueues()
- Diese Funktion wird aufgerufen, wenn wir die Schleife starten ( im Wesentlichen zu Beginn jeder Epoche ). Sie füllt die Eingabewarteschlange des einzelnen Workers auf.
getSlice()
– Dies ist die Hauptlogikfunktion, die einen Ausschnitt aus einem Patientenvolumen zurückgibt. Überprüfen Sie den Code hier .
collate_tensor_fn()
– Diese Funktion wurde direkt aus dem Torch-Repository ( torchv1.12.0) kopiert und wird zum Batch-Zusammenfassen von Daten verwendet.Um zu testen, ob unser Dataloader im Vergleich zur Standardoption eine Beschleunigung bietet, testen wir die Geschwindigkeit jeder Dataloader-Schleife mit unterschiedlichen Worker-Zahlen . In unseren Experimenten haben wir zwei Parameter variiert:
Wir experimentieren zunächst mit unserem Spielzeugdatensatz und stellen fest, dass unser Datenlader viel schneller arbeitet. Siehe die Abbildung unten (oder reproduzieren Sie sie mit diesem Code ).
Hier sehen wir folgendes
patientSlicesList
in unserem Spielzeugdatensatz 5 Slices pro Patient enthält. Der Worker muss also warten, bis er den zweiten Patienten lesen kann, um ihn dem letzten Index des Batches hinzuzufügen. Anschließend führen wir einen Benchmark mit einem realen Datensatz durch, in den 3D-Scans geladen werden, ein Ausschnitt extrahiert wird,
Wir haben festgestellt, dass
Wir haben auch die Ressourcennutzung während des Ladens der Daten mit unterschiedlichen Worker-Zahlen überwacht. Bei einer höheren Anzahl von Workern haben wir eine erhöhte CPU- und Speichernutzung beobachtet, was aufgrund der durch zusätzliche Prozesse eingeführten Parallelität zu erwarten ist. Benutzer sollten bei der Auswahl der optimalen Worker-Zahl ihre Hardwareeinschränkungen und Ressourcenverfügbarkeit berücksichtigen.
In diesem Blogbeitrag haben wir die Einschränkungen des Standard-DataLoaders von PyTorch beim Umgang mit Datensätzen untersucht, die große medizinische 3D-Scans enthalten, und eine benutzerdefinierte Lösung mit torch.multiprocessing
vorgestellt, um die Effizienz des Datenladens zu verbessern.
Im Zusammenhang mit der Schichtextraktion aus diesen medizinischen 3D-Scans kann der standardmäßige dataLoader möglicherweise zu mehreren Lesevorgängen desselben Patientenscans führen, da die Mitarbeiter den Speicher nicht gemeinsam nutzen. Diese Redundanz führt zu erheblichen Verzögerungen, insbesondere bei der Verarbeitung großer Datensätze.
Unser benutzerdefinierter dataLoader teilt Patienten zwischen Mitarbeitern auf und stellt sicher, dass jeder 3D-Scan nur einmal pro Mitarbeiter gelesen wird. Dieser Ansatz verhindert redundante Festplattenlesevorgänge und nutzt die parallele Verarbeitung, um das Laden der Daten zu beschleunigen.
Leistungstests haben gezeigt, dass unser benutzerdefinierter DataLoader den Standard-DataLoader im Allgemeinen übertrifft, insbesondere bei kleineren Batchgrößen und mehreren Arbeitsprozessen.
Unser benutzerdefinierter dataLoader verbessert die Effizienz beim Laden großer medizinischer 3D-Datensätze, indem er redundante Lesevorgänge reduziert und die Parallelität maximiert. Diese Verbesserung kann zu schnelleren Trainingszeiten und einer besseren Nutzung der Hardwareressourcen führen.
Dieser Blog wurde zusammen mit meinem Kollegen Jingnan Jia geschrieben.