HOOOS

源码剖析:CyclicBarrier 如何实现多线程同步?

0 59 并发小能手 Java并发CyclicBarrier源码分析
Apple

你好,我是你的源码剖析向导“并发小能手”。今天咱们来聊聊 Java 并发工具类中的 CyclicBarrier,看看它是如何实现多线程同步的。

CyclicBarrier 是什么?

CyclicBarrier,字面意思是“循环的屏障”。你可以把它想象成一个赛跑比赛中的起跑线。多个运动员(线程)必须在起跑线前等待,直到所有运动员都准备好了,裁判(CyclicBarrier)才会发出起跑信号(所有线程开始执行)。

与 CountDownLatch 不同,CyclicBarrier 可以重复使用。就像一场比赛可以进行多轮,每轮比赛开始前,运动员们都要在起跑线前重新集合。

CyclicBarrier 的应用场景

CyclicBarrier 适用于这样一种场景:多个线程协同工作,每个线程完成一部分任务,所有线程都完成后,才能进行下一步。常见的例子有:

  • 多线程计算:例如,一个大型计算任务可以分解成多个子任务,每个线程负责一个子任务。所有子任务完成后,才能进行最终结果的汇总。
  • 并行迭代算法:例如,在某些图像处理算法中,图像的不同区域可以由不同的线程并行处理,每个迭代步骤都需要所有线程完成后才能开始。
  • 模拟并发测试: 可以利用CyclicBarrier让多个线程在同一时刻执行,达到并发测试目的。

CyclicBarrier 源码分析

好了,说了这么多,咱们来扒一扒 CyclicBarrier 的源码,看看它是如何实现的。我们主要关注以下几个核心方法:

  • CyclicBarrier(int parties):构造方法,指定参与同步的线程数量。
  • CyclicBarrier(int parties, Runnable barrierAction):构造方法,除了指定线程数量,还指定一个 Runnable,当所有线程都到达屏障时,会执行这个 Runnable。
  • await():等待其他线程到达屏障。如果当前线程不是最后一个到达屏障的线程,它会阻塞,直到最后一个线程到达。
  • await(long timeout, TimeUnit unit):带有超时的等待。如果在指定时间内,其他线程没有全部到达屏障,当前线程会抛出 TimeoutException。

成员变量

public class CyclicBarrier {

    /**
     *  内部类 Generation,用于表示 CyclicBarrier 的“代”。
     *  当 CyclicBarrier 被重置时,会创建一个新的 Generation 对象。
     */
    private static class Generation {
        boolean broken = false; // 标记当前“代”是否被打破
    }

    /**
     *  锁对象,用于控制对 CyclicBarrier 内部状态的访问。
     *  CyclicBarrier 内部使用 ReentrantLock 来实现同步。
     */
    private final ReentrantLock lock = new ReentrantLock();

    /**
     *  条件变量,用于阻塞和唤醒线程。
     *  当线程调用 await() 方法时,如果不是最后一个到达的线程,它会在这个条件变量上等待。
     */
    private final Condition trip = lock.newCondition();

    /**
     *  参与同步的线程数量。
     */
    private final int parties;

    /**
     *  当所有线程都到达屏障时,要执行的 Runnable。
     */
    private final Runnable barrierCommand;

    /**
     *  当前 CyclicBarrier 的“代”。
     */
    private Generation generation = new Generation();

    /**
     *  当前“代”中,还未到达屏障的线程数量。
     */
    private int count;

    // ...
}

构造方法

    public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        this.count = parties;
        this.barrierCommand = barrierAction;
    }

    public CyclicBarrier(int parties) {
        this(parties, null);
    }

构造方法很简单,就是初始化 partiescountbarrierCommandparties 表示参与同步的线程数量,count 初始化为 parties,表示当前“代”中还未到达屏障的线程数量。barrierCommand 是可选的,当所有线程都到达屏障时,会执行这个 Runnable。

await() 方法

    public int await() throws InterruptedException, BrokenBarrierException {
        try {
            return dowait(false, 0L); // 调用内部的 dowait 方法
        } catch (TimeoutException toe) {
            throw new Error(toe); // 不可能发生,因为没有设置超时
        }
    }

await() 方法直接调用了内部的 dowait() 方法,并传入 false0L,表示不进行超时等待。

dowait() 方法

    private int dowait(boolean timed, long nanos) throws InterruptedException, BrokenBarrierException, TimeoutException {
        final ReentrantLock lock = this.lock;
        lock.lock(); // 获取锁
        try {
            final Generation g = generation; // 获取当前“代”

            if (g.broken) // 如果当前“代”已经被打破,抛出 BrokenBarrierException
                throw new BrokenBarrierException();

            if (Thread.interrupted()) { // 如果当前线程被中断
                breakBarrier(); // 打破屏障
                throw new InterruptedException(); // 抛出 InterruptedException
            }

            int index = --count; // 将 count 减 1,表示当前线程到达屏障
            if (index == 0) {  // 如果当前线程是最后一个到达屏障的线程
                boolean ranAction = false;
                try {
                    final Runnable command = barrierCommand;
                    if (command != null)
                        command.run(); // 执行 barrierCommand
                    ranAction = true;
                    nextGeneration(); // 开启下一“代”
                    return 0;
                } finally {
                    if (!ranAction)
                        breakBarrier(); // 如果 barrierCommand 执行失败,打破屏障
                }
            }

            // 如果当前线程不是最后一个到达屏障的线程,循环等待
            for (;;) {
                try {
                    if (!timed)
                        trip.await(); // 如果不进行超时等待,直接在 trip 上等待
                    else if (nanos > 0)
                        nanos = trip.awaitNanos(nanos); // 如果进行超时等待,在 trip 上等待指定时间
                } catch (InterruptedException ie) {
                    if (g == generation && ! g.broken) {
                        breakBarrier(); // 如果当前线程被中断,并且当前“代”没有被打破,打破屏障
                        throw ie; // 抛出 InterruptedException
                    } else {
                        // 如果当前“代”已经被打破,或者当前线程在等待期间“代”发生了变化
                        // 说明其他线程已经唤醒了当前线程,或者屏障已经被重置
                        Thread.currentThread().interrupt(); // 重新设置中断状态
                    }
                }

                if (g.broken) // 如果当前“代”已经被打破,抛出 BrokenBarrierException
                    throw new BrokenBarrierException();

                if (g != generation) // 如果“代”发生了变化,说明屏障已经被重置,返回 index
                    return index;

                if (timed && nanos <= 0L) { // 如果进行超时等待,并且超时时间已到
                    breakBarrier(); // 打破屏障
                    throw new TimeoutException(); // 抛出 TimeoutException
                }
            }
        } finally {
            lock.unlock(); // 释放锁
        }
    }

