Java并发之CountDownLatch原理及使用

CountDownLatch从字面上理解就是倒计时门栓的意思,它的实现原理同ReentrantLock一样,依然是借助AQS的双端队列,来实现原子的计数-1,线程阻塞和唤醒。

简介

CountDownLatch创建时设置一个count值,表示倒计时的次数,然后等待状态的线程调用CountDownLatch的await()方法(注意不要和Object.wait()混淆)进行等待,倒计时的方法是countDown(), 每次countDown都会减少count的值,直到count为0,则所有的await()的线程都会从等待中返回。

Demo

依然以讲解 ReentrantLock中的例子来说明,多线程实现累加:

1
2
3
线程1实现 10加到100
线程2实现 100加到200
线程3实现 线程1和线程2计算结果的和

CountDownLatch的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
public class CountDownLatchDemo {
private CountDownLatch countDownLatch;

private int start = 10;
private int mid = 100;
private int end = 200;

private volatile int tmpRes1, tmpRes2;

private int add(int start, int end) {
int sum = 0;
for (int i = start; i <= end; i++) {
sum += i;
}
return sum;
}


private int sum(int a, int b) {
return a + b;
}

public void calculate() {
countDownLatch = new CountDownLatch(2);

Thread thread1 = new Thread(() -> {
try {
// 确保线程3先与1,2执行,由于countDownLatch计数不为0而阻塞
Thread.sleep(100);
System.out.println(Thread.currentThread().getName() + " : 开始执行");
tmpRes1 = add(start, mid);
System.out.println(Thread.currentThread().getName() +
" : calculate ans: " + tmpRes1);
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
countDownLatch.countDown();
}
}, "线程1");

Thread thread2 = new Thread(() -> {
try {
// 确保线程3先与1,2执行,由于countDownLatch计数不为0而阻塞
Thread.sleep(100);
System.out.println(Thread.currentThread().getName() + " : 开始执行");
tmpRes2 = add(mid + 1, end);
System.out.println(Thread.currentThread().getName() +
" : calculate ans: " + tmpRes2);
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
countDownLatch.countDown();
}
}, "线程2");


Thread thread3 = new Thread(()-> {
try {
System.out.println(Thread.currentThread().getName() + " : 开始执行");
countDownLatch.await();
int ans = sum(tmpRes1, tmpRes2);
System.out.println(Thread.currentThread().getName() +
" : calculate ans: " + ans);
} catch (InterruptedException e) {
e.printStackTrace();
}
}, "线程3");

thread3.start();
thread1.start();
thread2.start();
}


public static void main(String[] args) throws InterruptedException {
CountDownLatchDemo demo = new CountDownLatchDemo();
demo.calculate();

Thread.sleep(1000);
}
}

输出

1
2
3
4
5
6
线程3 : 开始执行
线程1 : 开始执行
线程2 : 开始执行
线程1 : calculate ans: 5005
线程2 : calculate ans: 15050
线程3 : calculate ans: 20055

上面的流程:

  • 首先是创建实例 CountDownLatch countDown = new CountDownLatch(2)
  • 需要同步的线程执行完之后,计数-1; countDown.countDown()
  • 需要等待其他线程执行完毕之后,再运行的线程,调用 countDown.await()实现阻塞同步

使用场景

一种是同时开始,另一种是主从协作。它们都有两类线程,互相需要同步。

在同时开始场景中,运动员线程等待主裁判线程发出开始指令信号,一旦发出后,所有运动员线程同时开始,计数初始为1,运动员调用await,主线程调用countDown。

在主从协作模式中,主线程依赖工作线程的结果,需要等待工作线程结束,这时,计数初始值为工作线程的个数,工作线程结束后调用countDown,主线程调用await进行等待。

实现原理

CountDownLatch借助AQS的双端队列,来实现原子的计数-1,线程阻塞和唤醒。

AQS

AQS使用一个FIFO的队列表示排队等待锁的线程,队列头节点称作“哨兵节点”或者“哑节点”,它不与任何线程关联。其他的节点与等待线程关联,每个节点维护一个等待状态waitStatus

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
private transient volatile Node head;

private transient volatile Node tail;

private volatile int state;

static final class Node {
static final Node SHARED = new Node();
static final Node EXCLUSIVE = null;

/** waitStatus value to indicate thread has cancelled */
static final int CANCELLED = 1;
/** waitStatus value to indicate successor's thread needs unparking */
static final int SIGNAL = -1;
/** waitStatus value to indicate thread is waiting on condition */
static final int CONDITION = -2;
/**
* waitStatus value to indicate the next acquireShared should
* unconditionally propagate
*/
static final int PROPAGATE = -3;

//取值为 CANCELLED, SIGNAL, CONDITION, PROPAGATE 之一
volatile int waitStatus;

volatile Node prev;

volatile Node next;

// Link to next node waiting on condition,
// or the special value SHARED
volatile Thread thread;

Node nextWaiter;
}

计数器的初始化

CountDownLatch内部实现了AQS,并覆盖了tryAcquireShared()和tryReleaseShared()两个方法,下面说明干嘛用的

通过前面的使用,清楚了计数器的构造必须指定计数值,这个直接初始化了 AQS内部的state变量

1
2
3
Sync(int count) {
setState(count);
}

后续的计数-1/判断是否可用都是基于sate进行的

countDown() 计数-1的实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
// 计数-1
public void countDown() {
sync.releaseShared(1);
}


public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) { // 首先尝试释放锁
doReleaseShared();
return true;
}
return false;
}

protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0) //如果计数已经为0,则返回失败
return false;
int nextc = c-1;
// 原子操作实现计数-1
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}

// 唤醒被阻塞的线程
private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) { // 队列非空,表示有线程被阻塞
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
// 头结点如果为SIGNAL,则唤醒头结点下个节点上关联的线程,并出队
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // 没有线程被阻塞,直接跳出
break;
}
}

上面截出计数减1的完整调用链

  1. 尝试释放锁tryReleaseShared,实现计数-1
    • 若计数已经小于0,则直接返回false
    • 否则执行计数(AQS的state)减一
    • 若减完之后,state==0,表示没有线程占用锁,即释放成功,然后就需要唤醒被阻塞的线程了
  2. 释放并唤醒阻塞线程 doReleaseShared
    • 如果队列为空,即表示没有线程被阻塞(也就是说没有线程调用了 CountDownLatch#wait()方法),直接退出
    • 头结点如果为SIGNAL, 则依次唤醒头结点下个节点上关联的线程,并出队

CountDownLatch计数为0之后,所有被阻塞的线程都会被唤醒,且彼此相对独立,不会出现独占锁阻塞的问题

await() 阻塞等待计数为0

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}


public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted()) // 若线程中端,直接抛异常
throw new InterruptedException();
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}


// 计数为0时,表示获取锁成功
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}

// 阻塞,并入队
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED); // 入队
boolean failed = true;
try {
for (;;) {
// 获取前驱节点
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
// 获取锁成功,设置队列头为node节点
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) // 线程挂起
&& parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}

阻塞的逻辑:

  • 判断state计数是否为0,不是,则直接放过执行后面的代码
  • 大于0,则表示需要阻塞等待计数为0
  • 当前线程封装Node对象,进入阻塞队列
  • 然后就是循环尝试获取锁,直到成功(即state为0)后出队,继续执行线程后续代码