324 lines
8.7 KiB
Python
324 lines
8.7 KiB
Python
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# ==============================================================================
|
|
|
|
import time
|
|
import visdom
|
|
import sys
|
|
import numpy as np
|
|
import os
|
|
import socket
|
|
from . import Process
|
|
|
|
vis = None
|
|
port = None
|
|
visdom_fail_count = 0
|
|
|
|
|
|
def port_used(port):
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
result = sock.connect_ex(('127.0.0.1', port))
|
|
if result == 0:
|
|
sock.close()
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def alloc_port(start_from=7000):
|
|
while True:
|
|
if port_used(start_from):
|
|
print("Port already used: %d" % start_from)
|
|
start_from += 1
|
|
else:
|
|
return start_from
|
|
|
|
|
|
def wait_for_port(port, timeout=5):
|
|
star_time = time.time()
|
|
while not port_used(port):
|
|
if time.time() - star_time > timeout:
|
|
return False
|
|
|
|
time.sleep(0.1)
|
|
return True
|
|
|
|
|
|
def start(on_port=None):
|
|
global vis
|
|
global port
|
|
global visdom_fail_count
|
|
assert vis is None, "Cannot start more than 1 visdom servers."
|
|
|
|
if visdom_fail_count>=3:
|
|
return
|
|
|
|
port = alloc_port() if on_port is None else on_port
|
|
|
|
print("Starting Visdom server on %d" % port)
|
|
Process.run("%s -m visdom.server -p %d" % (sys.executable, port))
|
|
if not wait_for_port(port):
|
|
print("ERROR: failed to start Visdom server. Server not responding.")
|
|
visdom_fail_count += 1
|
|
return
|
|
print("Done.")
|
|
|
|
vis = visdom.Visdom(port=port)
|
|
|
|
|
|
def _start_if_not_running():
|
|
if vis is None:
|
|
start()
|
|
|
|
|
|
def save_heatmap(dir, title, img):
|
|
if dir is not None:
|
|
fname = os.path.join(dir, title.replace(" ", "_") + ".npy")
|
|
d = os.path.dirname(fname)
|
|
os.makedirs(d, exist_ok=True)
|
|
np.save(fname, img)
|
|
|
|
|
|
class Plot2D:
|
|
TO_SAVE = ["x", "y", "curr_accu", "curr_cnt", "legend"]
|
|
|
|
def __init__(self, name, store_interval=1, legend=None, xlabel=None, ylabel=None):
|
|
_start_if_not_running()
|
|
|
|
self.x = []
|
|
self.y = []
|
|
self.store_interval = store_interval
|
|
self.curr_accu = None
|
|
self.curr_cnt = 0
|
|
self.name = name
|
|
self.legend = legend
|
|
self.xlabel = xlabel
|
|
self.ylabel = ylabel
|
|
|
|
self.replot = False
|
|
self.visplot = None
|
|
self.last_vis_update_pos = 0
|
|
|
|
def set_legend(self, legend):
|
|
if self.legend != legend:
|
|
self.legend = legend
|
|
self.replot = True
|
|
|
|
def _send_update(self):
|
|
if not self.x or vis is None:
|
|
return
|
|
|
|
if self.visplot is None or self.replot:
|
|
opts = {
|
|
"title": self.name
|
|
}
|
|
if self.xlabel:
|
|
opts["xlabel"] = self.xlabel
|
|
if self.ylabel:
|
|
opts["ylabel"] = self.ylabel
|
|
|
|
if self.legend is not None:
|
|
opts["legend"] = self.legend
|
|
|
|
self.visplot = vis.line(X=np.asfarray(self.x), Y=np.asfarray(self.y), opts=opts, win=self.visplot)
|
|
self.replot = False
|
|
else:
|
|
vis.line(
|
|
X=np.asfarray(self.x[self.last_vis_update_pos:]),
|
|
Y=np.asfarray(self.y[self.last_vis_update_pos:]),
|
|
win=self.visplot,
|
|
update='append'
|
|
)
|
|
|
|
self.last_vis_update_pos = len(self.x) - 1
|
|
|
|
def add_point(self, x, y):
|
|
if not isinstance(y, list):
|
|
y = [y]
|
|
|
|
|
|
|
|
if self.curr_accu is None:
|
|
self.curr_accu = [0.0] * len(y)
|
|
|
|
if len(self.curr_accu) < len(y):
|
|
# The number of curves increased.
|
|
need_to_add = (len(y) - len(self.curr_accu))
|
|
|
|
self.curr_accu += [0.0] * need_to_add
|
|
count = len(self.x)
|
|
if count>0:
|
|
self.replot = True
|
|
if not isinstance(self.x[0], list):
|
|
self.x = [[x] for x in self.x]
|
|
self.y = [[y] for y in self.y]
|
|
|
|
nan = float("nan")
|
|
for a in self.x:
|
|
a += [nan] * need_to_add
|
|
for a in self.y:
|
|
a += [nan] * need_to_add
|
|
elif len(self.curr_accu) > len(y):
|
|
y = y[:] + [float("nan")] * (len(self.curr_accu) - len(y))
|
|
|
|
self.curr_accu = [self.curr_accu[i] + y[i] for i in range(len(y))]
|
|
self.curr_cnt += 1
|
|
if self.curr_cnt == self.store_interval:
|
|
if len(y) > 1:
|
|
self.x.append([x] * len(y))
|
|
self.y.append([a / self.curr_cnt for a in self.curr_accu])
|
|
else:
|
|
self.x.append(x)
|
|
self.y.append(self.curr_accu[0] / self.curr_cnt)
|
|
|
|
self.curr_accu = [0.0] * len(y)
|
|
self.curr_cnt = 0
|
|
|
|
self._send_update()
|
|
|
|
def state_dict(self):
|
|
s = {k: self.__dict__[k] for k in self.TO_SAVE}
|
|
return s
|
|
|
|
def load_state_dict(self, state):
|
|
if self.legend is not None:
|
|
# Load legend only if not given in the constructor.
|
|
state["legend"] = self.legend
|
|
self.__dict__.update(state)
|
|
self.last_vis_update_pos = 0
|
|
|
|
# Read old format
|
|
if not isinstance(self.curr_accu, list) and self.curr_accu is not None:
|
|
self.curr_accu = [self.curr_accu]
|
|
|
|
self._send_update()
|
|
|
|
|
|
class Image:
|
|
def __init__(self, title, dumpdir=None):
|
|
_start_if_not_running()
|
|
|
|
self.win = None
|
|
self.opts = dict(title=title)
|
|
self.dumpdir = dumpdir
|
|
|
|
def set_dump_dir(self, dumpdir):
|
|
self.dumpdir = dumpdir
|
|
|
|
def draw(self, img):
|
|
if vis is None:
|
|
return
|
|
|
|
if isinstance(img, list):
|
|
if self.win is None:
|
|
self.win = vis.images(img, opts=self.opts)
|
|
else:
|
|
vis.images(img, win=self.win, opts=self.opts)
|
|
else:
|
|
if len(img.shape)==2:
|
|
img = np.expand_dims(img, 0)
|
|
elif img.shape[-1] in [1,3] and img.shape[0] not in [1,3]:
|
|
# cv2 image
|
|
img = img.transpose(2,0,1)
|
|
img = img[::-1]
|
|
|
|
if img.dtype==np.uint8:
|
|
img = img.astype(np.float32)/255
|
|
|
|
self.opts["width"] = img.shape[2]
|
|
self.opts["height"] = img.shape[1]
|
|
|
|
save_heatmap(self.dumpdir, self.opts["title"], img)
|
|
if self.win is None:
|
|
self.win = vis.image(img, opts=self.opts)
|
|
else:
|
|
vis.image(img, win=self.win, opts=self.opts)
|
|
|
|
def __call__(self, img):
|
|
self.draw(img)
|
|
|
|
|
|
class Text:
|
|
def __init__(self, title):
|
|
_start_if_not_running()
|
|
|
|
self.win = None
|
|
self.title = title
|
|
self.curr_text = ""
|
|
|
|
def set(self, text):
|
|
self.curr_text = text
|
|
|
|
if vis is None:
|
|
return
|
|
|
|
if self.win is None:
|
|
self.win = vis.text(text, opts=dict(
|
|
title=self.title
|
|
))
|
|
else:
|
|
vis.text(text, win=self.win)
|
|
|
|
def state_dict(self):
|
|
return {"text": self.curr_text}
|
|
|
|
def load_state_dict(self, state):
|
|
self.set(state["text"])
|
|
|
|
def __call__(self, text):
|
|
self.set(text)
|
|
|
|
|
|
class Heatmap:
|
|
def __init__(self, title, min=None, max=None, xlabel=None, ylabel=None, colormap='Viridis', dumpdir=None):
|
|
_start_if_not_running()
|
|
|
|
self.win = None
|
|
self.opt = dict(title=title, colormap=colormap)
|
|
self.dumpdir = dumpdir
|
|
if min is not None:
|
|
self.opt["xmin"] = min
|
|
if max is not None:
|
|
self.opt["xmax"] = max
|
|
|
|
if xlabel:
|
|
self.opt["xlabel"] = xlabel
|
|
if ylabel:
|
|
self.opt["ylabel"] = ylabel
|
|
|
|
def set_dump_dir(self, dumpdir):
|
|
self.dumpdir = dumpdir
|
|
|
|
def draw(self, img):
|
|
if vis is None:
|
|
return
|
|
|
|
o = self.opt.copy()
|
|
if "xmin" not in o:
|
|
o["xmin"] = float(img.min())
|
|
|
|
if "xmax" not in o:
|
|
o["xmax"] = float(img.max())
|
|
|
|
save_heatmap(self.dumpdir, o["title"], img)
|
|
|
|
if self.win is None:
|
|
self.win = vis.heatmap(img, opts=o)
|
|
else:
|
|
vis.heatmap(img, win=self.win, opts=o)
|
|
|
|
def __call__(self, img):
|
|
self.draw(img)
|