您的位置:首页 > 理论基础 > 数据结构算法

基于稀疏矩阵数据结构的相关人物搜索

2016-11-25 13:00 330 查看
实体的词向量是利用word2vec训练得到的,每个实体的词向量长度为100,训练得到的结果见附件vertexName-vec.txt

开始采用的是HashMap结构进行存储,但出现OOM等一些问题,就拿稀疏矩阵来尝试下,效果不错,嗖嗖嗖

稀疏矩阵构建及查询:

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Scanner;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.PropertyConfigurator;

import com.ifeng.commen.Utils.FileUtil;
import com.ifeng.matrix.OLNode;
import com.ifeng.matrix.SparseMatrixInfo;
import com.ifeng.matrix.TripleNode;
import com.ifeng.wuyg.Utils.Compare;

/**
*
* [code]
* 作用 :
*     获取指定词语的相关词语 采用稀疏矩阵存储,存储结构为十字链表
*

*/
public class RelatedWordsWithSparseMatrix implements Serializable {

static Log LOG = LogFactory.getLog(InnerProductWithSparseMatrix.class);

static FileUtil fileUtil = new FileUtil();

class Matrix {
String wordDic;
ArrayList vectors = new ArrayList();

public String getWordDic() {
return wordDic;
}

public void setWordDic(String wordDic) {
this.wordDic = wordDic;
}

public ArrayList getVectors() {
return vectors;
}

public void setVectors(ArrayList vectors) {
this.vectors = vectors;
}
}

class PointerInfo implements Serializable {
String word;
double score;

public String getWord() {
return word;
}

public void setWord(String word) {
this.word = word;
}

public double getScore() {
return score;
}

public void setScore(double score) {
this.score = score;
}
}

public static void main(String[] args) throws IOException,
ClassNotFoundException {

String rootpath = "./";

PropertyConfigurator.configure(rootpath+"conf/log4j.properties");

String filepath = rootpath+"database/vertexName-vec.txt";

String content = fileUtil.Read(filepath, "utf-8");

ArrayList list = new ArrayList();

list.addAll(Arrays.asList(content.split("\n")));

HashMap wordMap = new HashMap();

HashMap dic2WordMap = new HashMap();
HashMap word2dicMap = new HashMap();
int index = 0;

for (String row : list) {
Matrix matrix = new InnerProductWithSparseMatrix().new Matrix();
ArrayList map_l = new ArrayList();
map_l.addAll(Arrays.asList(row.split(" ")));
String word_l = map_l.get(0);

word2dicMap.put(word_l, index);
dic2WordMap.put(index, word_l);

matrix.setWordDic(word_l);
matrix.setVectors(str2Double(map_l));
wordMap.put(index, matrix);

index++;
}

SparseMatrixInfo sparseMatrixInfo = computeWithWordInSparseMatrix(wordMap, wordMap);

while (true) {
System.err.println("请输入检索节点:");

Scanner scanner = new Scanner(System.in);

String query = scanner.nextLine();

if("-1".equals(query)){
break;
}

if(!word2dicMap.containsKey(query)){
System.err.println("不包含该词语!");
continue;
}

int Rindex = word2dicMap.get(query);

OLNode Rnode = sparseMatrixInfo.getRhead()[Rindex];

System.err.println("想要得到的TopN:");
String topS = scanner.nextLine();

int topN = Integer.valueOf(topS);

HashMap relatedMap = new HashMap();

while(null != Rnode){
TripleNode tripleNode = Rnode.getData();
if(null != tripleNode){
relatedMap.put(dic2WordMap.get(tripleNode.getColIndex()), (Double)tripleNode.getValue());
}
Rnode = Rnode.getRight();
}
ArrayList keyList = new ArrayList(relatedMap.keySet());

Collections.sort(keyList, new Compare(relatedMap));
for(String key : keyList){
System.err.println(key+"#DIV#"+relatedMap.get(key)+"\t");
}

}
//获取A,B节点的相关度
while(true){

System.err.println("输入A词语:");
Scanner scanner =new Scanner(System.in);
String query_A = scanner.nextLine();

scanner = new Scanner(System.in);

String query_B =scanner.nextLine();

if(!word2dicMap.containsKey(query_A) || !word2dicMap.containsKey(query_B)){
System.err.println("query_A:"+query_A+"\tquery_B:"+query_B);
continue;
}

int Rindex = word2dicMap.get(query_A);
int Cindex = word2dicMap.get(query_B);

TripleNode tripleNode = sparseMatrixInfo.getNode_indexMap().get(Rindex+"_"+Cindex);
System.err.println(tripleNode.getValue());

}

}

public static void computeSingleRelationNodes(
HashMap map_l, HashMap mapAll)
throws IOException {

Iterator iterator = map_l.keySet().iterator();

int count = 0;

while (true) {
System.err.println("请输入query:");
Scanner scanner = new Scanner(System.in);
String query = scanner.nextLine();

if ("-1".equals(query)) {
break;
}

Matrix matrix = map_l.get(query);

System.err.println(query + "相关节点:");

HashMap relatedMap = new HashMap();

Iterator iterator2 = mapAll.keySet().iterator();

while (iterator2.hasNext()) {

String word_r = iterator2.next();

double value = 0;

if (query.equals(word_r)) {
continue;
}

Matrix matrix2 = mapAll.get(word_r);

value = getInnerProductValue(matrix.getVectors(),
matrix2.getVectors());

relatedMap.put(word_r, value);
}

ArrayList keyList = new ArrayList(
relatedMap.keySet());

Collections.sort(keyList, new Compare(relatedMap));

while (true) {
System.err.println("请输入需要检查的节点:");
Scanner scanner2 = new Scanner(System.in);
String query2 = scanner2.nextLine();
if ("-1".equals(query2)) {
break;
}
System.err.println(relatedMap.get(query2));
}

for (int i = 0; i < 30; i++) {
System.err.println(keyList.get(i) + "\t"
+ relatedMap.get(keyList.get(i)));
}

}

}

public static SparseMatrixInfo computeWithWord_MultiThread(
final HashMap map_l,final HashMap mapAll)
throws IOException {

final SparseMatrixInfo sparseMatrixInfo = initMatrix(map_l.size(), map_l.size());

Iterator iterator = map_l.keySet().iterator();

int count = 0;

ArrayList keyList = new ArrayList(map_l.keySet());

int threadNum = 1;

ExecutorService service = Executors.newFixedThreadPool(threadNum);

int p_unit = keyList.size() / threadNum;

int s_index = 0;

int e_index = 0;

for (int i = 0; i < threadNum; i++) {
s_index = i * p_unit;
if (i == threadNum - 1)
e_index = keyList.size();
else
e_index = s_index + p_unit;

final List wordList = keyList.subList(s_index, e_index);

Runnable runnable = new Runnable() {

@Override
public void run() {
// TODO Auto-generated method stub

for(Integer word_l : wordList){

LOG.info(Thread.currentThread().getName()+"\t"+word_l);

Matrix matrix = map_l.get(word_l);

Iterator iterator = mapAll.keySet().iterator();

while (iterator.hasNext()) {

Integer word_r = iterator.next();

double value = 0;

if (word_l.equals(word_r)) {
continue;
}

Matrix matrix2 = mapAll.get(word_r);

value = getInnerProductValue(matrix.getVectors(),
matrix2.getVectors());

if(value > 0.6){
synchronized (this.getClass()) {
sparseMatrixInfo.increaseNumsByOne();

TripleNode tripleNode = new TripleNode(word_l, word_r, value);

OLNode newNode = new OLNode(tripleNode);

OLNode Rnode = sparseMatrixInfo.getRhead()[word_l];

while(null != Rnode.getRight()){
Rnode = Rnode.getRight();
}
Rnode.setRight(newNode);

OLNode Cnode = sparseMatrixInfo.getChead()[word_r];

while(null != Cnode.getDown()){
Cnode = Cnode.getDown();
}
Cnode.setDown(newNode);
}

}
}
}

}
};

service.execute(runnable);

}

service.shutdown();

while(true){
if(service.isTerminated()){
System.err.println("线程执行完毕!");
break;
}else{
try {
//休息三分钟再次检测
Thread.sleep(1000 * 60 * 3);
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
System.err.println("有线程未执行完毕");
}
}

return sparseMatrixInfo;

}
/**
*
* @Title:initMatrix
* @Description:稀疏矩阵初始化
* @param row
* @param col
* @return
* @author:wuyg1
* @date:2016年11月24日
*/
public static SparseMatrixInfo initMatrix(int row,int col){
SparseMatrixInfo sparseMatrixInfo = new SparseMatrixInfo();

sparseMatrixInfo.setRows(row);
sparseMatrixInfo.setCols(col);

sparseMatrixInfo.setNums(0);

sparseMatrixInfo.setChead(new OLNode[col]);
sparseMatrixInfo.setRhead(new OLNode[row]);

for(int i = 0;i< sparseMatrixInfo.getChead().length;i++){
sparseMatrixInfo.getChead()[i] = new OLNode();
}

for(int i = 0;i< sparseMatrixInfo.getRhead().length;i++){
sparseMatrixInfo.getRhead()[i] = new OLNode();
}
return sparseMatrixInfo;
}

public static SparseMatrixInfo computeWithWordInSparseMatrix(
HashMap map_l, HashMap mapAll)
throws IOException {

Iterator iterator = map_l.keySet().iterator();

int count = 0;

SparseMatrixInfo sparseMatrixInfo = initMatrix(map_l.size(), map_l.size());

while (iterator.hasNext()) {

int word_l = iterator.next();

Matrix matrix = map_l.get(word_l);

LOG.info(word_l + "\t" + count);

Iterator iterator2 = mapAll.keySet().iterator();

while (iterator2.hasNext()) {

int word_r = iterator2.next();

double value = 0;

if (word_l == word_r) {
continue;
}

Matrix matrix2 = mapAll.get(word_r);

value = getInnerProductValue(matrix.getVectors(),
matrix2.getVectors());

//relatedMap.put(word_r, value);

if(value > 0.6){

sparseMatrixInfo.increaseNumsByOne();

TripleNode tripleNode = new TripleNode(word_l, word_r, value);

OLNode newNode = new OLNode(tripleNode);

OLNode Rnode = sparseMatrixInfo.getRhead()[word_l];

sparseMatrixInfo.getNode_indexMap().put(word_l+"_"+word_r, tripleNode);

OLNode tempRnode = Rnode.getRight();
Rnode.setRight(newNode);
newNode.setRight(tempRnode);

OLNode Cnode = sparseMatrixInfo.getChead()[word_r];

OLNode tempCnode = Cnode.getDown();
Cnode.setDown(newNode);
newNode.setDown(tempRnode);
}
}

count++;

}

return sparseMatrixInfo;
}

public static HashMap computeWithWord(
HashMap map_l, HashMap mapAll)
throws IOException {

HashMap map = new HashMap();

Iterator iterator = map_l.keySet().iterator();

int count = 0;

while (iterator.hasNext()) {

String word_l = iterator.next();

Matrix matrix = map_l.get(word_l);

System.err.println(word_l + "\t" + count);

HashMap relatedMap = new HashMap();

Iterator iterator2 = mapAll.keySet().iterator();

while (iterator2.hasNext()) {

String word_r = iterator2.next();

double value = 0;

if (word_l.equals(word_r)) {
continue;
}

Matrix matrix2 = mapAll.get(word_r);

value = getInnerProductValue(matrix.getVectors(),
matrix2.getVectors());

relatedMap.put(word_r, value);
}

map.put(word_l, relatedMap);

count++;

}

return map;
}

public static HashMap computeWithDicID(
HashMap map_l, HashMap mapAll)
throws IOException {

HashMap map = new HashMap();

Iterator iterator = map_l.keySet().iterator();

int count = 0;

while (iterator.hasNext()) {

Integer word_l = iterator.next();

Matrix matrix = map_l.get(word_l);

System.err.println(word_l + "\t" + count);

HashMap relatedMap = new HashMap();

Iterator iterator2 = mapAll.keySet().iterator();

while (iterator2.hasNext()) {

Integer word_r = iterator2.next();

double value = 0;

if (word_l.equals(word_r)) {
continue;
}

Matrix matrix2 = mapAll.get(word_r);

value = getInnerProductValue(matrix.getVectors(),
matrix2.getVectors());

relatedMap.put(word_r, value);
}

ArrayList keyList = new ArrayList(
relatedMap.keySet());

Collections.sort(keyList, new Compare(relatedMap));

HashMap relatedMap_new = new HashMap();

for (int i = 0; i < 30; i++) {
relatedMap_new.put(i, relatedMap.get(i));
}

map.put(word_l, relatedMap_new);

count++;

}

return map;
}

public static ArrayList str2Double(ArrayList list) {
ArrayList reList = new ArrayList();

for (int index = 1; index < list.size(); index++) {
double value_d = Double.valueOf(list.get(index));
reList.add(value_d);
}
return reList;
}

/**
*
* @Title:getInnerProductValue
* @Description:计算二者内积
* @param innermap_l
* @param innerList_R
* @return
* @author:wuyg1
* @date:2016年9月22日
*/
public static double getInnerProductValue(ArrayList innermap_l,
ArrayList innerList_R) {

int length = 0;

length = innermap_l.size() < innerList_R.size() ? innermap_l.size()
: innerList_R.size();

double result = 0;

double Lsum = 0;
double Rsum = 0;

for (int index = 0; index < length; index++) {
result += innermap_l.get(index) * innerList_R.get(index);
Lsum += Math.pow(innermap_l.get(index), 2);
Rsum += Math.pow(innerList_R.get(index), 2);
}

result = result / (Math.sqrt(Lsum) * Math.sqrt(Rsum));

return result;

}
}
[/code]
稀疏矩阵结构:

import java.util.HashMap;

public class SparseMatrixInfo {
/**
* 每个非零元素的位置索引
*/
private HashMap node_indexMap= new HashMap();

/**
* @cols原始矩阵的列数
*/
private int cols;
/**
* @rows原始矩阵的行数
*/
private int rows;
/**
* @nums原始矩阵中非零元素的个数
*/
private int nums;
/**
* @rhead列指针
*/
private OLNode[] rhead;
/**
* @chead行指针
*/
private OLNode[] chead;

public SparseMatrixInfo() {
// TODO Auto-generated constructor stub
}

public int getCols() {
return cols;
}

public void setCols(int cols) {
this.cols = cols;
}

public int getRows() {
return rows;
}

public void setRows(int rows) {
this.rows = rows;
}

public int getNums() {
return nums;
}

public void setNums(int nums) {
this.nums = nums;
}

public synchronized OLNode[] getRhead() {
return rhead;
}

public synchronized void setRhead(OLNode[] rhead) {
this.rhead = rhead;
}

public synchronized OLNode[] getChead() {
return chead;
}

public synchronized void setChead(OLNode[] chead) {
this.chead = chead;
}

public void increaseColByOne(){
this.cols = this.cols + 1;
}
public void decreaseColByOne(){
this.cols = this.cols - 1;
}
public void increaseRowsByOne(){
this.rows = this.rows + 1;
}
public void decreaseRowsByOne(){
this.rows = this.rows - 1;
}
public synchronized void increaseNumsByOne(){
this.nums = this.nums + 1;
}
public synchronized void decreaseNumsByOne(){
this.nums = this.nums - 1;
}

public HashMap getNode_indexMap() {
return node_indexMap;
}

public void setNode_indexMap(HashMap node_indexMap) {
this.node_indexMap = node_indexMap;
}
}


三元组在矩阵中具体位置结构:

/**
* 三元组结构
*/
public class  TripleNode {
int rowIndex;//非零元的行下标
int colIndex;//非零元的列下标
T value;//非零元的值

public  TripleNode(int i, int j, T  value){
this.rowIndex = i;
this.colIndex = j;
this.value = value;
}

public int getRowIndex() {
return rowIndex;
}

public void setRowIndex(int rowIndex) {
this.rowIndex = rowIndex;
}

public int getColIndex() {
return colIndex;
}

public void setColIndex(int colIndex) {
this.colIndex = colIndex;
}

public T getValue() {
return value;
}

public void setValue(T value) {
this.value = value;
}

@Override
public String toString() {
return "TripleNode [rowIndex=" + rowIndex + ", colIndex=" + colIndex
+ ", value=" + value + "]";
}

}


三元组结构体:

public class OLNode {

private TripleNode data;// 三元组存储的数据包括该元素所在的行列和数值
private OLNode Right;// 行链表指针
private OLNode down;// 列链表指针

/**
* @构造器 
* @Description: TODO()
*/
public OLNode() {
this(null, null, null);
}

/**
* @构造器 
* @Description: TODO()
* @param data
*/
public OLNode(TripleNode data) {
this(data, null, null);
}

public OLNode(TripleNode data, OLNode right, OLNode down) {
super();
this.data = data;
Right = right;
this.down = down;
}

public TripleNode getData() {
return data;
}

public synchronized OLNode getRight() {
return Right;
}

public synchronized void setRight(OLNode right) {
Right = right;
}

public synchronized OLNode getDown() {
return down;
}

public synchronized void setDown(OLNode down) {
this.down = down;
}

public void setData(TripleNode data) {
this.data = data;
}

}


搜索结果:

请输入检索节点:

毛泽东

想要得到的TopN:

10

叶剑英#DIV#0.8786766021887125

华国锋#DIV#0.8691403116586236

邓小平#DIV#0.8560495212960941

陈毅#DIV#0.8417411983853618

萧劲光#DIV#0.8167936577713354

徐海东#DIV#0.7978155750998943

张春桥#DIV#0.7965857537574821

陈赓#DIV#0.7885015060295482

李克农#DIV#0.7820805260018616

毛远新#DIV#0.773292512123925
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: