HOOOS

Java 并发编程进阶:CountDownLatch 的实战应用与深度解析

0 81 老码农 Java并发编程CountDownLatch
Apple

你好,我是老码农,很高兴又和大家见面了。今天咱们聊聊 Java 并发编程中一个非常实用的工具——CountDownLatch。 相信不少小伙伴对它已经有所了解,但咱们的目标是不仅要“知其然”,更要“知其所以然”,深入挖掘它的应用场景,掌握它在实际项目中的实战技巧,帮助大家提升解决并发问题的能力。

1. 什么是 CountDownLatch?

CountDownLatch 是 Java 并发包 java.util.concurrent 下的一个同步辅助工具,它允许一个或多个线程等待其他线程完成操作。你可以把它想象成一个倒计时锁,它有一个初始计数器,每当一个线程完成它的工作,计数器的值就会减 1。当计数器减到 0 时,所有等待的线程都会被释放,继续执行它们自己的任务。

1.1. 核心概念

  • 计数器 (Counter)CountDownLatch 的核心,初始化时设置一个整数值,代表需要等待的任务数量。
  • countDown() 方法:用于递减计数器。当一个线程完成一个任务后,调用此方法。
  • await() 方法:使当前线程等待,直到计数器的值为 0。这个方法有多个重载版本,可以设置超时时间。

1.2. 内部实现

CountDownLatch 内部使用 AbstractQueuedSynchronizer (AQS) 来实现同步。AQS 是一个用于构建锁和同步器的框架,它使用一个 int 类型的变量来表示同步状态。CountDownLatch 将计数器的值存储在这个同步状态中,await() 方法实际上调用了 AQS 的 acquireShared() 方法,而 countDown() 方法则调用了 AQS 的 releaseShared() 方法。

2. CountDownLatch 的基本用法

咱们先来个简单的例子,直观感受一下 CountDownLatch 的用法。

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class CountDownLatchExample {

    public static void main(String[] args) throws InterruptedException {
        int workerCount = 3; // 定义需要执行的线程数量
        CountDownLatch latch = new CountDownLatch(workerCount); // 初始化 CountDownLatch,计数器为 workerCount
        ExecutorService executor = Executors.newFixedThreadPool(workerCount); // 创建线程池

        // 模拟三个工人干活
        for (int i = 0; i < workerCount; i++) {
            final int workerId = i + 1; // 给每个工人一个 ID
            executor.submit(() -> {
                try {
                    System.out.println("工人 " + workerId + " 开始工作...");
                    Thread.sleep((long) (Math.random() * 2000)); // 模拟工作时间
                    System.out.println("工人 " + workerId + " 完成工作.");
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } finally {
                    latch.countDown(); // 工人完成工作后,计数器减 1
                }
            });
        }

        System.out.println("主线程等待所有工人完成工作...");
        latch.await(); // 主线程等待,直到计数器变为 0
        System.out.println("所有工人完成工作,主线程继续执行.");

        executor.shutdown(); // 关闭线程池
    }
}

在这个例子中:

  1. 我们创建了一个 CountDownLatch 对象,计数器的初始值为 workerCount (3)。
  2. 我们创建了一个线程池,模拟了 3 个工人同时工作。
  3. 每个工人完成工作后,调用 latch.countDown() 方法,计数器减 1。
  4. 主线程调用 latch.await() 方法,阻塞等待,直到计数器变为 0。
  5. 当所有工人都完成工作后,主线程才继续执行。

运行结果类似这样:

工人 1 开始工作...
工人 2 开始工作...
工人 3 开始工作...
主线程等待所有工人完成工作...
工人 3 完成工作.
工人 2 完成工作.
工人 1 完成工作.
所有工人完成工作,主线程继续执行.

通过这个例子,你可以清晰地看到 CountDownLatch 的作用:协调多个线程的执行,确保它们在特定时刻同步。

3. CountDownLatch 的应用场景

CountDownLatch 是一种非常灵活的同步工具,在很多场景下都能发挥重要作用。

3.1. 系统启动时的服务加载

这是 CountDownLatch 最常见的应用场景之一。在系统启动时,往往需要加载多个服务或组件。这些服务的加载顺序可能有依赖关系,或者它们需要并行加载以提高启动速度。CountDownLatch 可以用来协调这些服务的加载,确保所有依赖的服务都加载完毕后,主服务才能启动。

3.1.1. 场景分析

假设一个系统依赖于数据库、缓存和消息队列这三个服务。为了保证系统的正常运行,需要先启动数据库,然后启动缓存和消息队列,最后启动主服务。

3.1.2. 代码实现

import java.util.concurrent.CountDownLatch;

public class SystemStartup {

    private static final int SERVICE_COUNT = 3;

