paint-brush
Parallel Merge Sort with Fork/Join Frameworkby@alexandermakeev
5,106 reads
5,106 reads

Parallel Merge Sort with Fork/Join Framework

by Alexander MakeevSeptember 17th, 2021
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

ForkJoinPool is one of the ExecutorService’s implementations to parallel tasks. It was designed to simplify parallelism for recursive tasks by breaking the single task into independent ones until they are small enough to be executed asynchronously.
featured image - Parallel Merge Sort with Fork/Join Framework
Alexander Makeev HackerNoon profile picture

In this article, I will show you how to use the ForkJoinPool, which hasn’t received significant dissemination among Java developers.


ForkJoinPool is one of the ExecutorService’s implementations. It is used in CompletableFuture and Stream API. It was designed to simplify parallelism for recursive tasks by breaking the single task into independent ones until they are small enough to be executed asynchronously. With this class, you can perform a significantly large amount of tasks in a small number of threads.


Java has a common ForkJoinPool implementation that can be created through the static commonPool() method:


ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();

ForkJoinTask

ForkJoinTask has its own Runnable and Callable implementations. These implementations are called RecursiveAction and RecursiveTask respectively. Each of them has an abstract compute() method, which must be implemented during implementation. RecursiveTask<T>#compute() returns generic <T> value, RecursiveAction#compute() returns void. Both of these implementations are inherited from the abstract ForkJoinTask class.


To start a task in ForkJoinPool, use the method T invoke(ForkJoinTask<T> task):


ForkJoinTask forkJoinTask = new ForkJoinTaskImpl(...);
forkJoinPool.invoke(forkJoinTask);


In addition to the compute() method, ForkJoinTask has the following methods: fork() and join(). In terms of usage, the ForkJoinTask#join() is similar to the Thread#join(). But in the case of a fork-join, the thread may not actually fall asleep, it will rather switch to another task. This strategy is called “work stealing”, it allows more efficient use of a limited number of threads.

Fibonacci example

Let's look at the simple example of the ForkJoinTask implementation for calculating the Fibonacci numbers:

public class FibonacciTask extends RecursiveTask<Integer> {
    private final int n;

    public FibonacciTask(int n) {
        this.n = n;
    }

    @Override
    public Integer compute() {
        if (n < 2) return n;
        FibonacciTask f1 = new FibonacciTask(n - 1);
        FibonacciTask f2 = new FibonacciTask(n - 2);
        f1.fork();
        return f2.compute() + f1.join();
    }
}

In this example, FibonacciTask implements the compute() method, which creates additional FibonacciTask instances and forks them. The join() method asks the current thread to wait until results are returned by the forked methods.

Merge Sort

Let's take a look at the more complex example of ForkJoinTask - the Merge Sort. It is based on the "Divide and Conquer" principle, old as the world. We need to divide source problems into subtasks, solve them recursively and combine the results:


For this task we will use RecursiveAction as the ForkJoinTask implementation, since we don't need a return value. We will add int[] arr as a parameter to the constructor and instantiate the class's field:


public class MergeSortAction extends RecursiveAction {

    private final int[] arr;

    public MergeSortAction(int[] arr) {
        this.arr = arr;
    }

    @Override
    public void compute() {
        ...
    }

    private void merge(int[] left, int[] right) {
        ...
    }

}


Next, we implement the merge() method. It will accept 2 arrays - left and right parts. We need to merge these unsorted arrays by comparing elements and assigning the smallest element to the appropriate arr index. When one of the arrays will be empty, the while loop will break, but the data from another non-empty array part still needs to be merged. To do this we have two additional while loops to get all the data from both arrays:


private void merge(int[] left, int[] right) {
    int i = 0, j = 0, k = 0;
    while (i < left.length && j < right.length) {
        if (left[i] < right[j])
            arr[k++] = left[i++];
        else
            arr[k++] = right[j++];
    }
    while (i < left.length) {
        arr[k++] = left[i++];
    }
    while (j < right.length) {
        arr[k++] = right[j++];
    }
}

Next, we will write the compute() method to recursively divide the original array and pass the results to the merge() method. To do this we need to calculate the middle index of the array. Then we divide the original array into two parts: left and right. To fill them we copy original data by calling System.arraycopy(Object src, int srcPos, Object dest, int destPos, int length):


@Override
public void compute() {
    if (arr.length < 2) return;
    int mid = arr.length / 2;

    int[] left = new int[mid];
    System.arraycopy(arr, 0, left, 0, mid);

    int[] right = new int[arr.length - mid];
    System.arraycopy(arr, mid, right, 0, arr.length - mid);

    ...
}


Now, we have two separate arrays. Let’s divide them recursively by creating new MergeSortAction tasks and run them asynchronously by passing them into invokeAll() method. To sort and combine two arrays we use merge() method:


@Override
public void compute() {
    if (arr.length < 2) return;
    int mid = arr.length / 2;

    int[] left = new int[mid];
    System.arraycopy(arr, 0, left, 0, mid);

    int[] right = new int[arr.length - mid];
    System.arraycopy(arr, mid, right, 0, arr.length - mid);

    invokeAll(new MergeSortAction(left), new MergeSortAction(right));
    merge(left, right);
}

Speed testing

We’ve finished our parallel Merge Sort. Let’s test it and compare performance with the non-parallel version. I used ThreadLocalRandom to generate random numbers and ZonedDateTime to calculate the execution time:


class MergeSortTest {

    private final List<MergeSort> mergeSortImpls = 
            Arrays.asList(new MergeSortImpl(), new ParallelMergeSortImpl());

    @Test
    void sort() {
        for (MergeSort mergeSort : mergeSortImpls) {
            int[] arr = IntStream
                    .range(0, 100_000_000)
                    .map(i -> ThreadLocalRandom.current().nextInt())
                    .toArray();
            ZonedDateTime now = ZonedDateTime.now();
            mergeSort.sort(arr);
            System.out.printf("%s exec time: %dms\n",
                    mergeSort.getClass().getSimpleName(),
                    ChronoUnit.MILLIS.between(now, ZonedDateTime.now()));
            assertTrue(isSorted(arr));
        }
    }

    private boolean isSorted(int[] arr) {
        for (int i = 0; i < arr.length - 1; i++) {
            if (arr[i] > arr[i + 1])
                return false;
        }
        return true;
    }

}


There are the execution results for 100 million of numbers on 8-core CPU:


Conclusion

In this article, I showed you an example of how to use Fork/Join Framework. I hope you now have a basic idea of how to speed up your applications. The source code is available over on GitHub.