dowait() 方法是 CyclicBarrier 的核心逻辑。我们来一步步分析:

  1. 获取锁:首先,获取 lock,保证对 CyclicBarrier 内部状态的互斥访问。
  2. 检查“代”是否被打破:如果当前“代”已经被打破(g.brokentrue),说明有其他线程在等待过程中发生了中断或超时,抛出 BrokenBarrierException
  3. 检查线程是否被中断:如果当前线程被中断,调用 breakBarrier() 方法打破屏障,并抛出 InterruptedException
  4. count 减 1:将 count 减 1,表示当前线程到达屏障。index 变量保存了当前线程到达屏障前的 count 值。
  5. 判断是否是最后一个到达的线程:如果 index 为 0,说明当前线程是最后一个到达屏障的线程。
    • 执行 barrierCommand:如果 barrierCommand 不为空,执行它。注意,barrierCommand 的执行是在持有锁的情况下进行的,这意味着在 barrierCommand 执行期间,其他线程无法进入 await() 方法。
    • 开启下一“代”:调用 nextGeneration() 方法,开启下一“代”。
    • 返回 0:返回 0,表示当前线程是最后一个到达屏障的线程。
    • 异常处理:如果 barrierCommand 执行过程中发生异常,调用 breakBarrier() 方法打破屏障。
  6. 非最后一个到达的线程:如果 index 不为 0,说明当前线程不是最后一个到达屏障的线程,进入循环等待。
    • 等待:根据 timed 参数,选择调用 trip.await()trip.awaitNanos(nanos) 方法,在 trip 条件变量上等待。
    • 中断处理:如果在等待过程中,当前线程被中断,检查当前“代”是否被打破。如果未被打破,调用 breakBarrier() 方法打破屏障,并抛出 InterruptedException;否则,重新设置中断状态。
    • 检查“代”是否被打破:如果在等待过程中,当前“代”被打破,抛出 BrokenBarrierException
    • 检查“代”是否变化:如果“代”发生了变化,说明屏障已经被重置,返回 index
    • 超时处理:如果进行超时等待,并且超时时间已到,调用 breakBarrier() 方法打破屏障,并抛出 TimeoutException
  7. 释放锁:最后,释放 lock

breakBarrier() 方法

    private void breakBarrier() {
        generation.broken = true; // 将当前“代”标记为被打破
        count = parties; // 重置 count
        trip.signalAll(); // 唤醒所有在 trip 上等待的线程
    }

breakBarrier() 方法用于打破屏障。它将当前“代”标记为被打破,重置 count,并唤醒所有在 trip 上等待的线程。这些线程被唤醒后,会抛出 BrokenBarrierException

nextGeneration() 方法

    private void nextGeneration() {
        // 唤醒所有在 trip 上等待的线程,让它们进入下一“代”
        trip.signalAll();
        // 重置 count
        count = parties;
        // 创建新的 Generation 对象
        generation = new Generation();
    }

nextGeneration() 方法用于开启下一“代”。它首先唤醒所有在 trip 上等待的线程,然后重置 count,并创建一个新的 Generation 对象。

CyclicBarrier 与 ReentrantLock 和 Condition

从源码中可以看出,CyclicBarrier 的实现依赖于 ReentrantLock 和 Condition。ReentrantLock 用于保证对 CyclicBarrier 内部状态的互斥访问,Condition 用于阻塞和唤醒线程。

CyclicBarrier 内部维护了一个 count 变量,表示还未到达屏障的线程数量。当一个线程调用 await() 方法时,count 会减 1。如果 count 变为 0,说明所有线程都到达了屏障,此时会执行 barrierCommand(如果有),然后唤醒所有在 trip 条件变量上等待的线程。如果 count 不为 0,当前线程会在 trip 条件变量上等待。

总结

CyclicBarrier 是一个非常有用的并发工具类,它可以让多个线程在某个点上同步。它的实现基于 ReentrantLock 和 Condition,通过维护一个计数器 count 和一个“代”的概念,实现了线程的阻塞、唤醒和循环使用。

理解 CyclicBarrier 的源码,可以帮助你更好地理解它的工作原理,并在实际开发中正确地使用它。同时,也能让你对 Java 并发编程有更深入的认识。

希望这篇源码剖析对你有所帮助。如果你还有其他问题,欢迎随时提问!

思考题

  1. CyclicBarrier 的 reset() 方法有什么作用?它的实现原理是什么?(提示:在源码中找找看。)
  2. CyclicBarrier 和 CountDownLatch 有什么区别?分别适用于什么场景?
  3. 如果让你自己实现一个 CyclicBarrier,你会怎么设计?

期待你的思考和回答!

点评评价

captcha
健康