    public static void main(String[] args) throws InterruptedException {
        CountDownLatch latch = new CountDownLatch(SERVICE_COUNT); // 计数器初始化为 3

        // 模拟数据库服务加载
        new Thread(() -> {
            try {
                System.out.println("数据库服务启动中...");
                Thread.sleep(2000); // 模拟加载时间
                System.out.println("数据库服务启动完成.");
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            } finally {
                latch.countDown(); // 数据库服务加载完成,计数器减 1
            }
        }).start();

        // 模拟缓存服务加载
        new Thread(() -> {
            try {
                System.out.println("缓存服务启动中...");
                Thread.sleep(1000); // 模拟加载时间
                System.out.println("缓存服务启动完成.");
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            } finally {
                latch.countDown(); // 缓存服务加载完成,计数器减 1
            }
        }).start();

        // 模拟消息队列服务加载
        new Thread(() -> {
            try {
                System.out.println("消息队列服务启动中...");
                Thread.sleep(1500); // 模拟加载时间
                System.out.println("消息队列服务启动完成.");
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            } finally {
                latch.countDown(); // 消息队列服务加载完成,计数器减 1
            }
        }).start();

        System.out.println("等待所有服务启动...");
        latch.await(); // 主线程等待,直到所有服务加载完成
        System.out.println("所有服务启动完成,主服务开始启动.");

        // 模拟主服务启动
        System.out.println("主服务启动中...");
        Thread.sleep(500); // 模拟启动时间
        System.out.println("主服务启动完成.");
    }
}

运行结果类似这样:

等待所有服务启动...
数据库服务启动中...
缓存服务启动中...
消息队列服务启动中...
数据库服务启动完成.
缓存服务启动完成.
消息队列服务启动完成.
所有服务启动完成,主服务开始启动.
主服务启动中...
主服务启动完成.

在这个例子中,主线程会等待数据库、缓存和消息队列这三个服务都启动完毕后,才开始启动主服务,保证了系统启动的正确性和可靠性。

3.2. 单元测试中的线程同步

在单元测试中,经常需要测试多线程的代码。CountDownLatch 可以用来控制测试线程的启动和同步,确保测试结果的准确性。

3.2.1. 场景分析

假设有一个 TaskExecutor 类,它会启动多个线程执行任务。我们需要编写单元测试来验证 TaskExecutor 的功能是否正确。

3.2.2. 代码实现

import org.junit.jupiter.api.Test;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class TaskExecutorTest {

    // 模拟 TaskExecutor 类
    static class TaskExecutor {
        private final ExecutorService executor;

        public TaskExecutor(int threadPoolSize) {
            this.executor = Executors.newFixedThreadPool(threadPoolSize);
        }

        public void executeTasks(int taskCount, Runnable task) throws InterruptedException {
            CountDownLatch latch = new CountDownLatch(taskCount);
            for (int i = 0; i < taskCount; i++) {
                executor.submit(() -> {
                    try {
                        task.run();
                    } finally {
                        latch.countDown();
                    }
                });
            }
            latch.await(); // 等待所有任务执行完毕
        }

        public void shutdown() {
            executor.shutdown();
        }
    }

    @Test
    public void testExecuteTasks() throws InterruptedException {
        int threadPoolSize = 3;
        int taskCount = 5;
        TaskExecutor taskExecutor = new TaskExecutor(threadPoolSize);
        CountDownLatch latch = new CountDownLatch(1); // 使用 CountDownLatch 模拟主线程启动

        // 定义一个计数器,用于统计任务执行的次数
        final int[] counter = {0};

        // 定义一个任务,每次执行时计数器加 1
        Runnable task = () -> {
            try {
                Thread.sleep(10); // 模拟任务执行时间
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            } finally {
                synchronized (counter) {
                    counter[0]++;
                }
            }
            latch.countDown();
        };

        // 启动任务
        taskExecutor.executeTasks(taskCount, task);
        latch.await();
        // 验证任务执行的次数是否正确
        assertEquals(taskCount, counter[0]);

        taskExecutor.shutdown();
    }
}

在这个例子中:

  1. 我们创建了一个 TaskExecutor 类,它使用线程池来执行任务。
  2. testExecuteTasks 方法中,我们创建了一个 CountDownLatch 对象,用于同步测试线程。
  3. 我们定义了一个任务,每次执行时计数器加 1。
  4. 我们使用 taskExecutor.executeTasks() 方法来启动任务,并在任务完成后调用 latch.countDown() 方法。
  5. 主线程调用 latch.await() 方法,等待所有任务执行完毕。
  6. 最后,我们使用 assertEquals 方法来验证任务执行的次数是否正确。

通过这种方式,我们可以确保测试线程的执行顺序和同步,从而提高单元测试的准确性和可靠性。

3.3. 并发任务的并行执行

