paint-brush
Cách cải thiện tính năng song song hóa của bộ tải dữ liệu Torch bằng Torch.multiprocessingtừ tác giả@pixelperfectionist
540 lượt đọc
540 lượt đọc

Cách cải thiện tính năng song song hóa của bộ tải dữ liệu Torch bằng Torch.multiprocessing

từ tác giả Prerak Mody13m2024/06/10
Read on Terminal Reader

dài quá đọc không nổi

Trình tải dữ liệu PyTorch là một công cụ giúp tải và xử lý trước dữ liệu một cách hiệu quả để đào tạo các mô hình học sâu. Trong bài đăng này, chúng tôi khám phá cách chúng tôi có thể tăng tốc quá trình này bằng cách sử dụng trình tải dữ liệu tùy chỉnh cùng với torch.multiprocessing. Chúng tôi thử nghiệm tải nhiều lát 2D từ tập dữ liệu quét y tế 3D.
featured image - Cách cải thiện tính năng song song hóa của bộ tải dữ liệu Torch bằng Torch.multiprocessing
Prerak Mody HackerNoon profile picture
0-item

Giới thiệu

DataLoader của PyTorch ( torch.utils.data.Dataloader ) đã là một công cụ hữu ích để tải và xử lý trước dữ liệu một cách hiệu quả nhằm đào tạo các mô hình học sâu. Theo mặc định, PyTorch sử dụng quy trình một công nhân ( num_workers=0 ), nhưng người dùng có thể chỉ định số cao hơn để tận dụng tính song song và tăng tốc độ tải dữ liệu.


Tuy nhiên, vì đây là trình tải dữ liệu có mục đích chung và mặc dù cung cấp tính năng song song nhưng nó vẫn không phù hợp với một số trường hợp sử dụng tùy chỉnh nhất định. Trong bài đăng này, chúng tôi khám phá cách chúng tôi có thể tăng tốc độ tải nhiều lát 2D từ tập dữ liệu quét y tế 3D bằng cách sử dụng torch.multiprocessing() .


Chúng tôi muốn trích xuất một tập hợp các lát cắt từ bản quét 3D của mỗi bệnh nhân. Những bệnh nhân này là một phần của một tập dữ liệu lớn.



torch.utils.data.Dataset của chúng tôi

Tôi tưởng tượng một trường hợp sử dụng trong đó cung cấp một bộ ảnh quét 3D cho bệnh nhân (tức là P1, P2, P3, …) và danh sách các lát cắt tương ứng; Mục tiêu của chúng tôi là xây dựng một trình tải dữ liệu tạo ra một lát cắt trong mỗi lần lặp . Hãy kiểm tra mã Python bên dưới nơi chúng tôi xây dựng tập dữ liệu ngọn đuốc có tên myDataset và chuyển nó vào torch.utils.data.Dataloader() .


 # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-torchdataloader-py import tqdm import time import torch # v1.12.1 import numpy as np ################################################## # myDataset ################################################## def getPatientArray(patientName): # return patients 3D scan def getPatientSliceArray(patientName, sliceId, patientArray=None): # return patientArray and a slice class myDataset(torch.utils.data.Dataset): def __init__(self, patientSlicesList, patientsInMemory=1): ... self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. def _managePatientObj(self, patientName): if len(self.patientObj) > self.patientsInMemory: self.patientObj.pop(list(self.patientObj.keys())[0]) def __getitem__(self, idx): # Step 0 - Init patientName, sliceId = ... # Step 1 - Get patient slice array patientArrayThis = self.patientObj.get(patientName, None) patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis) if patientArray is not None: self.patientObj[patientName] = patientArray self._managePatientObj(patientName) return patientSliceArray, [patientName, sliceId] ################################################## # Main ################################################## if __name__ == '__main__': # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2.1 - Create dataset and dataloader dataset = myDataset(patientSlicesList) dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, num_workers=4) # Step 2.2 - Iterate over dataloader print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) for i, (patientSliceArray, meta) in enumerate(dataloader): print (' - [main] meta: ', meta) pbar.update(patientSliceArray.shape[0])


Mối quan tâm chính trong trường hợp sử dụng của chúng tôi là các bản quét y tế 3D có kích thước lớn ( được mô phỏng ở đây bằng thao tác time.sleep() ) và do đó

  • đọc chúng từ đĩa có thể tốn nhiều thời gian

  • và một tập dữ liệu lớn về quét 3D trong hầu hết các trường hợp không thể đọc trước vào bộ nhớ


