您的位置:首页 > 编程语言 > Java开发

CountDownLatch原理分析

2017-08-13 00:00 471 查看
闭锁是一种同步工具类,可以延迟线程进度直到其到达最终终止状态。闭锁关闭之后,没有任何线程可以通过,当到达结束状态时,会永远打开,并允许所有线程通过。
闭锁状态包括一个计数器,该计数器初始化为一个正数,表示需要等待的事件数量。countDown 方法递减计数器,表示一个事件已经发生,await 方法等待计数器达到零,表示所有需要等待的事件都已经发生。CountDownLatch 内部也是通过 AQS 来实现的。

CountDownLatch 主要实现方法:

CountDownLatch 构造时传入计数器数目,CountDownLatch 将计数器设置到 AQS 的 state 上;

await 等待,如果计数器为 0, 则成功通过;否则被阻塞,并将线程记录到 AQS 的等待锁队列中,等待唤醒重新尝试;

countDown 计数器减1,通过 CAS 将计数器减1,如果计数器为0,则唤醒等待队列上所有线程,因为 CountDownLatch 此时已经打开,所以所有的线程都可以通过。

CountDownLatch 结构:

public class CountDownLatch {
private final CountDownLatch.Sync sync;

public CountDownLatch(int var1) {
if(var1 < 0) {
throw new IllegalArgumentException("count < 0");
} else {
this.sync = new CountDownLatch.Sync(var1);
}
}

public void await() throws InterruptedException {
this.sync.acquireSharedInterruptibly(1);
}

public boolean await(long var1, TimeUnit var3) throws InterruptedException {
return this.sync.tryAcquireSharedNanos(1, var3.toNanos(var1));
}

public void countDown() {
this.sync.releaseShared(1);
}

public long getCount() {
return (long)this.sync.getCount();
}
}

从 CountDownLatch 中我们可以看到 CountDownLatch 的锁是通过 Sync 这个类完成的,Sync 则继承自 AQS,AQS 是独占锁和共享锁的父类,通过继承 AQS 实现共享锁。

Sync 结构:

private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;

Sync(int var1) {
this.setState(var1);
}

int getCount() {
return this.getState();
}

// 判断是否可用
protected int tryAcquireShared(int var1) {
return this.getState() == 0?1:-1;
}

// 计数器减1,并返回是否释放成功
protected boolean tryReleaseShared(int var1) {
int var2;
int var3;
do {
var2 = this.getState();
if(var2 == 0) {
return false;
}

var3 = var2 - 1;
} while(!this.compareAndSetState(var2, var3));

return var3 == 0;
}
}

Sync 实现了 tryAcquireShared 方法,调用 await 时判断锁是否可用,如果可用,就直接通过;如果不可用,则将线程记录在等待队列上。
Sync 实现了 tryReleaseShared 方法,调用 countDown 方法时,会将 state 减1,如果成功释放则执行 doReleaseShared 方法唤醒队首线程,队首线程唤醒后依次唤醒后续队首的线程,从而做到唤醒所有线程。

AQS 结构:

public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements Serializable {
private static final long serialVersionUID = 7373984972572414691L;
private transient volatile AbstractQueuedSynchronizer.Node head;
private transient volatile AbstractQueuedSynchronizer.Node tail;
private volatile int state;
static final long spinForTimeoutThreshold = 1000L;
private static final Unsafe unsafe = Unsafe.getUnsafe();
private static final long stateOffset;
private static final long headOffset;
private static final long tailOffset;
private static final long waitStatusOffset;
private static final long nextOffset;

protected AbstractQueuedSynchronizer() {
}

public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}

// 获取失败会被阻塞
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
// 设置等待头部队列,并传播唤醒下个头部线程
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}

// 唤醒等待队列队首线程
private void doReleaseShared() {
/*
* Ensure that a release propagates, even if there are other
* in-progress acquires/releases.  This proceeds in the usual
* way of trying to unparkSuccessor of head if it needs
* signal. But if it does not, status is set to PROPAGATE to
* ensure that upon release, propagation continues.
* Additionally, we must loop in case a new node is added
* while we are doing this. Also, unlike other uses of
* unparkSuccessor, we need to know if CAS to reset status
* fails, if so rechecking.
*/
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue;            // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue;                // loop on failed CAS
}
if (h == head)                   // loop if head changed
break;
}
}

private void doReleaseShared() {
while(true) {
AbstractQueuedSynchronizer.Node var1 = this.head;
if(var1 != null && var1 != this.tail) {
int var2 = var1.waitStatus;
if(var2 == -1) {
if(!compareAndSetWaitStatus(var1, -1, 0)) {
continue;
}

this.unparkSuccessor(var1);
} else if(var2 == 0 && !compareAndSetWaitStatus(var1, 0, -3)) {
continue;
}
}

if(var1 == this.head) {
return;
}
}
}
}

CountDownLatch 调用 await 方法时,调用 AQS 的 acquireSharedInterruptibly 方法,AQS 则调用子类的 tryAcquireShared 方法,如果获取成功,则直接返回;如果获取失败,则调用调用 AQS 的方法 doAcquireSharedInterruptibly 阻塞并加入到等待队列等待唤醒。
CountDownLatch 调用 countDown 方法时,调用 AQS 的 releaseShared 方法,AQS 则调用子类的 tryReleaseShared 方法,如果释放成功,则调用 doReleaseShared 方法唤醒队列首部线程,线程启动后,如果 tryAcquireShared 返回值大于等于 0,则通过 setHeadAndPropagate 方法进行传播,唤醒下一个线程。

CountDownLatch 使用:

public class Main {
public static void main(String[] args) throws InterruptedException {
final int threadsNumber = 4;
final CountDownLatch startGate = new CountDownLatch(1);
final CountDownLatch endGate = new CountDownLatch(threadsNumber);

for (int i = 0; i < threadsNumber; i++) {
Thread thread = new Thread(new Runnable() {
@Override
public void run() {
try {
startGate.await();
endGate.countDown();
System.out.println(this + "end");
} catch (Exception e) {
e.printStackTrace();
}
}
});
thread.start();
}

startGate.countDown();
endGate.await();
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  Java