158 lines
5.4 KiB
Python
Executable File
158 lines
5.4 KiB
Python
Executable File
|
|
import altair
|
|
from altair import Config, Color, Shape, Column, Row, Encoding, Scale, Axis
|
|
from random import randint
|
|
import os
|
|
import logging
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
from .tud import nominal_colors, full_colors
|
|
|
|
|
|
|
|
def filter_nulltime_json(data):
|
|
if 'time' in data:
|
|
data = data[data.time > 0]
|
|
return altair.pipe(data, altair.to_json)
|
|
|
|
altair.renderers.enable('notebook')
|
|
altair.data_transformers.register('json_logtime', filter_nulltime_json)
|
|
altair.data_transformers.enable('json_logtime')
|
|
|
|
def my_theme(*args, **kwargs):
|
|
return {
|
|
'range': {
|
|
'ordinal': altair.VgScheme('viridis'),
|
|
'ramp': {'scheme': 'viridis'}
|
|
}
|
|
}
|
|
|
|
altair.themes.register('my-theme', my_theme)
|
|
altair.themes.enable('my-theme')
|
|
|
|
|
|
class BaseMixin(Encoding):
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs.setdefault('scale', altair.Scale(zero=False))
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
class LogMixin(BaseMixin):
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs['scale'] = altair.Scale(type='log')
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
class X(altair.X, BaseMixin):
|
|
pass
|
|
|
|
|
|
class Y(altair.Y, BaseMixin):
|
|
pass
|
|
|
|
|
|
class LogX(altair.X, LogMixin):
|
|
pass
|
|
|
|
|
|
class LogY(altair.Y, LogMixin):
|
|
pass
|
|
|
|
|
|
class DataHandler(altair.Data):
|
|
|
|
def __init__(self, df):
|
|
self._filename = '.altair.json'
|
|
with open(self._filename, 'w') as f:
|
|
f.write(df.to_json())
|
|
super().__init__(url=self._filename)
|
|
|
|
|
|
class Chart(altair.Chart):
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def encode(self, *args, color=None, **kwargs):
|
|
if isinstance(color, str):
|
|
if color.endswith(':F'):
|
|
field = color[:-2]
|
|
color = color.replace(':F', ':N')
|
|
self.configure_scale(nominalColorRange=full_colors(len(set(self._data[field]))))
|
|
|
|
for arg in args:
|
|
if isinstance(arg, altair.X):
|
|
kwargs['x'] = arg
|
|
elif isinstance(arg, altair.Y):
|
|
kwargs['y'] = arg
|
|
return super().encode(color=color, **kwargs)
|
|
|
|
def to_mpl(self):
|
|
d = self.to_dict()
|
|
fmt = 'o' if d.get('mark', 'point') is 'point' else '-'
|
|
|
|
def encode(data, encoding, **kwargs):
|
|
logging.debug(str(kwargs))
|
|
if 'column' in encoding:
|
|
channel = encoding.pop('column')
|
|
ncols = len(data[channel.get('field')].unique())
|
|
for col, (column, df) in enumerate(data.groupby(channel.get('field'))):
|
|
ax = plt.gca() if col > 0 else None
|
|
plt.subplot(kwargs.get('nrows', 1), ncols, col + 1, sharey=ax).set_title(column)
|
|
encode(df, encoding.copy(), secondary_column=col > 0, **kwargs.copy())
|
|
elif 'color' in encoding:
|
|
channel = encoding.pop('color')
|
|
if channel.get('type') == 'quantitative':
|
|
colors = full_colors(len(data[channel.get('field')].unique()))
|
|
else:
|
|
colors = nominal_colors['b']
|
|
while len(colors) < len(data[channel.get('field')].unique()):
|
|
colors *= 2
|
|
for color, (column, df) in zip(colors, data.groupby(channel.get('field'))):
|
|
if 'label' in kwargs:
|
|
label = kwargs.pop('label') + ', {}'.format(column)
|
|
else:
|
|
label = str(column)
|
|
encode(df, encoding.copy(), color=color, label=label, **kwargs.copy())
|
|
elif 'shape' in encoding:
|
|
channel = encoding.pop('shape')
|
|
markers = ['h', 'v', 'o', 's', '^', 'D', '<', '>']
|
|
while len(markers) < len(data[channel.get('field')].unique()):
|
|
markers *= 2
|
|
logging.debug(str(data[channel.get('field')].unique()))
|
|
for marker, (column, df) in zip(markers, data.groupby(channel.get('field'))):
|
|
if 'label' in kwargs:
|
|
label = kwargs.pop('label') + ', {}'.format(column)
|
|
else:
|
|
label = str(column)
|
|
encode(df, encoding.copy(), marker=marker, label=label, **kwargs.copy())
|
|
|
|
else:
|
|
x_field = encoding.get('x').get('field')
|
|
y_field = encoding.get('y').get('field')
|
|
plt.xlabel(x_field)
|
|
if not kwargs.pop('secondary_column', False):
|
|
plt.ylabel(y_field)
|
|
else:
|
|
plt.tick_params(axis='y', which='both', labelleft='off', labelright='off')
|
|
if 'scale' in encoding.get('x'):
|
|
plt.xscale(encoding['x']['scale'].get('type', 'linear'))
|
|
if 'scale' in encoding.get('y'):
|
|
plt.yscale(encoding['y']['scale'].get('type', 'linear'))
|
|
plt.plot(data[x_field], data[y_field], fmt, **kwargs)
|
|
plt.legend(loc='best', fontsize='small')
|
|
|
|
encode(self._data, d.get('encoding'))
|
|
plt.tight_layout(pad=0.5)
|
|
|
|
|
|
class Arrhenius(Chart):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
# self.transform_data(calculate=[Formula(field='1000K/T', expr='1000/datum.T')])
|
|
self.data['1000 K / T'] = 1000 / self.data['T']
|
|
self.encode(x=X('1000 K / T'))
|