Lý tưởng nhất là chúng ta chỉ nên đọc mỗi bản quét của bệnh nhân một lần cho tất cả các lát cắt liên quan đến nó. Nhưng vì dữ liệu được chia bởi torch.utils.data.dataloader(myDataset, batch_size=b, workers=n) thành các công nhân tùy thuộc vào kích thước lô, nên có khả năng các công nhân khác nhau sẽ đọc bệnh nhân hai lần ( kiểm tra hình ảnh và nhật ký dưới ).

Torch chia việc tải tập dữ liệu vào từng công nhân tùy thuộc vào kích thước lô (= 3, trong trường hợp này). Do đó, mỗi bệnh nhân được đọc bởi nhiều công nhân.


 - [main] Iterating over (my) dataloader... - [main] --------------------------------------- Epoch 1/3 - [getPatientArray()][worker=3] Loading volumes for patient: P2 - [getPatientArray()][worker=1] Loading volumes for patient: P1 - [getPatientArray()][worker=2] Loading volumes for patient: P2 - [getPatientArray()][worker=0] Loading volumes for patient: P1 - [getPatientArray()][worker=3] Loading volumes for patient: P3 - [main] meta: [('P1', 'P1', 'P1'), tensor([45, 62, 32])] - [getPatientArray()][worker=1] Loading volumes for patient: P2 - [main] meta: [('P1', 'P1', 'P2'), tensor([21, 69, 13])] - [main] meta: [('P2', 'P2', 'P2'), tensor([23, 87, 54])] - [main] meta: [('P2', 'P3', 'P3'), tensor([ 5, 34, 56])] - [getPatientArray()][worker=2] Loading volumes for patient: P4 - [getPatientArray()][worker=0] Loading volumes for patient: P3 - [getPatientArray()][worker=1] Loading volumes for patient: P4 - [main] meta: [('P3', 'P3', 'P3'), tensor([78, 90, 12])] - [main] meta: [('P4', 'P4', 'P4'), tensor([34, 56, 78])] - [main] meta: [('P4', 'P4'), tensor([90, 12])]


Tóm lại, đây là các vấn đề với việc triển khai torch.utils.data.Dataloader hiện có

  • Mỗi công nhân được chuyển một bản sao của myDataset() (Tham khảo: ngọn đuốc v1.2. 0 ) và vì họ không có bất kỳ bộ nhớ chung nào nên điều này dẫn đến việc đọc bản quét 3D của bệnh nhân trên đĩa đôi.


  • Hơn nữa, vì ngọn đuốc tuần tự lặp lại trên patientSliceList ( xem hình ảnh bên dưới ), nên không thể xáo trộn tự nhiên giữa các tổ hợp ( PatientId, sliceId). ( Lưu ý: người ta có thể xáo trộn, nhưng điều đó liên quan đến việc lưu trữ kết quả đầu ra trong bộ nhớ )


Torch.utils.data.Dataloader() tiêu chuẩn có một hàng đợi nội bộ quản lý toàn cầu cách trích xuất đầu ra từ các công nhân. Ngay cả khi dữ liệu đã sẵn sàng bởi một nhân viên cụ thể, nó không thể xuất dữ liệu đó vì nó phải tôn trọng hàng đợi toàn cầu này.



