8
8
9
9
package scala .concurrent
10
10
11
+ import java .util .ArrayDeque
11
12
import java .util .concurrent .Executor
12
13
import scala .annotation .tailrec
14
+ import scala .util .control .NonFatal
15
+
16
+ /**
17
+ * Marker trait to indicate that a Runnable is Batchable by BatchingExecutors
18
+ */
19
+ trait Batchable {
20
+ self : Runnable =>
21
+ }
13
22
14
23
/**
15
24
* Mixin trait for an Executor
@@ -39,79 +48,90 @@ import scala.annotation.tailrec
39
48
* WARNING: The underlying Executor's execute-method must not execute the submitted Runnable
40
49
* in the calling thread synchronously. It must enqueue/handoff the Runnable.
41
50
*/
42
- private [concurrent] trait BatchingExecutor extends Executor {
51
+ private [concurrent] trait BatchingExecutor extends Executor {
52
+ private [this ] final val _tasksLocal = new ThreadLocal [Batch ]()
53
+
54
+ private [this ] final class Batch (capacity : Int ) extends ArrayDeque [Runnable ](capacity) with Runnable with BlockContext with (BlockContext => Unit ) {
55
+ private [this ] final var parentBlockContext : BlockContext = _
43
56
44
- // invariant: if "_tasksLocal.get ne null" then we are inside BatchingRunnable.run; if it is null, we are outside
45
- private [this ] val _tasksLocal = new ThreadLocal [List [Runnable ]]()
57
+ def this (r : Runnable , capacity : Int ) = {
58
+ this (capacity)
59
+ addLast(r)
60
+ }
61
+
62
+ final def executor : BatchingExecutor = BatchingExecutor .this
46
63
47
- private class Batch (val initial : List [Runnable ]) extends Runnable with BlockContext {
48
- private [this ] var parentBlockContext : BlockContext = _
49
64
// this method runs in the delegate ExecutionContext's thread
50
- override def run (): Unit = {
51
- require(_tasksLocal.get eq null )
52
-
53
- val prevBlockContext = BlockContext .current
54
- BlockContext .withBlockContext(this ) {
55
- try {
56
- parentBlockContext = prevBlockContext
57
-
58
- @ tailrec def processBatch (batch : List [Runnable ]): Unit = batch match {
59
- case Nil => ()
60
- case head :: tail =>
61
- _tasksLocal set tail
62
- try {
63
- head.run()
64
- } catch {
65
- case t : Throwable =>
66
- // if one task throws, move the
67
- // remaining tasks to another thread
68
- // so we can throw the exception
69
- // up to the invoking executor
70
- val remaining = _tasksLocal.get
71
- _tasksLocal set Nil
72
- unbatchedExecute(new Batch (remaining)) // TODO what if this submission fails?
73
- throw t // rethrow
74
- }
75
- processBatch(_tasksLocal.get) // since head.run() can add entries, always do _tasksLocal.get here
76
- }
77
-
78
- processBatch(initial)
79
- } finally {
80
- _tasksLocal.remove()
81
- parentBlockContext = null
82
- }
65
+ override final def run (): Unit = BlockContext .usingBlockContext(this )(this )
66
+
67
+ override final def apply (prevBlockContext : BlockContext ): Unit = {
68
+ // This invariant needs to hold: require(_tasksLocal.get eq null)
69
+ parentBlockContext = prevBlockContext
70
+ try {
71
+ _tasksLocal.set(this )
72
+ runAll()
73
+ _tasksLocal.remove() // Will be cleared in the throwing-case by runAll()
74
+ } finally {
75
+ parentBlockContext = null
83
76
}
84
77
}
85
78
79
+ @ tailrec private [this ] final def runAll (): Unit = {
80
+ val next = pollLast()
81
+ if (next ne null ) {
82
+ try next.run() catch {
83
+ case t : Throwable =>
84
+ parentBlockContext = null // Need to reset this before re-submitting it
85
+ _tasksLocal.remove() // If unbatchedExecute runs synchronously
86
+ handleRunFailure(t)
87
+ }
88
+ runAll()
89
+ }
90
+ }
91
+
92
+ private [this ] final def handleRunFailure (cause : Throwable ): Nothing =
93
+ if (NonFatal (cause) || cause.isInstanceOf [InterruptedException ]) {
94
+ try unbatchedExecute(this ) catch {
95
+ case inner : Throwable =>
96
+ if (NonFatal (inner)) {
97
+ val e = new ExecutionException (" Non-fatal error occurred and resubmission failed, see suppressed exception." , cause)
98
+ e.addSuppressed(inner)
99
+ throw e
100
+ } else throw inner
101
+ }
102
+ throw cause
103
+ } else throw cause
104
+
86
105
override def blockOn [T ](thunk : => T )(implicit permission : CanAwait ): T = {
87
- // if we know there will be blocking, we don't want to keep tasks queued up because it could deadlock.
88
- {
89
- val tasks = _tasksLocal.get
90
- _tasksLocal set Nil
91
- if ((tasks ne null ) && tasks.nonEmpty )
92
- unbatchedExecute(new Batch (tasks) )
106
+ val pbc = parentBlockContext
107
+ if ( ! isEmpty) { // if we know there will be blocking, we don't want to keep tasks queued up because it could deadlock.
108
+ val b = new Batch (math.max( 4 , this .size))
109
+ b.addAll( this )
110
+ this .clear( )
111
+ unbatchedExecute(b )
93
112
}
94
113
95
- // now delegate the blocking to the previous BC
96
- require(parentBlockContext ne null )
97
- parentBlockContext.blockOn(thunk)
114
+ if (pbc ne null ) pbc.blockOn(thunk) // now delegate the blocking to the previous BC
115
+ else {
116
+ try thunk finally throw new IllegalStateException (" BUG in BatchingExecutor.Batch: parentBlockContext is null" )
117
+ }
98
118
}
99
119
}
100
120
101
121
protected def unbatchedExecute (r : Runnable ): Unit
102
122
103
- override def execute (runnable : Runnable ): Unit = {
104
- if (batchable(runnable)) { // If we can batch the runnable
105
- _tasksLocal.get match {
106
- case null => unbatchedExecute(new Batch (runnable :: Nil )) // If we aren't in batching mode yet, enqueue batch
107
- case some => _tasksLocal.set(runnable :: some) // If we are already in batching mode, add to batch
108
- }
109
- } else unbatchedExecute(runnable) // If not batchable, just delegate to underlying
123
+ private [this ] final def batchedExecute (runnable : Runnable ): Unit = {
124
+ val b = _tasksLocal.get
125
+ if (b ne null ) b.addLast(runnable)
126
+ else unbatchedExecute(new Batch (runnable, 4 ))
110
127
}
111
128
112
- /** Override this to define which runnables will be batched. */
113
- def batchable (runnable : Runnable ): Boolean = runnable match {
114
- case _ : OnCompleteRunnable => true
115
- case _ => false
116
- }
117
- }
129
+ override def execute (runnable : Runnable ): Unit =
130
+ if (batchable(runnable)) batchedExecute(runnable)
131
+ else unbatchedExecute(runnable)
132
+
133
+ /** Override this to define which runnables will be batched.
134
+ * By default it tests the Runnable for being an instance of [Batchable].
135
+ **/
136
+ protected def batchable (runnable : Runnable ): Boolean = runnable.isInstanceOf [Batchable ]
137
+ }
0 commit comments