From 05a83fb43b009391ef98ffdf30506c722821723c Mon Sep 17 00:00:00 2001 From: ArchUSB Date: Wed, 18 Mar 2020 17:05:51 +0100 Subject: [PATCH] HW 4 Done --- src/main/scala/barneshut/Simulator.scala | 17 +++- src/main/scala/barneshut/package.scala | 101 ++++++++++++++++++----- 2 files changed, 94 insertions(+), 24 deletions(-) diff --git a/src/main/scala/barneshut/Simulator.scala b/src/main/scala/barneshut/Simulator.scala index 7fdafb2..ebee1d7 100644 --- a/src/main/scala/barneshut/Simulator.scala +++ b/src/main/scala/barneshut/Simulator.scala @@ -12,11 +12,20 @@ import scala.collection.parallel.CollectionConverters._ class Simulator(val taskSupport: TaskSupport, val timeStats: TimeStatistics) { def updateBoundaries(boundaries: Boundaries, body: Body): Boundaries = { - ??? + boundaries.minX = Math.min(boundaries.minX, body.x); + boundaries.minY = Math.min(boundaries.minY, body.y); + boundaries.maxX = Math.max(boundaries.maxX, body.x); + boundaries.maxY = Math.max(boundaries.maxY, body.y); + boundaries } def mergeBoundaries(a: Boundaries, b: Boundaries): Boundaries = { - ??? + val bnd = new Boundaries + bnd.minX = Math.min(a.minX, b.minX) + bnd.minY = Math.min(a.minY, b.minY) + bnd.maxX = Math.max(a.maxX, b.maxX) + bnd.maxY = Math.max(a.maxY, b.maxY) + bnd } def computeBoundaries(bodies: coll.Seq[Body]): Boundaries = timeStats.timed("boundaries") { @@ -28,7 +37,7 @@ class Simulator(val taskSupport: TaskSupport, val timeStats: TimeStatistics) { def computeSectorMatrix(bodies: coll.Seq[Body], boundaries: Boundaries): SectorMatrix = timeStats.timed("matrix") { val parBodies = bodies.par parBodies.tasksupport = taskSupport - ??? + parBodies.aggregate(new SectorMatrix(boundaries, SECTOR_PRECISION))((accSM, b) => accSM += b, (sm1, sm2) => sm1.combine(sm2)) } def computeQuad(sectorMatrix: SectorMatrix): Quad = timeStats.timed("quad") { @@ -38,7 +47,7 @@ class Simulator(val taskSupport: TaskSupport, val timeStats: TimeStatistics) { def updateBodies(bodies: coll.Seq[Body], quad: Quad): coll.Seq[Body] = timeStats.timed("update") { val parBodies = bodies.par parBodies.tasksupport = taskSupport - ??? + parBodies.map(_.updated(quad)).seq } def eliminateOutliers(bodies: coll.Seq[Body], sectorMatrix: SectorMatrix, quad: Quad): coll.Seq[Body] = timeStats.timed("eliminate") { diff --git a/src/main/scala/barneshut/package.scala b/src/main/scala/barneshut/package.scala index 20b05b4..d9352d7 100644 --- a/src/main/scala/barneshut/package.scala +++ b/src/main/scala/barneshut/package.scala @@ -46,34 +46,71 @@ package object barneshut { } case class Empty(centerX: Float, centerY: Float, size: Float) extends Quad { - def massX: Float = ??? - def massY: Float = ??? - def mass: Float = ??? - def total: Int = ??? - def insert(b: Body): Quad = ??? + def massX: Float = centerX + def massY: Float = centerY + def mass: Float = 0 + def total: Int = 0 + def insert(b: Body): Quad = Leaf(centerX, centerY, size, Seq(b)) } case class Fork( nw: Quad, ne: Quad, sw: Quad, se: Quad ) extends Quad { - val centerX: Float = ??? - val centerY: Float = ??? - val size: Float = ??? - val mass: Float = ??? - val massX: Float = ??? - val massY: Float = ??? - val total: Int = ??? + + val centerX: Float = (nw.centerX + ne.centerX)/2 + val centerY: Float = (nw.centerY + sw.centerY)/2 + val size: Float = nw.size + ne.size + val mass: Float = Seq(nw, ne, sw, se).map(_.mass).sum + val massX: Float = if(mass == 0) centerX else Seq(nw, ne, sw, se).map(q => q.mass * q.massX).sum/mass + val massY: Float = if(mass == 0) centerY else Seq(nw, ne, sw, se).map(q => q.mass * q.massY).sum/mass + val total: Int = Seq(nw, ne, sw, se).map(_.total).sum; def insert(b: Body): Fork = { - ??? + if(b.x <= centerX){ //W + if(b.y <= centerY){ // NW + Fork(nw.insert(b), ne, sw, se) + } + else{ //SW + Fork(nw, ne, sw.insert(b), se) + } + } + else{ //EAST + if(b.y <= centerY){ //NE + Fork(nw, ne.insert(b), sw, se) + } + else{ //SE + Fork(nw, ne, sw, se.insert(b)) + } + } } } case class Leaf(centerX: Float, centerY: Float, size: Float, bodies: coll.Seq[Body]) extends Quad { - val (mass, massX, massY) = (??? : Float, ??? : Float, ??? : Float) - val total: Int = ??? - def insert(b: Body): Quad = ??? + val mass = bodies.map(_.mass).sum + val massX = bodies.map(b => b.mass * b.x).sum / mass + val massY = bodies.map(b => b.mass * b.y).sum / mass + val total: Int = bodies.length + def insert(b: Body): Quad = { + if(size > minimumSize){ + val wCorr = centerX - size/4 + val eCorr = centerX + size/4 + val nCorr = centerY - size/4 + val sCorr = centerY + size/4 + + val newSize = size/2 + + val fork = Fork( + Empty(wCorr, nCorr, newSize), + Empty(eCorr, nCorr, newSize), + Empty(wCorr, sCorr, newSize), + Empty(eCorr, sCorr, newSize)) + (bodies :+ b).foldLeft(fork)((f, b) => f.insert(b)) + } + else { + Leaf(centerX, centerY, size, bodies :+ b) + } + } } def minimumSize = 0.00001f @@ -123,9 +160,16 @@ package object barneshut { def traverse(quad: Quad): Unit = (quad: Quad) match { case Empty(_, _, _) => // no force - case Leaf(_, _, _, bodies) => + case Leaf(_, _, _, bodies) => bodies.foreach(b => addForce(b.mass, b.x, b.y)) // add force contribution of each body by calling addForce - case Fork(nw, ne, sw, se) => + case Fork(nw, ne, sw, se) => if(quad.size / distance(x, y, quad.centerX, quad.centerY) < theta) + addForce(quad.mass, quad.massX, quad.massY) + else { + traverse(nw) + traverse(ne) + traverse(sw) + traverse(se) + } // see if node is far enough from the body, // or recursion is needed } @@ -150,14 +194,31 @@ package object barneshut { for (i <- 0 until matrix.length) matrix(i) = new ConcBuffer def +=(b: Body): SectorMatrix = { - ??? + val x = Math.max(Math.min(b.x, boundaries.maxX), boundaries.minX) + val y = Math.max(Math.min(b.y, boundaries.maxY), boundaries.minY) + + //Distance from top-left corner + val dx = x - boundaries.minX; + val dy = y - boundaries.minY; + + //Corresponding sector + val xSect = (dx / sectorSize).toInt + val ySect = (dy / sectorSize).toInt + + apply(xSect, ySect) += b + this } def apply(x: Int, y: Int) = matrix(y * sectorPrecision + x) def combine(that: SectorMatrix): SectorMatrix = { - ??? + var i = 0; + while(i < matrix.length){ + matrix(i) = matrix(i).combine(that.matrix(i)); + i += 1 + } + this } def toQuad(parallelism: Int): Quad = {