CountDownLatch 还可以用来实现并发任务的并行执行。你可以将一个大任务拆分成多个小任务,然后使用多个线程并行执行这些小任务,最后使用 CountDownLatch 来等待所有小任务完成。

3.3.1. 场景分析

假设我们需要处理一个大型的文本文件,将其中的每一行数据进行处理。为了提高处理速度,我们可以将文件分成多个部分,然后使用多个线程并行处理这些部分。

3.3.2. 代码实现

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class ParallelProcessing {

    public static void main(String[] args) throws IOException, InterruptedException {
        String filePath = "large_text_file.txt"; // 假设有一个大型文本文件
        int threadCount = 4; // 定义线程数量
        ExecutorService executor = Executors.newFixedThreadPool(threadCount);
        CountDownLatch latch = new CountDownLatch(threadCount);
        List<String> processedLines = new ArrayList<>();

        // 1. 将文件分成多个部分(这里简化为按行读取)
        List<String> lines = readFile(filePath);
        int linesPerThread = (int) Math.ceil((double) lines.size() / threadCount);

        // 2. 创建并提交任务
        for (int i = 0; i < threadCount; i++) {
            final int startIndex = i * linesPerThread;
            final int endIndex = Math.min(startIndex + linesPerThread, lines.size());

            executor.submit(() -> {
                try {
                    System.out.println(Thread.currentThread().getName() + "处理从" + startIndex + "到" + endIndex + "的行.");
                    for (int j = startIndex; j < endIndex; j++) {
                        String line = lines.get(j);
                        // 模拟处理过程
                        String processedLine = processLine(line);
                        synchronized (processedLines) {
                            processedLines.add(processedLine);
                        }
                    }
                } finally {
                    latch.countDown(); // 任务完成,计数器减 1
                }
            });
        }

        // 3. 等待所有任务完成
        latch.await();
        System.out.println("所有线程处理完成.");

        // 4. 汇总结果
        System.out.println("处理后的总行数: " + processedLines.size());
        executor.shutdown();
    }

    // 模拟读取文件
    private static List<String> readFile(String filePath) throws IOException {
        List<String> lines = new ArrayList<>();
        try (BufferedReader reader = new BufferedReader(new FileReader(filePath))) {
            String line;
            while ((line = reader.readLine()) != null) {
                lines.add(line);
            }
        } 
        return lines;
    }

    // 模拟处理每一行数据
    private static String processLine(String line) {
        // 模拟耗时操作
        try {
            Thread.sleep(10);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        return "Processed: " + line;
    }
}

在这个例子中:

  1. 我们首先读取了大型文本文件的所有行。
  2. 然后,我们将文件分成多个部分,每个部分由一个线程处理。
  3. 我们创建了一个 CountDownLatch 对象,计数器的初始值为线程数量。
  4. 每个线程处理完自己的部分后,调用 latch.countDown() 方法。
  5. 主线程调用 latch.await() 方法,等待所有线程完成任务。
  6. 最后,我们汇总所有线程的处理结果。

通过这种方式,我们可以充分利用多核 CPU 的优势,提高数据处理的速度。

4. CountDownLatch 的高级用法与注意事项

除了基本用法和常见场景,CountDownLatch 还有一些高级用法和需要注意的事项,可以帮助你更好地使用它。

4.1. await() 方法的超时控制

await() 方法有多个重载版本,其中一个版本可以设置超时时间。这在某些场景下非常有用,可以避免线程无限期地等待下去,导致程序 hang 住。

4.1.1. 场景分析

在系统启动时,如果某个服务的启动时间过长,导致主线程一直等待,可能会影响系统的可用性。通过设置 await() 方法的超时时间,可以在服务启动超时时,采取一些补救措施,例如打印错误日志、启动备用服务等。

4.1.2. 代码实现

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

public class CountDownLatchTimeout {

    public static void main(String[] args) throws InterruptedException {
        CountDownLatch latch = new CountDownLatch(1);

        new Thread(() -> {
            try {
                Thread.sleep(5000); // 模拟耗时操作
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            } finally {
                latch.countDown();
            }
        }).start();

        System.out.println("等待任务完成...");
        boolean success = latch.await(3, TimeUnit.SECONDS); // 设置超时时间为 3 秒
        if (success) {
            System.out.println("任务完成.");
        } else {
            System.out.println("任务超时.");
            // 可以采取一些补救措施,例如打印错误日志、启动备用服务等
        }
    }
}

在这个例子中,我们设置了 await() 方法的超时时间为 3 秒。如果任务在 3 秒内完成,则打印“任务完成”;如果任务超时,则打印“任务超时”,并可以执行相应的补救措施。

4.2. CountDownLatch 的重用问题

CountDownLatch 是一次性的,一旦计数器变为 0,就不能再次重置。如果需要多次使用,可以考虑使用其他同步工具,例如 CyclicBarrier

4.2.1. 场景分析

假设我们需要模拟一个游戏场景,多个玩家需要等待所有玩家都准备好后,才能开始游戏。每局游戏开始前,都需要重置等待状态。

4.2.2. 解决方案

在这种情况下,CyclicBarrier 更加适合。CyclicBarrier 允许一组线程相互等待,直到所有线程都到达某个屏障点,然后所有线程才能继续执行。它可以在多个线程之间重复使用,并且可以重置计数器。

4.3. 异常处理

在使用 CountDownLatch 时,需要注意异常处理。如果线程在等待过程中被中断,await() 方法会抛出 InterruptedException。在使用线程池时,还需要注意异常的传递,避免线程池中的任务发生异常,导致计数器无法递减,从而造成程序死锁。

4.3.1. 场景分析

在系统启动时,如果某个服务加载过程中发生异常,可能会导致 CountDownLatch 的计数器无法递减,从而导致主线程无限期地等待。

4.3.2. 代码实现

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class CountDownLatchException {

    public static void main(String[] args) throws InterruptedException {
        int workerCount = 2;
        CountDownLatch latch = new CountDownLatch(workerCount);
        ExecutorService executor = Executors.newFixedThreadPool(workerCount);

        // 模拟一个线程发生异常
        executor.submit(() -> {
            try {
                System.out.println("线程 1 开始工作...");
                Thread.sleep(1000); // 模拟工作时间
                throw new RuntimeException("模拟异常"); // 模拟异常
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            } catch (Exception e) {
                System.err.println("线程 1 发生异常: " + e.getMessage());
            } finally {
                latch.countDown();
            }
        });

        // 模拟一个正常线程
        executor.submit(() -> {
            try {
                System.out.println("线程 2 开始工作...");
                Thread.sleep(2000); // 模拟工作时间
                System.out.println("线程 2 完成工作.");
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            } finally {
                latch.countDown();
            }
        });

        latch.await(); // 主线程等待
        System.out.println("所有线程完成工作.");
        executor.shutdown();
    }
}

在这个例子中,线程 1 模拟了异常的发生。为了避免异常导致 CountDownLatch 的计数器无法递减,我们在 finally 块中调用 latch.countDown() 方法,确保计数器能够正确递减。同时,我们在 catch 块中捕获异常,并进行处理,例如打印错误日志。

4.4. 线程安全问题

虽然 CountDownLatch 本身是线程安全的,但在使用时,需要注意对共享资源的访问。如果多个线程需要访问共享资源,需要使用同步机制来保证线程安全,例如 synchronized 关键字、ReentrantLock 等。

4.4.1. 场景分析

在并发任务的并行执行场景中,如果多个线程需要同时修改一个共享的列表,则需要保证线程安全。

4.4.2. 解决方案

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class CountDownLatchThreadSafe {

    public static void main(String[] args) throws InterruptedException {
        int threadCount = 3;
        CountDownLatch latch = new CountDownLatch(threadCount);
        ExecutorService executor = Executors.newFixedThreadPool(threadCount);
        List<String> sharedList = new ArrayList<>();

        // 模拟多个线程向共享列表添加元素
        for (int i = 0; i < threadCount; i++) {
            final int threadId = i;
            executor.submit(() -> {
                try {
                    System.out.println("线程 " + threadId + " 开始工作...");
                    for (int j = 0; j < 5; j++) {
                        // 使用 synchronized 保证线程安全
                        synchronized (sharedList) {
                            sharedList.add("线程 " + threadId + " 添加元素 " + j);
                        }
                        Thread.sleep(10); // 模拟操作时间
                    }
                    System.out.println("线程 " + threadId + " 完成工作.");
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } finally {
                    latch.countDown();
                }
            });
        }

        latch.await();
        System.out.println("所有线程完成工作.");
        System.out.println("共享列表内容: " + sharedList); // 打印共享列表内容
        executor.shutdown();
    }
}

在这个例子中,我们使用 synchronized 关键字来保证对 sharedList 的访问是线程安全的。当一个线程访问 sharedList 时,其他线程需要等待,直到该线程释放锁。

5. 总结

CountDownLatch 是一个非常实用的并发工具,它可以帮助我们解决很多实际问题,例如系统启动时的服务加载、单元测试中的线程同步、并发任务的并行执行等。通过学习它的基本用法、应用场景和高级用法,我们可以更好地掌握它,并将其应用到实际项目中。

希望这次分享能够帮助你更深入地理解 CountDownLatch。在实际工作中,要根据具体的场景选择合适的同步工具,并注意异常处理和线程安全问题。 祝你编程愉快!

点评评价

captcha
健康