59 lines
1.4 KiB
Scala
59 lines
1.4 KiB
Scala
import java.util.concurrent._
|
|
import scala.util.DynamicVariable
|
|
|
|
import org.scalameter._
|
|
|
|
package object reductions {
|
|
val forkJoinPool = new ForkJoinPool
|
|
|
|
abstract class TaskScheduler {
|
|
def schedule[T](body: => T): ForkJoinTask[T]
|
|
def parallel[A, B](taskA: => A, taskB: => B): (A, B) = {
|
|
val right = task {
|
|
taskB
|
|
}
|
|
val left = taskA
|
|
(left, right.join())
|
|
}
|
|
}
|
|
|
|
class DefaultTaskScheduler extends TaskScheduler {
|
|
def schedule[T](body: => T): ForkJoinTask[T] = {
|
|
val t = new RecursiveTask[T] {
|
|
def compute = body
|
|
}
|
|
Thread.currentThread match {
|
|
case wt: ForkJoinWorkerThread =>
|
|
t.fork()
|
|
case _ =>
|
|
forkJoinPool.execute(t)
|
|
}
|
|
t
|
|
}
|
|
}
|
|
|
|
val scheduler =
|
|
new DynamicVariable[TaskScheduler](new DefaultTaskScheduler)
|
|
|
|
def task[T](body: => T): ForkJoinTask[T] = {
|
|
scheduler.value.schedule(body)
|
|
}
|
|
|
|
def parallel[A, B](taskA: => A, taskB: => B): (A, B) = {
|
|
scheduler.value.parallel(taskA, taskB)
|
|
}
|
|
|
|
def parallel[A, B, C, D](taskA: => A, taskB: => B, taskC: => C, taskD: => D): (A, B, C, D) = {
|
|
val ta = task { taskA }
|
|
val tb = task { taskB }
|
|
val tc = task { taskC }
|
|
val td = taskD
|
|
(ta.join(), tb.join(), tc.join(), td)
|
|
}
|
|
|
|
// Workaround Dotty's handling of the existential type KeyValue
|
|
implicit def keyValueCoerce[T](kv: (Key[T], T)): KeyValue = {
|
|
kv.asInstanceOf[KeyValue]
|
|
}
|
|
}
|