Procházet zdrojové kódy

fix: Threadpool has now thread Ids

Matthias Ladkau před 4 roky
rodič
revize
6f33f4f03e
2 změnil soubory, kde provedl 40 přidání a 23 odebrání
  1. 36 19
      pools/threadpool.go
  2. 4 4
      pools/threadpool_test.go

+ 36 - 19
pools/threadpool.go

@@ -7,9 +7,6 @@
  * For further information see: http://creativecommons.org/publicdomain/zero/1.0/
  */
 
-/*
-Package pools contains object pooling utilities.
-*/
 package pools
 
 import (
@@ -33,9 +30,10 @@ Task is a task which should be run in a thread.
 type Task interface {
 
 	/*
-		Run the task.
+		Run the task. The function gets the unique thread ID of the worker
+		which executes the task.
 	*/
-	Run() error
+	Run(tid uint64) error
 
 	/*
 		HandleError handles an error which occurred during the run method.
@@ -124,7 +122,9 @@ type ThreadPool struct {
 
 	// Worker regulation
 
-	workerIDCount uint64                       // Id counter for worker tasks
+	workerIDCount uint64      // Id counter for worker tasks
+	workerIDLock  *sync.Mutex // Lock for ID generation
+
 	workerMap     map[uint64]*ThreadPoolWorker // Map of all workers
 	workerIdleMap map[uint64]*ThreadPoolWorker // Map of all idle workers
 	workerMapLock *sync.Mutex                  // Lock for worker map
@@ -156,7 +156,7 @@ NewThreadPoolWithQueue creates a new thread pool with a specific task queue.
 */
 func NewThreadPoolWithQueue(q TaskQueue) *ThreadPool {
 	return &ThreadPool{q, &sync.Mutex{},
-		0, make(map[uint64]*ThreadPoolWorker),
+		1, &sync.Mutex{}, make(map[uint64]*ThreadPoolWorker),
 		make(map[uint64]*ThreadPoolWorker), &sync.Mutex{},
 		0, sync.NewCond(&sync.Mutex{}), &sync.Mutex{},
 		math.MaxInt32, func() {}, false, 0, func() {}, false}
@@ -254,6 +254,21 @@ func (tp *ThreadPool) getTask() Task {
 	return nil
 }
 
+/*
+NewThreadID creates a new thread ID unique to this pool.
+*/
+func (tp *ThreadPool) NewThreadID() uint64 {
+
+	tp.workerIDLock.Lock()
+
+	res := tp.workerIDCount
+	tp.workerIDCount++
+
+	tp.workerIDLock.Unlock()
+
+	return res
+}
+
 /*
 SetWorkerCount sets the worker count of this pool. If the wait flag is true then
 this call will return after the pool has reached the requested worker count.
@@ -279,10 +294,10 @@ func (tp *ThreadPool) SetWorkerCount(count int, wait bool) {
 		tp.workerKill = 0
 
 		for len(tp.workerMap) != count {
-			worker := &ThreadPoolWorker{tp.workerIDCount, tp}
+			tid := tp.NewThreadID()
+			worker := &ThreadPoolWorker{tid, tp}
 			go worker.run()
-			tp.workerMap[tp.workerIDCount] = worker
-			tp.workerIDCount++
+			tp.workerMap[tid] = worker
 		}
 
 		tp.workerMapLock.Unlock()
@@ -455,6 +470,14 @@ run lets this worker run tasks.
 */
 func (w *ThreadPoolWorker) run() {
 
+	defer func() {
+		// Remove worker from workerMap
+
+		w.pool.workerMapLock.Lock()
+		delete(w.pool.workerMap, w.id)
+		w.pool.workerMapLock.Unlock()
+	}()
+
 	for true {
 
 		// Try to get the next task
@@ -480,7 +503,7 @@ func (w *ThreadPoolWorker) run() {
 
 		// Run the task
 
-		if err := task.Run(); err != nil {
+		if err := task.Run(w.id); err != nil {
 			task.HandleError(err)
 		}
 
@@ -490,12 +513,6 @@ func (w *ThreadPoolWorker) run() {
 			w.pool.workerMapLock.Unlock()
 		}
 	}
-
-	// Remove worker from workerMap
-
-	w.pool.workerMapLock.Lock()
-	delete(w.pool.workerMap, w.id)
-	w.pool.workerMapLock.Unlock()
 }
 
 /*
@@ -508,10 +525,10 @@ type idleTask struct {
 /*
 Run the idle task.
 */
-func (t *idleTask) Run() error {
+func (t *idleTask) Run(tid uint64) error {
 	t.tp.newTaskCond.L.Lock()
+	defer t.tp.newTaskCond.L.Unlock()
 	t.tp.newTaskCond.Wait()
-	t.tp.newTaskCond.L.Unlock()
 	return nil
 }
 

+ 4 - 4
pools/threadpool_test.go

@@ -23,7 +23,7 @@ type testTask struct {
 	errorHandler func(e error)
 }
 
-func (t *testTask) Run() error {
+func (t *testTask) Run(tid uint64) error {
 	return t.task()
 }
 
@@ -77,21 +77,21 @@ func TestDefaultTaskQueue(t *testing.T) {
 
 	// Execute the functions
 
-	tq.Pop().Run()
+	tq.Pop().Run(0)
 
 	if res := tq.Size(); res != 2 {
 		t.Error("Unexpected result: ", res)
 		return
 	}
 
-	tq.Pop().Run()
+	tq.Pop().Run(0)
 
 	if res := tq.Size(); res != 1 {
 		t.Error("Unexpected result: ", res)
 		return
 	}
 
-	tq.Pop().Run()
+	tq.Pop().Run(0)
 
 	if res := tq.Size(); res != 0 {
 		t.Error("Unexpected result: ", res)