Initial project version

This commit is contained in:
sebastiankloth
2022-04-11 11:01:13 +02:00
commit f40f2badd8
18 changed files with 2012 additions and 0 deletions

157
tudplot/altair.py Executable file
View File

@ -0,0 +1,157 @@
import altair
from altair import Config, Color, Shape, Column, Row, Encoding, Scale, Axis
from random import randint
import os
import logging
import matplotlib.pyplot as plt
from .tud import nominal_colors, full_colors
def filter_nulltime_json(data):
if 'time' in data:
data = data[data.time > 0]
return altair.pipe(data, altair.to_json)
altair.renderers.enable('notebook')
altair.data_transformers.register('json_logtime', filter_nulltime_json)
altair.data_transformers.enable('json_logtime')
def my_theme(*args, **kwargs):
return {
'range': {
'ordinal': altair.VgScheme('viridis'),
'ramp': {'scheme': 'viridis'}
}
}
altair.themes.register('my-theme', my_theme)
altair.themes.enable('my-theme')
class BaseMixin(Encoding):
def __init__(self, *args, **kwargs):
kwargs.setdefault('scale', altair.Scale(zero=False))
super().__init__(*args, **kwargs)
class LogMixin(BaseMixin):
def __init__(self, *args, **kwargs):
kwargs['scale'] = altair.Scale(type='log')
super().__init__(*args, **kwargs)
class X(altair.X, BaseMixin):
pass
class Y(altair.Y, BaseMixin):
pass
class LogX(altair.X, LogMixin):
pass
class LogY(altair.Y, LogMixin):
pass
class DataHandler(altair.Data):
def __init__(self, df):
self._filename = '.altair.json'
with open(self._filename, 'w') as f:
f.write(df.to_json())
super().__init__(url=self._filename)
class Chart(altair.Chart):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def encode(self, *args, color=None, **kwargs):
if isinstance(color, str):
if color.endswith(':F'):
field = color[:-2]
color = color.replace(':F', ':N')
self.configure_scale(nominalColorRange=full_colors(len(set(self._data[field]))))
for arg in args:
if isinstance(arg, altair.X):
kwargs['x'] = arg
elif isinstance(arg, altair.Y):
kwargs['y'] = arg
return super().encode(color=color, **kwargs)
def to_mpl(self):
d = self.to_dict()
fmt = 'o' if d.get('mark', 'point') is 'point' else '-'
def encode(data, encoding, **kwargs):
logging.debug(str(kwargs))
if 'column' in encoding:
channel = encoding.pop('column')
ncols = len(data[channel.get('field')].unique())
for col, (column, df) in enumerate(data.groupby(channel.get('field'))):
ax = plt.gca() if col > 0 else None
plt.subplot(kwargs.get('nrows', 1), ncols, col + 1, sharey=ax).set_title(column)
encode(df, encoding.copy(), secondary_column=col > 0, **kwargs.copy())
elif 'color' in encoding:
channel = encoding.pop('color')
if channel.get('type') == 'quantitative':
colors = full_colors(len(data[channel.get('field')].unique()))
else:
colors = nominal_colors['b']
while len(colors) < len(data[channel.get('field')].unique()):
colors *= 2
for color, (column, df) in zip(colors, data.groupby(channel.get('field'))):
if 'label' in kwargs:
label = kwargs.pop('label') + ', {}'.format(column)
else:
label = str(column)
encode(df, encoding.copy(), color=color, label=label, **kwargs.copy())
elif 'shape' in encoding:
channel = encoding.pop('shape')
markers = ['h', 'v', 'o', 's', '^', 'D', '<', '>']
while len(markers) < len(data[channel.get('field')].unique()):
markers *= 2
logging.debug(str(data[channel.get('field')].unique()))
for marker, (column, df) in zip(markers, data.groupby(channel.get('field'))):
if 'label' in kwargs:
label = kwargs.pop('label') + ', {}'.format(column)
else:
label = str(column)
encode(df, encoding.copy(), marker=marker, label=label, **kwargs.copy())
else:
x_field = encoding.get('x').get('field')
y_field = encoding.get('y').get('field')
plt.xlabel(x_field)
if not kwargs.pop('secondary_column', False):
plt.ylabel(y_field)
else:
plt.tick_params(axis='y', which='both', labelleft='off', labelright='off')
if 'scale' in encoding.get('x'):
plt.xscale(encoding['x']['scale'].get('type', 'linear'))
if 'scale' in encoding.get('y'):
plt.yscale(encoding['y']['scale'].get('type', 'linear'))
plt.plot(data[x_field], data[y_field], fmt, **kwargs)
plt.legend(loc='best', fontsize='small')
encode(self._data, d.get('encoding'))
plt.tight_layout(pad=0.5)
class Arrhenius(Chart):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# self.transform_data(calculate=[Formula(field='1000K/T', expr='1000/datum.T')])
self.data['1000 K / T'] = 1000 / self.data['T']
self.encode(x=X('1000 K / T'))