update number of clusters
This commit is contained in:
parent
0dea7a2cdb
commit
bee58a107b
@ -5,7 +5,6 @@ try:
|
|||||||
from openpifpaf.network.nets import cli as openpifpaf_cli
|
from openpifpaf.network.nets import cli as openpifpaf_cli
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from openpifpaf.network.factory import cli as openpifpaf_cli
|
from openpifpaf.network.factory import cli as openpifpaf_cli
|
||||||
from openpifpaf.network import nets
|
|
||||||
from openpifpaf import decoder
|
from openpifpaf import decoder
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -59,7 +59,7 @@ class Trainer:
|
|||||||
self.lr = lr
|
self.lr = lr
|
||||||
self.sched_step = sched_step
|
self.sched_step = sched_step
|
||||||
self.sched_gamma = sched_gamma
|
self.sched_gamma = sched_gamma
|
||||||
self.clusters = ['10', '20', '30', '50', '>50']
|
self.clusters = ['10', '20', '30', '40']
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.n_stage = n_stage
|
self.n_stage = n_stage
|
||||||
self.dir_out = dir_out
|
self.dir_out = dir_out
|
||||||
|
|||||||
@ -20,14 +20,14 @@ def append_cluster(dic_jo, phase, xx, ys, kps):
|
|||||||
dic_jo[phase]['clst']['30']['kps'].append(kps)
|
dic_jo[phase]['clst']['30']['kps'].append(kps)
|
||||||
dic_jo[phase]['clst']['30']['X'].append(xx)
|
dic_jo[phase]['clst']['30']['X'].append(xx)
|
||||||
dic_jo[phase]['clst']['30']['Y'].append(ys)
|
dic_jo[phase]['clst']['30']['Y'].append(ys)
|
||||||
elif ys[3] < 50:
|
elif ys[3] <= 40:
|
||||||
dic_jo[phase]['clst']['50']['kps'].append(kps)
|
dic_jo[phase]['clst']['40']['kps'].append(kps)
|
||||||
dic_jo[phase]['clst']['50']['X'].append(xx)
|
dic_jo[phase]['clst']['40']['X'].append(xx)
|
||||||
dic_jo[phase]['clst']['50']['Y'].append(ys)
|
dic_jo[phase]['clst']['40']['Y'].append(ys)
|
||||||
else:
|
else:
|
||||||
dic_jo[phase]['clst']['>50']['kps'].append(kps)
|
dic_jo[phase]['clst']['>40']['kps'].append(kps)
|
||||||
dic_jo[phase]['clst']['>50']['X'].append(xx)
|
dic_jo[phase]['clst']['>40']['X'].append(xx)
|
||||||
dic_jo[phase]['clst']['>50']['Y'].append(ys)
|
dic_jo[phase]['clst']['>40']['Y'].append(ys)
|
||||||
|
|
||||||
|
|
||||||
def get_task_error(dd):
|
def get_task_error(dd):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user