您的位置:首页 > 编程语言 > Go语言

使用 golang 实现类似 pthread_barrier_t 语义的 barrier 对象 3ff0

2015-01-18 23:20 645 查看
看到golang标准库sync package WaitGroup 类型, 本以为是golang 版本的 barrier 对象实现,看到文档给出的使用示例:

 var wg sync.WaitGroup
var urls = []string{
"http://www.golang.org/",
"http://www.google.com/",
"http://www.somestupidname.com/",
}
for _, url := range urls {
// Increment the WaitGroup counter.
wg.Add(1)
// Launch a goroutine to fetch the URL.
go func(url string) {
// Decrement the counter when the goroutine completes.
defer wg.Done()
// Fetch the URL.
http.Get(url)
}(url)
}
// Wait for all HTTP fetches to complete.
wg.Wait()

可以看出WaitGroup 类型主要用于某个goroutine(调用Wait() 方法的那个), 等待个数不定goroutine(内部调用Done() 方法),

Add 方法对内部计数,添加或减少,Done方法其实是Add(-1);

与pthread_barrier_t 有着语义上的差别,pthread_barrier_wait() 的调用者之间互相等待,就好比5名队员(线程)参加跨栏比赛,使用 pthread_barrier_init 初始化最后一个参数为5, 五个队员都是好基友, 定了规矩, 不管谁先到栏杆, 都要等队友,直到最后一名队员跨过栏时,然后同一起步点再次出发。下面时使用pthread_barrier_t 的简单示例 5个线程,每个线程拥有一个私有数组,及增量数字:

#define _GNU_SOURCE

#include <pthread.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#define NTHR 5
#define NARR 6
#define INLOOPS 1000
#define OUTLOOPS 10
#define err_abort(code,text) do { \
char errbuf[128] = {0};         \
fprintf (stderr, "%s at \"%s\":%d: %s\n", \
(text), __FILE__, __LINE__, strerror_r(code,errbuf,128)); \
abort (); \
} while (0)

typedef struct thrArg {
pthread_t   tid;
int         incr;
int         arr[NARR];
}thrArg;

pthread_barrier_t   barrier;
thrArg  thrs[NTHR];

void *thrFunc (void *arg)
{
thrArg *self = (thrArg*)arg;
int j, i, k, status;

for (i = 0; i < OUTLOOPS; i++) {
status = pthread_barrier_wait (&barrier);
if (status > 0)
err_abort (status, "wait on barrier");
//每个线程迭代 INLOOPS 次,对自己的内部数组arr 成员加上 自己的增量值
for (j = 0; j < INLOOPS; j++)
for (k = 0; k < NARR; k++)
self->arr[k] += self->incr;
//先执行完迭代的线程在此等待,直到最后一个到达
status = pthread_barrier_wait (&barrier);
if (status > 0)
err_abort (status, "wait on barrier");
//最后一个到达的线程,把所有线程的内部增量加1
//此时其他先到的线程阻塞在第一次wait调用处,所以最后一个到达的线程
//可以排他性地访问所有线程的内部状态,if 语句执行完后,跳到第一次wait处,
//其他阻塞在第一次wait处的线程,得到释放,大家一块使用新的增量做计算
if (status == PTHREAD_BARRIER_SERIAL_THREAD ) {
int i;
for (i = 0; i < NTHR; i++)
thrs[i].incr += 1;
}
}
return NULL;
}

int main (int arg, char *argv[])
{
int i, j;
int status;

pthread_barrier_init (&barrier, NULL, NTHR);

for (i = 0; i < NTHR; i++) {
thrs[i].incr = i;
for (j = 0; j < NARR; j++)
thrs[i].arr[j] = j + 1;

status = pthread_create (&thrs[i].tid,
NULL, thrFunc, (void*)&thrs[i]);
if (status != 0)
err_abort (status, "create thread");
}

for (i = 0; i < NTHR; i++) {
status = pthread_join (thrs[i].tid, NULL);
if (status != 0)
err_abort (status, "join thread");

printf ("%02d: (%d) ", i, thrs[i].incr);

for (j = 0; j < NARR; j++)
printf ("%010u ", thrs[i].arr[j]);
printf ("\n");
}
pthread_barrier_destroy (&barrier);
return 0;
}

怎么用golang 来表达上述c 代码,需要实现pthread_barrier_t 等价语义的的 barrier 对象,可以使用golang 已有的mutex, cond

对象实现 barrier:

package main
import (
"fmt"
"sync"
)
type Barrier struct{
lock  sync.Mutex
cond  sync.Cond
threshold  int    //总的等待个数
count      int    //还剩多少没有到达barrier,即没有完成wait调用个数
cycle      bool   //用于重初始化下一个wait 周期,
}
func NewBarrier(n  int) *Barrier{
b := &Barrier{threshold: n, count: n}
b.cond.L = &b.lock
return b
}
//last == true ,说明最有一个到达
func (b *Barrier)Wait()(last bool){
b.lock.Lock()
defer  b.lock.Unlock()
cycle :=  b.cycle
b.count--
//最后一个到达负责,重初始化count 计数,cycle 变量翻转,
if b.count == 0 {
b.cycle  =  !b.cycle
b.count = b.threshold
b.cond.Broadcast()
last = true
}else{
for cycle == b.cycle {
b.cond.Wait()
}
}
return
}
type thrArg struct{
incr  int
arr   [narr]int
}
var (
thrs  [nthr]thrArg
wg   sync.WaitGroup
barrier = NewBarrier(nthr)
)
const (
outloops = 10
inloops  = 1000
nthr  = 5
narr  = 6
)

func thrFunc(arg  *thrArg){
defer wg.Done()
for i := 0; i < outloops; i++{
barrier.Wait()
for j := 0; j < inloops; j++{
for k:= 0; k < narr; k++{
arg.arr[k] += arg.incr
}
}
if barrier.Wait() {
for i := 0; i < nthr; i++{
thrs[i].incr += 1
}
}
}
}

func  main(){
for
3ff0
i:= 0; i < nthr; i++{
thrs[i].incr =  i
for j := 0; j < narr; j++{
thrs[i].arr[j] = j + 1
}
wg.Add(1)
go thrFunc(&thrs[i])
}
wg.Wait()
//所有goroutine完成,main goroutine,检查最后的结果
for i := 0; i < nthr; i++{
fmt.Printf("%02d: (%d) ", i, thrs[i].incr)
for j := 0; j < narr; j++{
fmt.Printf ("%010d ", thrs[i].arr[j]);
}
fmt.Println()
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  golang c linux pthread 线程