你好,我是你的源码剖析向导“并发小能手”。今天咱们来聊聊 Java 并发工具类中的 CyclicBarrier,看看它是如何实现多线程同步的。
CyclicBarrier 是什么?
CyclicBarrier,字面意思是“循环的屏障”。你可以把它想象成一个赛跑比赛中的起跑线。多个运动员(线程)必须在起跑线前等待,直到所有运动员都准备好了,裁判(CyclicBarrier)才会发出起跑信号(所有线程开始执行)。
与 CountDownLatch 不同,CyclicBarrier 可以重复使用。就像一场比赛可以进行多轮,每轮比赛开始前,运动员们都要在起跑线前重新集合。
CyclicBarrier 的应用场景
CyclicBarrier 适用于这样一种场景:多个线程协同工作,每个线程完成一部分任务,所有线程都完成后,才能进行下一步。常见的例子有:
- 多线程计算:例如,一个大型计算任务可以分解成多个子任务,每个线程负责一个子任务。所有子任务完成后,才能进行最终结果的汇总。
- 并行迭代算法:例如,在某些图像处理算法中,图像的不同区域可以由不同的线程并行处理,每个迭代步骤都需要所有线程完成后才能开始。
- 模拟并发测试: 可以利用CyclicBarrier让多个线程在同一时刻执行,达到并发测试目的。
CyclicBarrier 源码分析
好了,说了这么多,咱们来扒一扒 CyclicBarrier 的源码,看看它是如何实现的。我们主要关注以下几个核心方法:
CyclicBarrier(int parties)
:构造方法,指定参与同步的线程数量。CyclicBarrier(int parties, Runnable barrierAction)
:构造方法,除了指定线程数量,还指定一个 Runnable,当所有线程都到达屏障时,会执行这个 Runnable。await()
:等待其他线程到达屏障。如果当前线程不是最后一个到达屏障的线程,它会阻塞,直到最后一个线程到达。await(long timeout, TimeUnit unit)
:带有超时的等待。如果在指定时间内,其他线程没有全部到达屏障,当前线程会抛出 TimeoutException。
成员变量
public class CyclicBarrier {
/**
* 内部类 Generation,用于表示 CyclicBarrier 的“代”。
* 当 CyclicBarrier 被重置时,会创建一个新的 Generation 对象。
*/
private static class Generation {
boolean broken = false; // 标记当前“代”是否被打破
}
/**
* 锁对象,用于控制对 CyclicBarrier 内部状态的访问。
* CyclicBarrier 内部使用 ReentrantLock 来实现同步。
*/
private final ReentrantLock lock = new ReentrantLock();
/**
* 条件变量,用于阻塞和唤醒线程。
* 当线程调用 await() 方法时,如果不是最后一个到达的线程,它会在这个条件变量上等待。
*/
private final Condition trip = lock.newCondition();
/**
* 参与同步的线程数量。
*/
private final int parties;
/**
* 当所有线程都到达屏障时,要执行的 Runnable。
*/
private final Runnable barrierCommand;
/**
* 当前 CyclicBarrier 的“代”。
*/
private Generation generation = new Generation();
/**
* 当前“代”中,还未到达屏障的线程数量。
*/
private int count;
// ...
}
构造方法
public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) throw new IllegalArgumentException();
this.parties = parties;
this.count = parties;
this.barrierCommand = barrierAction;
}
public CyclicBarrier(int parties) {
this(parties, null);
}
构造方法很简单,就是初始化 parties
、count
和 barrierCommand
。parties
表示参与同步的线程数量,count
初始化为 parties
,表示当前“代”中还未到达屏障的线程数量。barrierCommand
是可选的,当所有线程都到达屏障时,会执行这个 Runnable。
await() 方法
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L); // 调用内部的 dowait 方法
} catch (TimeoutException toe) {
throw new Error(toe); // 不可能发生,因为没有设置超时
}
}
await()
方法直接调用了内部的 dowait()
方法,并传入 false
和 0L
,表示不进行超时等待。
dowait() 方法
private int dowait(boolean timed, long nanos) throws InterruptedException, BrokenBarrierException, TimeoutException {
final ReentrantLock lock = this.lock;
lock.lock(); // 获取锁
try {
final Generation g = generation; // 获取当前“代”
if (g.broken) // 如果当前“代”已经被打破,抛出 BrokenBarrierException
throw new BrokenBarrierException();
if (Thread.interrupted()) { // 如果当前线程被中断
breakBarrier(); // 打破屏障
throw new InterruptedException(); // 抛出 InterruptedException
}
int index = --count; // 将 count 减 1,表示当前线程到达屏障
if (index == 0) { // 如果当前线程是最后一个到达屏障的线程
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
if (command != null)
command.run(); // 执行 barrierCommand
ranAction = true;
nextGeneration(); // 开启下一“代”
return 0;
} finally {
if (!ranAction)
breakBarrier(); // 如果 barrierCommand 执行失败,打破屏障
}
}
// 如果当前线程不是最后一个到达屏障的线程,循环等待
for (;;) {
try {
if (!timed)
trip.await(); // 如果不进行超时等待,直接在 trip 上等待
else if (nanos > 0)
nanos = trip.awaitNanos(nanos); // 如果进行超时等待,在 trip 上等待指定时间
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {
breakBarrier(); // 如果当前线程被中断,并且当前“代”没有被打破,打破屏障
throw ie; // 抛出 InterruptedException
} else {
// 如果当前“代”已经被打破,或者当前线程在等待期间“代”发生了变化
// 说明其他线程已经唤醒了当前线程,或者屏障已经被重置
Thread.currentThread().interrupt(); // 重新设置中断状态
}
}
if (g.broken) // 如果当前“代”已经被打破,抛出 BrokenBarrierException
throw new BrokenBarrierException();
if (g != generation) // 如果“代”发生了变化,说明屏障已经被重置,返回 index
return index;
if (timed && nanos <= 0L) { // 如果进行超时等待,并且超时时间已到
breakBarrier(); // 打破屏障
throw new TimeoutException(); // 抛出 TimeoutException
}
}
} finally {
lock.unlock(); // 释放锁
}
}
dowait()
方法是 CyclicBarrier 的核心逻辑。我们来一步步分析:
- 获取锁:首先,获取
lock
,保证对 CyclicBarrier 内部状态的互斥访问。 - 检查“代”是否被打破:如果当前“代”已经被打破(
g.broken
为true
),说明有其他线程在等待过程中发生了中断或超时,抛出BrokenBarrierException
。 - 检查线程是否被中断:如果当前线程被中断,调用
breakBarrier()
方法打破屏障,并抛出InterruptedException
。 - count 减 1:将
count
减 1,表示当前线程到达屏障。index
变量保存了当前线程到达屏障前的count
值。 - 判断是否是最后一个到达的线程:如果
index
为 0,说明当前线程是最后一个到达屏障的线程。- 执行 barrierCommand:如果
barrierCommand
不为空,执行它。注意,barrierCommand
的执行是在持有锁的情况下进行的,这意味着在barrierCommand
执行期间,其他线程无法进入await()
方法。 - 开启下一“代”:调用
nextGeneration()
方法,开启下一“代”。 - 返回 0:返回 0,表示当前线程是最后一个到达屏障的线程。
- 异常处理:如果
barrierCommand
执行过程中发生异常,调用breakBarrier()
方法打破屏障。
- 执行 barrierCommand:如果
- 非最后一个到达的线程:如果
index
不为 0,说明当前线程不是最后一个到达屏障的线程,进入循环等待。- 等待:根据
timed
参数,选择调用trip.await()
或trip.awaitNanos(nanos)
方法,在trip
条件变量上等待。 - 中断处理:如果在等待过程中,当前线程被中断,检查当前“代”是否被打破。如果未被打破,调用
breakBarrier()
方法打破屏障,并抛出InterruptedException
;否则,重新设置中断状态。 - 检查“代”是否被打破:如果在等待过程中,当前“代”被打破,抛出
BrokenBarrierException
。 - 检查“代”是否变化:如果“代”发生了变化,说明屏障已经被重置,返回
index
。 - 超时处理:如果进行超时等待,并且超时时间已到,调用
breakBarrier()
方法打破屏障,并抛出TimeoutException
。
- 等待:根据
- 释放锁:最后,释放
lock
。
breakBarrier() 方法
private void breakBarrier() {
generation.broken = true; // 将当前“代”标记为被打破
count = parties; // 重置 count
trip.signalAll(); // 唤醒所有在 trip 上等待的线程
}
breakBarrier()
方法用于打破屏障。它将当前“代”标记为被打破,重置 count
,并唤醒所有在 trip
上等待的线程。这些线程被唤醒后,会抛出 BrokenBarrierException
。
nextGeneration() 方法
private void nextGeneration() {
// 唤醒所有在 trip 上等待的线程,让它们进入下一“代”
trip.signalAll();
// 重置 count
count = parties;
// 创建新的 Generation 对象
generation = new Generation();
}
nextGeneration()
方法用于开启下一“代”。它首先唤醒所有在 trip
上等待的线程,然后重置 count
,并创建一个新的 Generation
对象。
CyclicBarrier 与 ReentrantLock 和 Condition
从源码中可以看出,CyclicBarrier 的实现依赖于 ReentrantLock 和 Condition。ReentrantLock 用于保证对 CyclicBarrier 内部状态的互斥访问,Condition 用于阻塞和唤醒线程。
CyclicBarrier 内部维护了一个 count
变量,表示还未到达屏障的线程数量。当一个线程调用 await()
方法时,count
会减 1。如果 count
变为 0,说明所有线程都到达了屏障,此时会执行 barrierCommand
(如果有),然后唤醒所有在 trip
条件变量上等待的线程。如果 count
不为 0,当前线程会在 trip
条件变量上等待。
总结
CyclicBarrier 是一个非常有用的并发工具类,它可以让多个线程在某个点上同步。它的实现基于 ReentrantLock 和 Condition,通过维护一个计数器 count
和一个“代”的概念,实现了线程的阻塞、唤醒和循环使用。
理解 CyclicBarrier 的源码,可以帮助你更好地理解它的工作原理,并在实际开发中正确地使用它。同时,也能让你对 Java 并发编程有更深入的认识。
希望这篇源码剖析对你有所帮助。如果你还有其他问题,欢迎随时提问!
思考题
- CyclicBarrier 的
reset()
方法有什么作用?它的实现原理是什么?(提示:在源码中找找看。) - CyclicBarrier 和 CountDownLatch 有什么区别?分别适用于什么场景?
- 如果让你自己实现一个 CyclicBarrier,你会怎么设计?
期待你的思考和回答!