您的位置:首页 > 其它

zookeeper 实现分布式锁

2015-05-13 20:55 363 查看
package com.concurrent;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import org.apache.zookeeper.CreateMode;
import org.apache.zookeeper.KeeperException;
import org.apache.zookeeper.WatchedEvent;
import org.apache.zookeeper.Watcher;
import org.apache.zookeeper.ZooDefs;
import org.apache.zookeeper.ZooKeeper;
import org.apache.zookeeper.data.Stat;

public class DistributedLock implements Lock, Watcher {
private ZooKeeper zk;
private String root = "/locks";// 根
private String lockName;// 竞争资源的标志
private String waitNode;// 等待前一个锁
private String myZnode;// 当前锁
private CountDownLatch latch;// 计数器
private int sessionTimeout = 30000;
private List<Exception> exception = new ArrayList<Exception>();

/**
* 创建分布式锁,使用前请确认config配置的zookeeper服务可用
*
* @param config
*            127.0.0.1:2181
* @param lockName
*            竞争资源标志,lockName中不能包含单词lock
*/
public DistributedLock(String config, String lockName) {
this.lockName = lockName;
// 创建一个与服务器的连接
try {
zk = new ZooKeeper(config, sessionTimeout, this);
Stat stat = zk.exists(root, false);
if (stat == null) {
// 创建根节点
zk.create(root, new byte[0], ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);
}
} catch (IOException e) {
exception.add(e);
} catch (KeeperException e) {
exception.add(e);
} catch (InterruptedException e) {
exception.add(e);
}
}

/**
* zookeeper节点的监视器
*/
public void process(WatchedEvent event) {
if (this.latch != null) {
this.latch.countDown();
}
}

public void lock() {
if (exception.size() > 0) {
throw new LockException(exception.get(0));
}
try {
if (this.tryLock()) {
System.out.println("Thread " + Thread.currentThread().getId() + " " + myZnode + " get lock true");
return;
} else {
waitForLock(waitNode, sessionTimeout);// 等待锁
}
} catch (KeeperException e) {
throw new LockException(e);
} catch (InterruptedException e) {
throw new LockException(e);
}
}

public boolean tryLock() {
try {
String splitStr = "_lock_";
if (lockName.contains(splitStr))
throw new LockException("lockName can not contains \\u000B");
// 创建临时子节点
myZnode = zk.create(root + "/" + lockName + splitStr, new byte[0], ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL);
System.out.println(myZnode + " is created ");
// 取出所有子节点
List<String> subNodes = zk.getChildren(root, false);
// 取出所有lockName的锁
List<String> lockObjNodes = new ArrayList<String>();
for (String node : subNodes) {
String _node = node.split(splitStr)[0];
if (_node.equals(lockName)) {
lockObjNodes.add(node);
}
}
Collections.sort(lockObjNodes);
System.out.println(myZnode + "==" + lockObjNodes.get(0));
if (myZnode.equals(root + "/" + lockObjNodes.get(0))) {
// 如果是最小的节点,则表示取得锁
return true;
}
// 如果不是最小的节点,找到比自己小1的节点
String subMyZnode = myZnode.substring(myZnode.lastIndexOf("/") + 1);
waitNode = lockObjNodes.get(Collections.binarySearch(lockObjNodes, subMyZnode) - 1);
} catch (KeeperException e) {
throw new LockException(e);
} catch (InterruptedException e) {
throw new LockException(e);
}
return false;
}

public boolean tryLock(long time, TimeUnit unit) {
try {
if (this.tryLock()) {
return true;
}
return waitForLock(waitNode, time);
} catch (Exception e) {
e.printStackTrace();
}
return false;
}

private boolean waitForLock(String lower, long waitTime) throws InterruptedException, KeeperException {
Stat stat = zk.exists(root + "/" + lower, true);
// 判断比自己小一个数的节点是否存在,如果不存在则无需等待锁,同时注册监听
if (stat != null) {
System.out.println("Thread " + Thread.currentThread().getId() + " waiting for " + root + "/" + lower);
this.latch = new CountDownLatch(1);
this.latch.await(waitTime, TimeUnit.MILLISECONDS);
this.latch = null;
}
return true;
}

public void unlock() {
try {
System.out.println("unlock " + myZnode);
zk.delete(myZnode, -1);
myZnode = null;
zk.close();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (KeeperException e) {
e.printStackTrace();
}
}

public void lockInterruptibly() throws InterruptedException {
this.lock();
}

public Condition newCondition() {
return null;
}

public class LockException extends RuntimeException {
private static final long serialVersionUID = 1L;

public LockException(String e) {
super(e);
}

public LockException(Exception e) {
super(e);
}
}
}


package com.concurrent;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;

public class ConcurrentTest {
private CountDownLatch startSignal = new CountDownLatch(1);// 开始阀门
private CountDownLatch doneSignal = null;// 结束阀门
private CopyOnWriteArrayList<Long> list = new CopyOnWriteArrayList<Long>();
private AtomicInteger err = new AtomicInteger();// 原子递增
private ConcurrentTask[] task = null;

public ConcurrentTest(ConcurrentTask... task) {
this.task = task;
if (task == null) {
System.out.println("task can not null");
System.exit(1);
}
doneSignal = new CountDownLatch(task.length);
start();
}

/**
* @param args
* @throws ClassNotFoundException
*/
private void start() {
// 创建线程,并将所有线程等待在阀门处
createThread();
// 打开阀门
startSignal.countDown();// 递减锁存器的计数,如果计数到达零,则释放所有等待的线程
try {
doneSignal.await();// 等待所有线程都执行完毕
} catch (InterruptedException e) {
e.printStackTrace();
}
// 计算执行时间
getExeTime();
}

/**
* 初始化所有线程,并在阀门处等待
*/
private void createThread() {
long len = doneSignal.getCount();
for (int i = 0; i < len; i++) {
final int j = i;
new Thread(new Runnable() {
public void run() {
try {
startSignal.await();// 使当前线程在锁存器倒计数至零之前一直等待
long start = System.currentTimeMillis();
task[j].run();
long end = (System.currentTimeMillis() - start);
list.add(end);
} catch (Exception e) {
err.getAndIncrement();// 相当于err++
}
doneSignal.countDown();
}
}).start();
}
}

/**
* 计算平均响应时间
*/
private void getExeTime() {
int size = list.size();
List<Long> _list = new ArrayList<Long>(size);
_list.addAll(list);
Collections.sort(_list);
long min = _list.get(0);
long max = _list.get(size - 1);
long sum = 0L;
for (Long t : _list) {
sum += t;
}
long avg = sum / size;
System.out.println("min: " + min);
System.out.println("max: " + max);
System.out.println("avg: " + avg);
System.out.println("err: " + err.get());
}

public interface ConcurrentTask {
void run();
}
}


package com.concurrent;

import com.concurrent.ConcurrentTest.ConcurrentTask;

public class ZkTest {
public static void main(String[] args) {
Runnable task1 = new Runnable() {
public void run() {
DistributedLock lock = null;
try {
lock = new DistributedLock("127.0.0.1:2181", "test1");
// lock = new DistributedLock("127.0.0.1:2182","test2");
lock.lock();
Thread.sleep(3000);
System.out.println("===Thread " + Thread.currentThread().getId() + " running");
} catch (Exception e) {
e.printStackTrace();
} finally {
if (lock != null)
lock.unlock();
}

}

};
new Thread(task1).start();
try {
Thread.sleep(1000);
} catch (InterruptedException e1) {
e1.printStackTrace();
}
ConcurrentTask[] tasks = new ConcurrentTask[6];
for (int i = 0; i < tasks.length; i++) {
ConcurrentTask task3 = new ConcurrentTask() {
public void run() {
DistributedLock lock = null;
try {
lock = new DistributedLock("127.0.0.1:2181", "test2");
lock.lock();
System.out.println("Thread " + Thread.currentThread().getId() + " running");
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}

}
};
tasks[i] = task3;
}
new ConcurrentTest(tasks);
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: