AnthonyZero's Bolg

JUC-Fork/Join框架

Fork/Join框架

Fork/Join框架是Java7提供的一个用于并行执行任务的框架, 是一个把大任务分割成若干个小任务,最终汇总每个小任务结果后得到大任务结果的框架。

Oracle的官方给出的定义是:Fork/Join框架是一个实现了ExecutorService接口的多线程处理器。它可以把一个大的任务划分为若干个小的任务并发执行,充分利用可用的资源,进而提高应用的执行效率。Fork/Join的运行流程图如下:

体现出算法中分而治之思想(将一个难以直接解决的大问题,分割成一些规模较小的相同问题,以便各个击破,分而治之。)

Alt text

Fork/Join框架的设计

Fork/Join框架的设计分为两步:

步骤一分割任务: 首先我们需要有一个fork类来把大任务分割成子任务,有可能子任务还是很大,所以还需要不停的分割,直到分割出的子任务足够小。
步骤一执行任务并合并结果: 分割的子任务分别放在双端队列里,然后几个启动线程分别从双端队列里获取任务执行。子任务执行完的结果都统一放在一个队列里,启动一个线程从这个队列里拿数据,然后合并这些数据。

Fork/Join使用两个类来完成以上两件事情:

  1. ForkJoinTask:我们要使用ForkJoin框架,必须首先创建一个ForkJoin任务。它提供在任务中执行fork()和join()操作的机制,通常情况下我们不需要直接继承ForkJoinTask类,而只需要继承它的子类,Fork/Join框架提供了以下两个子类:
    • RecursiveAction:用于没有返回结果的任务。
    • RecursiveTask:用于有返回结果的任务。
  2. ForkJoinPool :ForkJoinTask需要通过ForkJoinPool来执行。
任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。当一个工作线程的队列里暂时没有任务时,它会随机从其他工作线程的队列的尾部获取一个任务。

Java8中java.util.Arrays的parallelSort()方法和java.util.streams包中封装的方法也都用到了Fork/Join

工作窃取算法

Fork/Join在实现上,大任务分割出若干互不依赖的子任务,为了减少线程间的竞争,把这些子任务分别放到不同的队列里面,每一个队列都会创建一个单独的线程来消费执行队列中的任务,线程和队列一 一对应。但是某些线程会提前消费完自己的任务。而有些线程没有及时消费完任务,这个时候,完成了任务的线程就会去窃取那些没有消费完成的线程的队列任务,(这时候多线程会访问同一个队列)为了减少线程竞争,Fork/Join使用双端队列来存取子任务,分配给这个队列的线程会一直从头取得一个任务然后执行,而窃取线程总是从队列的尾端拉取任务执行。

优点充分利用线程进行并行计算,减少了线程间的竞争。缺点在双端队列只有一个任务时,还是会存在竞争,并且创建了多个线程多个双端队列消耗了更多的系统资源。

使用体验

通过使用Fork/Join框架计算两个数之间的所有数之和

public class CountTask extends RecursiveTask<Integer> {
    private static final int THRESHOLD = 2; //阈值
    private int start;
    private int end;
    public CountTask(int start, int end) {
        this.start = start;
        this.end = end;
    }
    @Override
    protected Integer compute() {
        int sum = 0;
        //任务是否可以继续划分为 子任务
        boolean canContinue = (end - start) >= THRESHOLD;
        if (!canContinue) {
            for (int i = start; i<= end; i++) {
                sum += i;
            }
            return sum;
        } else {
            //任务大于阈值 就继续分裂为两个子任务计算
            int middle = (start + end) / 2;
            CountTask leftTask = new CountTask(start, middle);
            CountTask rightTask = new CountTask(middle+1, end);
            //执行子任务
            leftTask.fork(); //子任务fork时 又会进入compute方法
            rightTask.fork();
            return leftTask.join() + rightTask.join(); //子任务执行完得到结果
        }
    }

    public static void main(String[] args) {
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        //生成一个任务 负责执行1到10的和计算
        CountTask task = new CountTask(1, 10);
        ForkJoinTask<Integer> result = forkJoinPool.submit(task);
        try {
            System.out.println(result.get()); //最后结果
        } catch (InterruptedException e) {
            e.printStackTrace();
        } catch (ExecutionException e) {
            e.printStackTrace();
        }
    }
}