Lưu ý: Người ta cũng có thể chỉ cần trả lại một loạt các lát cắt từ bản quét 3D của mỗi bệnh nhân. Nhưng nếu chúng ta cũng muốn trả về các mảng 3D phụ thuộc vào lát cắt (ví dụ: các mạng sàng lọc tương tác ( xem Hình 1 của tác phẩm này ), thì điều này sẽ làm tăng đáng kể dung lượng bộ nhớ của bộ nạp dữ liệu của bạn.



Sử dụng torch.multiprocessing

Để ngăn chặn nhiều lần đọc bản quét bệnh nhân , lý tưởng nhất là chúng tôi cần mỗi bệnh nhân ( hãy tưởng tượng 8 bệnh nhân ) được đọc bởi một nhân viên cụ thể.

Ở đây, mỗi nhân viên tập trung vào việc đọc (các) bệnh nhân.


Để đạt được điều này, chúng tôi sử dụng các công cụ nội bộ tương tự như lớp trình tải dữ liệu torch (tức là torch.multiprocessing() ) nhưng có một chút khác biệt. Kiểm tra hình và quy trình làm việc bên dưới để biết trình tải dữ liệu tùy chỉnh của chúng tôi - myDataloader

Ở đây, hàng đợi đầu ra (dưới cùng) chứa đầu ra của mỗi công nhân. Mỗi nhân viên chỉ nhận được thông tin đầu vào (hàng đợi đầu vào hiển thị ở trên cùng) cho một nhóm bệnh nhân cụ thể. Do đó, điều này ngăn cản việc đọc nhiều lần bản quét 3D của bệnh nhân.



 # check full code here: https://gist.github.com/prerakmody/0c5e9263d42b2fab26a48dfb6b818cca#file-mydataloader-py class myDataloader: def __init__(self, patientSlicesList, numWorkers, batchSize) -> None: ... self._initWorkers() def _initWorkers(self): # Step 1 - Initialize vas self.workerProcesses = [] self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] self.workerOutputQueue = torchMP.Queue() for workerId in range(self.numWorkers): p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue)) p.start() def fillInputQueues(self): """ This function allows to split patients and slices across workers. One can implement custom logic here. """ patientNames = list(self.patientSlicesList.keys()) for workerId in range(self.numWorkers): idxs = ... for patientName in patientNames[idxs]: for sliceId in self.patientSlicesList[patientName]: self.workerInputQueues[workerId].put((patientName, sliceId)) def emptyAllQueues(self): # empties the self.workerInputQueues and self.workerOutputQueue def __iter__(self): try: # Step 0 - Init self.fillInputQueues() # once for each epoch batchArray, batchMeta = [], [] # Step 1 - Continuously yield results while True: if not self.workerOutputQueue.empty(): # Step 2.1 - Get data point patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) # Step 2.2 - Append to batch ... # Step 2.3 - Yield batch if len(batchArray) == self.batchSize: batchArray = collate_tensor_fn(batchArray) yield batchArray, batchMeta batchArray, batchMeta = [], [] # Step 3 - End condition if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): break except GeneratorExit: self.emptyAllQueues() except KeyboardInterrupt: self.closeProcesses() except: traceback.print_exc() def closeProcesses(self): pass if __name__ == "__main__": # Step 1 - Setup patient slices (fixed count of slices per patient) patientSlicesList = { 'P1': [45, 62, 32, 21, 69] , 'P2': [13, 23, 87, 54, 5] , 'P3': [34, 56, 78, 90, 12] , 'P4': [34, 56, 78, 90, 12] , 'P5': [45, 62, 32, 21, 69] , 'P6': [13, 23, 87, 54, 5] , 'P7': [34, 56, 78, 90, 12] , 'P8': [34, 56, 78, 90, 12, 21] } workerCount, batchSize, epochs = 4, 1, 3 # Step 2 - Create new dataloader dataloaderNew = None try: dataloaderNew = myDataloader(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) print ('\n - [main] Iterating over (my) dataloader...') for epochId in range(epochs): with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: for i, (X, meta) in enumerate(dataloaderNew): print (' - [main] {}'.format(meta.tolist())) pbar.update(X.shape[0]) dataloaderNew.closeProcesses() except KeyboardInterrupt: if dataloader is not None: dataloader.closeProcesses() except: traceback.print_exc() if dataloaderNew is not None: dataloaderNew.closeProcesses()


Đoạn mã trên ( thay vào đó là 8 bệnh nhân ) chứa các chức năng sau

  • __iter__() - Vì myDataloader() là một vòng lặp nên đây thực sự là hàm mà nó lặp lại.


  • _initWorkers() - Tại đây, chúng tôi tạo các quy trình công nhân của mình với hàng đợi đầu vào riêng lẻ workerInputQueues[workerId] . Điều này được gọi khi lớp được khởi tạo.


  • fillInputQueues() - Hàm này được gọi khi chúng ta bắt đầu vòng lặp ( về cơ bản là ở đầu mỗi kỷ nguyên ). Nó lấp đầy hàng đợi đầu vào của từng công nhân.


  • getSlice() - Đây là hàm logic chính trả về một lát cắt từ ổ đĩa bệnh nhân. Kiểm tra mã ở đây .


  • collate_tensor_fn() - Hàm này được sao chép trực tiếp từ kho lưu trữ torchv1.12.0 và được sử dụng để gộp dữ liệu lại với nhau.


Hiệu suất

Để kiểm tra xem trình tải dữ liệu của chúng tôi có tăng tốc so với tùy chọn mặc định hay không, chúng tôi kiểm tra tốc độ của từng vòng lặp của trình tải dữ liệu bằng cách sử dụng số lượng nhân viên khác nhau . Chúng tôi đã thay đổi hai tham số trong thử nghiệm của mình:


  • Số lượng công nhân : Chúng tôi đã thử nghiệm các quy trình 1, 2, 4 và 8 công nhân.
  • Kích thước lô : Chúng tôi đã đánh giá các kích cỡ lô khác nhau từ 1 đến 8.

Bộ dữ liệu đồ chơi

Trước tiên, chúng tôi thử nghiệm với tập dữ liệu đồ chơi của mình và thấy rằng trình tải dữ liệu của chúng tôi hoạt động nhanh hơn nhiều. Xem hình bên dưới (hoặc sao chép bằng mã này )
Tổng thời gian thấp hơn và số lần lặp/giây cao hơn có nghĩa là trình tải dữ liệu tốt hơn.

Ở đây chúng ta có thể thấy như sau

  • Khi sử dụng một trình chạy duy nhất, cả hai trình tải dữ liệu đều giống nhau.


  • Khi sử dụng các công cụ bổ sung (tức là 2,4,8), cả hai trình tải dữ liệu đều tăng tốc, tuy nhiên, tốc độ tăng tốc cao hơn nhiều trong trình tải dữ liệu tùy chỉnh của chúng tôi.


  • Khi sử dụng kích thước lô 6 (so với 1,2,3,4), hiệu suất sẽ bị ảnh hưởng nhỏ. Điều này là do trong tập dữ liệu đồ chơi của chúng tôi, biến patientSlicesList chứa 5 lát cho mỗi bệnh nhân. Vì vậy, công nhân cần đợi để đọc bệnh nhân thứ hai để thêm vào chỉ mục cuối cùng của lô.

Bộ dữ liệu thế giới thực

Sau đó, chúng tôi đánh giá một tập dữ liệu thực nơi tải bản quét 3D, một lát cắt được trích xuất, một số tiền xử lý bổ sung được thực hiện , sau đó lát cắt và các mảng khác được trả về. Xem hình dưới đây để biết kết quả.


Chúng tôi quan sát thấy rằng tăng số lượng quy trình công nhân (và quy mô lô) thường dẫn đến tải dữ liệu nhanh hơn và do đó có thể dẫn đến việc đào tạo nhanh hơn. Đối với quy mô lô nhỏ hơn (ví dụ: 1 hoặc 2), việc tăng gấp đôi số lượng công nhân sẽ dẫn đến tốc độ nhanh hơn nhiều. Tuy nhiên, khi quy mô lô tăng lên, mức cải thiện cận biên từ việc bổ sung thêm công nhân sẽ giảm đi.

Số lần lặp/giây càng cao thì trình tải dữ liệu càng nhanh.

Tận dụng nguồn tài nguyên

Chúng tôi cũng giám sát việc sử dụng tài nguyên trong quá trình tải dữ liệu với số lượng nhân viên khác nhau. Với số lượng công nhân cao hơn, chúng tôi nhận thấy mức sử dụng CPU và bộ nhớ tăng lên, điều này được dự đoán là do tính song song do các quy trình bổ sung mang lại. Người dùng nên xem xét các hạn chế về phần cứng và tính sẵn có của tài nguyên khi chọn số lượng nhân viên tối ưu.

Bản tóm tắt

  1. Trong bài đăng trên blog này, chúng tôi đã khám phá những hạn chế của DataLoader tiêu chuẩn của PyTorch khi xử lý các tập dữ liệu chứa các bản quét y tế 3D lớn và trình bày một giải pháp tùy chỉnh bằng cách sử dụng torch.multiprocessing để cải thiện hiệu quả tải dữ liệu.


  2. Trong bối cảnh trích xuất lát cắt từ các bản quét y tế 3D này, Trình tải dữ liệu mặc định có thể dẫn đến nhiều lần đọc cùng một bản quét bệnh nhân do nhân viên không chia sẻ bộ nhớ. Sự dư thừa này gây ra sự chậm trễ đáng kể, đặc biệt khi xử lý các tập dữ liệu lớn.


  3. Trình tải dữ liệu tùy chỉnh của chúng tôi phân chia bệnh nhân giữa các nhân viên, đảm bảo rằng mỗi lần quét 3D chỉ được đọc một lần cho mỗi nhân viên. Cách tiếp cận này ngăn chặn việc đọc đĩa dư thừa và tận dụng quá trình xử lý song song để tăng tốc độ tải dữ liệu.


  4. Kiểm tra hiệu suất cho thấy rằng dataLoader tùy chỉnh của chúng tôi thường hoạt động tốt hơn dataLoader tiêu chuẩn, đặc biệt với kích thước lô nhỏ hơn và nhiều quy trình công nhân.


    1. Tuy nhiên, hiệu suất tăng giảm khi kích thước lô lớn hơn.


Trình tải dữ liệu tùy chỉnh của chúng tôi nâng cao hiệu quả tải dữ liệu cho các bộ dữ liệu y tế 3D lớn bằng cách giảm số lần đọc dư thừa và tối đa hóa tính song song. Cải tiến này có thể dẫn đến thời gian đào tạo nhanh hơn và sử dụng tài nguyên phần cứng tốt hơn.


Blog này được viết cùng với đồng nghiệp của tôi Jingnan Jia .