Test 2
This commit is contained in:
parent
9a3514a468
commit
16b6696555
@ -45,10 +45,9 @@ class KMeans extends KMeansInterface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def classify(points: ParSeq[Point], means: ParSeq[Point]): ParMap[Point, ParSeq[Point]] = {
|
def classify(points: ParSeq[Point], means: ParSeq[Point]): ParMap[Point, ParSeq[Point]] = {
|
||||||
val meansWithPoints = points.groupBy( p => findClosest(p, means) )
|
val pointsMeanMap = points.par.groupBy(findClosest(_, means))
|
||||||
means
|
// So iterate over means get (empty) list and return map
|
||||||
.map(m => (m, if (meansWithPoints.contains(m)) List() ++ meansWithPoints(m) else List()))
|
means.par.map(mean => mean -> pointsMeanMap.getOrElse(mean, ParSeq())).toMap
|
||||||
.toMap
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def findAverage(oldMean: Point, points: ParSeq[Point]): Point = if (points.isEmpty) oldMean else {
|
def findAverage(oldMean: Point, points: ParSeq[Point]): Point = if (points.isEmpty) oldMean else {
|
||||||
@ -64,20 +63,21 @@ class KMeans extends KMeansInterface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def update(classified: ParMap[Point, ParSeq[Point]], oldMeans: ParSeq[Point]): ParSeq[Point] = {
|
def update(classified: ParMap[Point, ParSeq[Point]], oldMeans: ParSeq[Point]): ParSeq[Point] = {
|
||||||
oldMeans.map( mean => findAverage(mean, classified(mean)) )
|
oldMeans.par.map(oldMean => findAverage(oldMean, classified(oldMean)))
|
||||||
}
|
}
|
||||||
|
|
||||||
def converged(eta: Double, oldMeans: ParSeq[Point], newMeans: ParSeq[Point]): Boolean = {
|
def converged(eta: Double, oldMeans: ParSeq[Point], newMeans: ParSeq[Point]): Boolean = {
|
||||||
oldMeans
|
(oldMeans zip newMeans).forall{
|
||||||
.zip(newMeans)
|
case (oldMean, newMean) => oldMean.squareDistance(newMean) <= eta
|
||||||
.map(p => p._1.squareDistance(p._2) <= eta)
|
}
|
||||||
.forall( t => t )
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@tailrec
|
@tailrec
|
||||||
final def kMeans(points: ParSeq[Point], means: ParSeq[Point], eta: Double): ParSeq[Point] = {
|
final def kMeans(points: ParSeq[Point], means: ParSeq[Point], eta: Double): ParSeq[Point] = {
|
||||||
val freshMints = update(classify(points, means), means)
|
val classified = classify(points, means)
|
||||||
if (!converged(eta)(means, freshMints)) kMeans(points, freshMints, eta) else freshMints // your implementation need to be tail recursive
|
val newMeans = update(classified, means)
|
||||||
|
|
||||||
|
if (!converged(eta, means, newMeans)) kMeans(points, newMeans, eta) else newMeans
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user