From 9a3514a46894b12637f80ae94b33c69ce1c915af Mon Sep 17 00:00:00 2001 From: Kali Live user Date: Sun, 15 Mar 2020 11:57:14 +0000 Subject: [PATCH] Test --- src/main/scala/kmeans/KMeans.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/main/scala/kmeans/KMeans.scala b/src/main/scala/kmeans/KMeans.scala index ced54ea..1f65829 100644 --- a/src/main/scala/kmeans/KMeans.scala +++ b/src/main/scala/kmeans/KMeans.scala @@ -45,7 +45,10 @@ class KMeans extends KMeansInterface { } def classify(points: ParSeq[Point], means: ParSeq[Point]): ParMap[Point, ParSeq[Point]] = { - ??? + val meansWithPoints = points.groupBy( p => findClosest(p, means) ) + means + .map(m => (m, if (meansWithPoints.contains(m)) List() ++ meansWithPoints(m) else List())) + .toMap } def findAverage(oldMean: Point, points: ParSeq[Point]): Point = if (points.isEmpty) oldMean else { @@ -61,16 +64,20 @@ class KMeans extends KMeansInterface { } def update(classified: ParMap[Point, ParSeq[Point]], oldMeans: ParSeq[Point]): ParSeq[Point] = { - ??? + oldMeans.map( mean => findAverage(mean, classified(mean)) ) } def converged(eta: Double, oldMeans: ParSeq[Point], newMeans: ParSeq[Point]): Boolean = { - ??? + oldMeans + .zip(newMeans) + .map(p => p._1.squareDistance(p._2) <= eta) + .forall( t => t ) } @tailrec final def kMeans(points: ParSeq[Point], means: ParSeq[Point], eta: Double): ParSeq[Point] = { - if (???) kMeans(???, ???, ???) else ??? // your implementation need to be tail recursive + val freshMints = update(classify(points, means), means) + if (!converged(eta)(means, freshMints)) kMeans(points, freshMints, eta) else freshMints // your implementation need to be tail recursive } }