From 16b66965559195f80cf2741e8a7dbed0cdd4d0a4 Mon Sep 17 00:00:00 2001 From: Kali Live user Date: Sun, 15 Mar 2020 12:02:25 +0000 Subject: [PATCH] Test 2 --- src/main/scala/kmeans/KMeans.scala | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/main/scala/kmeans/KMeans.scala b/src/main/scala/kmeans/KMeans.scala index 1f65829..6e26635 100644 --- a/src/main/scala/kmeans/KMeans.scala +++ b/src/main/scala/kmeans/KMeans.scala @@ -45,10 +45,9 @@ class KMeans extends KMeansInterface { } def classify(points: ParSeq[Point], means: ParSeq[Point]): ParMap[Point, ParSeq[Point]] = { - val meansWithPoints = points.groupBy( p => findClosest(p, means) ) - means - .map(m => (m, if (meansWithPoints.contains(m)) List() ++ meansWithPoints(m) else List())) - .toMap + val pointsMeanMap = points.par.groupBy(findClosest(_, means)) + // So iterate over means get (empty) list and return map + means.par.map(mean => mean -> pointsMeanMap.getOrElse(mean, ParSeq())).toMap } def findAverage(oldMean: Point, points: ParSeq[Point]): Point = if (points.isEmpty) oldMean else { @@ -64,20 +63,21 @@ class KMeans extends KMeansInterface { } def update(classified: ParMap[Point, ParSeq[Point]], oldMeans: ParSeq[Point]): ParSeq[Point] = { - oldMeans.map( mean => findAverage(mean, classified(mean)) ) + oldMeans.par.map(oldMean => findAverage(oldMean, classified(oldMean))) } def converged(eta: Double, oldMeans: ParSeq[Point], newMeans: ParSeq[Point]): Boolean = { - oldMeans - .zip(newMeans) - .map(p => p._1.squareDistance(p._2) <= eta) - .forall( t => t ) + (oldMeans zip newMeans).forall{ + case (oldMean, newMean) => oldMean.squareDistance(newMean) <= eta + } } @tailrec final def kMeans(points: ParSeq[Point], means: ParSeq[Point], eta: Double): ParSeq[Point] = { - val freshMints = update(classify(points, means), means) - if (!converged(eta)(means, freshMints)) kMeans(points, freshMints, eta) else freshMints // your implementation need to be tail recursive + val classified = classify(points, means) + val newMeans = update(classified, means) + + if (!converged(eta, means, newMeans)) kMeans(points, newMeans, eta) else newMeans } }