Browse Source

fix: Threadpool has now thread Ids

Matthias Ladkau 3 years ago
parent
commit
6f33f4f03e
2 changed files with 40 additions and 23 deletions
  1. 36 19
      pools/threadpool.go
  2. 4 4
      pools/threadpool_test.go

+ 36 - 19
pools/threadpool.go

@@ -7,9 +7,6 @@
  * For further information see: http://creativecommons.org/publicdomain/zero/1.0/
  * For further information see: http://creativecommons.org/publicdomain/zero/1.0/
  */
  */
 
 
-/*
-Package pools contains object pooling utilities.
-*/
 package pools
 package pools
 
 
 import (
 import (
@@ -33,9 +30,10 @@ Task is a task which should be run in a thread.
 type Task interface {
 type Task interface {
 
 
 	/*
 	/*
-		Run the task.
+		Run the task. The function gets the unique thread ID of the worker
+		which executes the task.
 	*/
 	*/
-	Run() error
+	Run(tid uint64) error
 
 
 	/*
 	/*
 		HandleError handles an error which occurred during the run method.
 		HandleError handles an error which occurred during the run method.
@@ -124,7 +122,9 @@ type ThreadPool struct {
 
 
 	// Worker regulation
 	// Worker regulation
 
 
-	workerIDCount uint64                       // Id counter for worker tasks
+	workerIDCount uint64      // Id counter for worker tasks
+	workerIDLock  *sync.Mutex // Lock for ID generation
+
 	workerMap     map[uint64]*ThreadPoolWorker // Map of all workers
 	workerMap     map[uint64]*ThreadPoolWorker // Map of all workers
 	workerIdleMap map[uint64]*ThreadPoolWorker // Map of all idle workers
 	workerIdleMap map[uint64]*ThreadPoolWorker // Map of all idle workers
 	workerMapLock *sync.Mutex                  // Lock for worker map
 	workerMapLock *sync.Mutex                  // Lock for worker map
@@ -156,7 +156,7 @@ NewThreadPoolWithQueue creates a new thread pool with a specific task queue.
 */
 */
 func NewThreadPoolWithQueue(q TaskQueue) *ThreadPool {
 func NewThreadPoolWithQueue(q TaskQueue) *ThreadPool {
 	return &ThreadPool{q, &sync.Mutex{},
 	return &ThreadPool{q, &sync.Mutex{},
-		0, make(map[uint64]*ThreadPoolWorker),
+		1, &sync.Mutex{}, make(map[uint64]*ThreadPoolWorker),
 		make(map[uint64]*ThreadPoolWorker), &sync.Mutex{},
 		make(map[uint64]*ThreadPoolWorker), &sync.Mutex{},
 		0, sync.NewCond(&sync.Mutex{}), &sync.Mutex{},
 		0, sync.NewCond(&sync.Mutex{}), &sync.Mutex{},
 		math.MaxInt32, func() {}, false, 0, func() {}, false}
 		math.MaxInt32, func() {}, false, 0, func() {}, false}
@@ -254,6 +254,21 @@ func (tp *ThreadPool) getTask() Task {
 	return nil
 	return nil
 }
 }
 
 
+/*
+NewThreadID creates a new thread ID unique to this pool.
+*/
+func (tp *ThreadPool) NewThreadID() uint64 {
+
+	tp.workerIDLock.Lock()
+
+	res := tp.workerIDCount
+	tp.workerIDCount++
+
+	tp.workerIDLock.Unlock()
+
+	return res
+}
+
 /*
 /*
 SetWorkerCount sets the worker count of this pool. If the wait flag is true then
 SetWorkerCount sets the worker count of this pool. If the wait flag is true then
 this call will return after the pool has reached the requested worker count.
 this call will return after the pool has reached the requested worker count.
@@ -279,10 +294,10 @@ func (tp *ThreadPool) SetWorkerCount(count int, wait bool) {
 		tp.workerKill = 0
 		tp.workerKill = 0
 
 
 		for len(tp.workerMap) != count {
 		for len(tp.workerMap) != count {
-			worker := &ThreadPoolWorker{tp.workerIDCount, tp}
+			tid := tp.NewThreadID()
+			worker := &ThreadPoolWorker{tid, tp}
 			go worker.run()
 			go worker.run()
-			tp.workerMap[tp.workerIDCount] = worker
-			tp.workerIDCount++
+			tp.workerMap[tid] = worker
 		}
 		}
 
 
 		tp.workerMapLock.Unlock()
 		tp.workerMapLock.Unlock()
@@ -455,6 +470,14 @@ run lets this worker run tasks.
 */
 */
 func (w *ThreadPoolWorker) run() {
 func (w *ThreadPoolWorker) run() {
 
 
+	defer func() {
+		// Remove worker from workerMap
+
+		w.pool.workerMapLock.Lock()
+		delete(w.pool.workerMap, w.id)
+		w.pool.workerMapLock.Unlock()
+	}()
+
 	for true {
 	for true {
 
 
 		// Try to get the next task
 		// Try to get the next task
@@ -480,7 +503,7 @@ func (w *ThreadPoolWorker) run() {
 
 
 		// Run the task
 		// Run the task
 
 
-		if err := task.Run(); err != nil {
+		if err := task.Run(w.id); err != nil {
 			task.HandleError(err)
 			task.HandleError(err)
 		}
 		}
 
 
@@ -490,12 +513,6 @@ func (w *ThreadPoolWorker) run() {
 			w.pool.workerMapLock.Unlock()
 			w.pool.workerMapLock.Unlock()
 		}
 		}
 	}
 	}
-
-	// Remove worker from workerMap
-
-	w.pool.workerMapLock.Lock()
-	delete(w.pool.workerMap, w.id)
-	w.pool.workerMapLock.Unlock()
 }
 }
 
 
 /*
 /*
@@ -508,10 +525,10 @@ type idleTask struct {
 /*
 /*
 Run the idle task.
 Run the idle task.
 */
 */
-func (t *idleTask) Run() error {
+func (t *idleTask) Run(tid uint64) error {
 	t.tp.newTaskCond.L.Lock()
 	t.tp.newTaskCond.L.Lock()
+	defer t.tp.newTaskCond.L.Unlock()
 	t.tp.newTaskCond.Wait()
 	t.tp.newTaskCond.Wait()
-	t.tp.newTaskCond.L.Unlock()
 	return nil
 	return nil
 }
 }
 
 

+ 4 - 4
pools/threadpool_test.go

@@ -23,7 +23,7 @@ type testTask struct {
 	errorHandler func(e error)
 	errorHandler func(e error)
 }
 }
 
 
-func (t *testTask) Run() error {
+func (t *testTask) Run(tid uint64) error {
 	return t.task()
 	return t.task()
 }
 }
 
 
@@ -77,21 +77,21 @@ func TestDefaultTaskQueue(t *testing.T) {
 
 
 	// Execute the functions
 	// Execute the functions
 
 
-	tq.Pop().Run()
+	tq.Pop().Run(0)
 
 
 	if res := tq.Size(); res != 2 {
 	if res := tq.Size(); res != 2 {
 		t.Error("Unexpected result: ", res)
 		t.Error("Unexpected result: ", res)
 		return
 		return
 	}
 	}
 
 
-	tq.Pop().Run()
+	tq.Pop().Run(0)
 
 
 	if res := tq.Size(); res != 1 {
 	if res := tq.Size(); res != 1 {
 		t.Error("Unexpected result: ", res)
 		t.Error("Unexpected result: ", res)
 		return
 		return
 	}
 	}
 
 
-	tq.Pop().Run()
+	tq.Pop().Run(0)
 
 
 	if res := tq.Size(); res != 0 {
 	if res := tq.Size(); res != 0 {
 		t.Error("Unexpected result: ", res)
 		t.Error("Unexpected result: ", res)