"""
Read athena vtk file
"""
import os
import os.path as osp
from packaging.version import Version
import glob, struct
import numpy as np
import xarray as xr
import astropy.constants as ac
import astropy.units as au
from .athena_read import vtk
from ..util.units import Units
def read_vtk_athenapp(filenames):
x1f = []
x2f = []
x3f = []
dat = dict()
for i,filename in enumerate(filenames):
x1f_, x2f_, x3f_, dat_ = vtk(filename)
x1f.append(x1f_)
x2f.append(x2f_)
x3f.append(x3f_)
if i == 0:
for k in dat_.keys():
dat[k] = []
for k in dat_.keys():
dat[k].append(dat_[k])
for k in dat.keys():
dat[k] = np.array(dat[k])
dat['x1f'] = np.array(x1f)
dat['x2f'] = np.array(x2f)
dat['x3f'] = np.array(x3f)
return dat
[docs]def read_vtk(filename, id0_only=False):
"""Convenience wrapper function to read Athena vtk output file
using AthenaDataSet class.
Parameters
----------
filename : str
Name of the file to open, including extension
id0_only : bool
Flag to enforce to read vtk file in id0 directory only.
Default value is False.
Returns
-------
ds : AthenaDataSet
"""
return AthenaDataSet(filename, id0_only=id0_only)
[docs]class AthenaDataSet(object):
def __init__(self, filename, id0_only=False, units=Units(), dfi=None):
"""Class to read athena vtk file.
Parameters
----------
filename : string
Name of the file to open, including extension
id0_only : bool
Flag to enforce to read vtk file in id0 directory only.
Default value is False.
units : Units
pyathena Units object (used for reading derived fields)
dfi : dict
Dictionary containing derived fields info
"""
if not osp.exists(filename):
raise IOError(('File does not exist: {0:s}'.format(filename)))
dirname, problem_id, num, suffix, ext, mpi_mode, nonzero_id = \
_parse_filename(filename)
if id0_only:
mpi_mode = False
self.dirname = dirname
self.problem_id = problem_id
self.num = int(num)
self.suffix = suffix
self.ext = ext
self.mpi_mode = mpi_mode
self.fnames = [filename]
self.u = units
self.dfi = dfi
if dfi is not None:
self.derived_field_list = list(dfi.keys())
else:
self.derived_field_list = None
# Find all vtk file names and add to flist
if mpi_mode:
if self.suffix is None:
fname_pattern = osp.join(dirname, 'id*/{0:s}-id*.{1:s}.{2:s}'.\
format(problem_id, num, ext))
else:
fname_pattern = osp.join(dirname, 'id*/{0:s}-id*.{1:s}.{2:s}.{3:s}'.\
format(problem_id, num, suffix, ext))
fnames = glob.glob(fname_pattern)
self.fnames += fnames
if nonzero_id:
from collections import OrderedDict
self.fnames = list(OrderedDict.fromkeys(self.fnames))
self.grid = self._set_grid()
self.domain = self._set_domain()
self.set_region()
# Need separte field_map for different grids
if self.domain['all_grid_equal']:
self._field_map = _set_field_map(self.grid[0])
for g in self.grid:
g['field_map'] = self._field_map
else:
for g in self.grid:
g['field_map'] = _set_field_map(g)
self._field_map = self.grid[0]['field_map']
self.field_list = list(self._field_map.keys())
[docs] def get_cc_pos(self):
"""Compute cell center positions
Returns
-------
xc : numpy array
Unique cell-centered cooridnates
"""
xc = dict()
for axis, le, re, dx in zip(('x', 'y', 'z'), \
self.region['gle'], self.region['gre'], self.domain['dx']):
xc[axis] = np.arange(le + 0.5*dx, re + 0.5*dx, dx)
return xc
[docs] def get_cc_ijk(s, x1, x2, x3):
"""Compute closest cell-center integer indices in which a particle resides
Parameters
----------
x1,x2,x3 : array of floats
Particle position
"""
domain = s.domain
le1,le2,le3 = domain['le']
dx1,dx2,dx3 = domain['dx']
return (np.floor((x1 - le1)/dx1).astype(int),
np.floor((x2 - le2)/dx2).astype(int),
np.floor((x3 - le3)/dx3).astype(int))
[docs] def set_region(self, le=None, re=None):
"""Set region and find overlapping grids.
"""
if le is None:
le = self.domain['le']
if re is None:
re = self.domain['re']
le = np.array(le)
re = np.array(re)
if (re < le).any():
raise ValueError('Check left/right edge.')
# Find all overlapping grids and their edges
gle_all = [] # grid left edge
gre_all = [] # grid right edge
gidx = [] # grid indices that belongs to this region
#print(self.grid,len(self.grid))
for i, g in enumerate(self.grid):
if (g['re'] >= le).all() and (g['le'] <= re).all():
gidx.append(i)
gle_all.append(g['le'])
gre_all.append(g['re'])
gidx = np.array(gidx)
if len(gidx) == 0:
raise ValueError('Check left/right edges:', le, re, \
' Domain left/right edges are ', \
self.domain['le'], self.domain['re'])
gle_all = np.array(gle_all)
gre_all = np.array(gre_all)
# Find unique grid left/right edge coordinates
gleu = [np.unique(gle_all[:, i]) for i in range(3)]
greu = [np.unique(gre_all[:, i]) for i in range(3)]
# Min/Max of gleu/greu
gle = np.array([gle.min() for gle in gleu])
gre = np.array([gre.max() for gre in greu])
# Number of grids in each direction
NGrid = np.array([len(gleu_) for gleu_ in gleu])
# Number of cells per grid
Nxg = np.concatenate([(greu_ - gleu_)/dx for greu_, gleu_, dx in \
zip(greu, gleu, self.domain['dx'])])
Nxg = np.array(np.array_split(Nxg, NGrid.cumsum()[:-1]),
dtype=object).tolist()
# Since floating point arithmetic may result in incorrect results, need
# to round to the nearest integer
for i, Nxg_ in enumerate(Nxg):
Nxg[i] = np.rint(Nxg_).astype(int)
# Number of cells in region
Nxr = np.empty(len(Nxg), dtype=int)
for i, Nxg_ in enumerate(Nxg):
Nxr[i] = np.sum(Nxg_)
#print(gidx,NGrid)
assert len(gidx) == NGrid.prod(),\
print('Unexpected error: Number of grids {0:d} != '.format(len(gidx)) +
'number of unique edges {0:d}.'.format(NGrid.prod()))
self.region = dict(le=le, re=re, gidx=gidx,
gleu=gleu, greu=greu,\
gle=gle, gre=gre,
NGrid=NGrid, Nxg=Nxg, Nxr=Nxr)
[docs] def get_slice(self, axis, field='density', pos='c', method='nearest'):
"""Read slice of fields.
Parameters
----------
axis : str
Axis to slice along. 'x' or 'y' or 'z'
field : (list of) str
The name of the field(s) to be read.
pos : float or str
Slice through If 'c' or 'center', get a slice through the domain
center. Default value is 'c'.
method : str
Returns
-------
slc : xarray dataset
An xarray dataset containing slices.
"""
axis_idx = dict(x=0, y=1, z=2)
if pos is None:
pos = 'c'
field = np.atleast_1d(field)
axis = np.atleast_1d(axis)
for ax in axis:
le = np.copy(self.domain['le'])
re = np.copy(self.domain['re'])
if pos in ['c', 'center']:
pos = self.domain['center'][axis_idx[ax]]
# Let's make sure le < re always and truncation error does not cause
# problem, although performance can be slowed down a bit.
le[axis_idx[ax]] = pos - 0.5*self.domain['dx'][axis_idx[ax]]
re[axis_idx[ax]] = pos + 0.5*self.domain['dx'][axis_idx[ax]]
dat = self.get_field(field, le, re, as_xarray=True)
slc = dat.sel(method='nearest', **{ax:pos})
return slc
[docs] def get_field(self, field='density', le=None, re=None, as_xarray=True):
"""Read 3d fields data.
Parameters
----------
field : (list of) string
The name of the field(s) to be read.
le : sequence of floats
Left edge. Default value is the domain left edge.
re : sequence of floats
Right edge. Default value is the domain right edge.
as_xarray : bool
If True, returns results as an xarray Dataset. If False, returns a
dictionary containing numpy arrays. Default value is True.
Returns
-------
dat : xarray dataset
An xarray dataset containing fields.
"""
field = np.atleast_1d(field)
# Derived field list
dflist = set(field) - set(self.field_list)
if not bool(dflist):
# dflist is an empty set, we can read all fields directly from vtk
# file
return self._get_field(field, le, re, as_xarray)
# If we are here, need to read all union of all input fields and those
# required to calculate derived fields
# Let's first make sure that we have all info about dflist
if not dflist.issubset(set(self.dfi.keys())):
tmp = []
for f in dflist:
if not f in self.dfi.keys():
tmp.append(f)
raise KeyError("Unrecognized field name(s):", tmp)
# Field names that are in the vtk file
flist = set(field) - dflist
# Fields that need to be read to calculate derived field
flist_dep = set()
for f in dflist:
flist_dep = flist_dep | set(self.dfi[f]['field_dep'])
# Fields names to be dropped later
fdrop_list = flist_dep - flist
# Need to adjust names for vector fields
for f in list(fdrop_list):
if self._field_map[f]['nvar'] > 1:
for i in range(self._field_map[f]['nvar']):
fdrop_list.add(f+str(i+1))
fdrop_list.remove(f)
field = list(flist_dep | flist)
dat = self._get_field(field, le, re, as_xarray)
# Calculate derived fields
for f in dflist:
dat[f] = self.dfi[f]['func'](dat, self.u)
# Drop fields that are not requested
if as_xarray:
dat = dat.drop(list(fdrop_list))
dat.attrs['dfi'] = self.dfi
else:
for f in fdrop_list:
del dat[f]
return dat.squeeze()
def _get_field(self, field='density', le=None, re=None, as_xarray=True):
field = np.atleast_1d(field)
# Create region
self.set_region(le=le, re=re)
arr = self._get_array(field)
# Works only for 3d data
if as_xarray:
# Cell center positions
coords = dict()
for axis, le, re, dx in zip(('x', 'y', 'z'), \
self.region['gle'], self.region['gre'], self.domain['dx']):
# May not result in correct number of elements due to truncation error
# x[axis] = np.arange(le + 0.5*dx, re + 0.5*dx, dx)
coords[axis] = np.arange(le + 0.5*dx, re + 0.25*dx, dx)
dat = dict()
for k, v in arr.items():
if len(v.shape) > self.domain['ndim']:
for i in range(v.shape[-1]):
dat[k + str(i+1)] = (('z','y','x'), v[..., i])
else:
dat[k] = (('z','y','x'), v)
attrs = dict()
for k, v in self.domain.items():
attrs[k] = v
attrs['num'] = self.num
return xr.Dataset(dat, coords=coords, attrs=attrs)
else:
if len(field) == 1:
return arr[field[0]]
else:
# Return a dictionary of numpy arrays
return arr
def _get_array(self, field):
arr = dict()
for f in field:
arr[f] = self._set_array(f)
# Read from individual grids and copy to data
le = self.region['gle']
dx = self.domain['dx']
for i in self.region['gidx']:
g = self.grid[i]
il = (np.rint((g['le'] - le)/dx)).astype(int)
iu = il + g['Nx']
slc = tuple([slice(l, u) for l, u in zip(il[::-1], iu[::-1])])
for f in field:
arr[f][slc] = self._read_array(g, f)
return arr
def _read_array(self, grid, field):
if field in grid['data']:
return grid['data'][field]
elif field in self.field_list:
fm = grid['field_map'][field]
if 'tarinfo' in grid:
fp = self.tarfile.extractfile(grid['tarinfo'])
else:
fp = open(grid['filename'], 'rb')
fp.seek(fm['offset'])
fp.readline() # skip header
if fm['read_table']:
fp.readline()
if Version(np.__version__) >= Version('2.0.0'):
arr = np.frombuffer(buffer=fp.read(fm['dsize']),dtype=fm['dtype']) # New method for converting the binary data
grid['data'][field]= arr.view(arr.dtype.newbyteorder()) # New method for converting the binary data
else:
grid['data'][field]= (np.frombuffer(buffer=fp.read(fm['dsize']),dtype=fm['dtype'])).newbyteorder() # New method for converting the binary data
fp.close()
if fm['nvar'] == 1:
shape = np.flipud(grid['Nx'])
else:
shape = (*np.flipud(grid['Nx']), fm['nvar'])
grid['data'][field].shape = shape
return grid['data'][field]
def _set_array(self, field):
dtype = self._field_map[field]['dtype']
nvar = self._field_map[field]['nvar']
Nxr = self.region['Nxr']
if 'face_centered_B' in field:
Nxr[int(field[-1])-1] += 1
if nvar == 1:
shape = np.flipud(Nxr)
else:
shape = (*np.flipud(Nxr), nvar)
return np.empty(shape, dtype=dtype)
def _set_domain(self):
domain = dict()
grid = self.grid
ngrid = len(grid)
# Grid left/right edges
gle = np.empty((ngrid, 3), dtype='float32')
gre = np.empty((ngrid, 3), dtype='float32')
dx = np.empty((ngrid, 3), dtype='float32')
Nx = np.ones_like(dx, dtype='int')
for i, g in enumerate(grid):
gle[i, :] = g['le']
gre[i, :] = g['re']
Nx[i, :] = g['Nx']
dx[i, :] = g['dx']
# Check if all grids have the equal size
if (Nx[0] == Nx).all():
domain['all_grid_equal'] = True
else:
domain['all_grid_equal'] = False
# Set domain
le = gle.min(axis=0)
re = gre.max(axis=0)
domain['ngrid'] = ngrid
domain['le'] = le
domain['re'] = re
domain['dx'] = dx[0, :]
domain['Lx'] = re - le
domain['center'] = 0.5*(le + re)
domain['Nx'] = np.round(domain['Lx']/domain['dx']).astype('int')
domain['ndim'] = 3 # always 3d
# file = open(self.fnames[0], 'rb')
# tmpgrid = dict()
# tmpgrid['time'] = None
# while tmpgrid['time'] is None:
# line = file.readline()
# _vtk_parse_line(line, tmpgrid)
# file.close()
# domain['time'] = tmpgrid['time']
domain['time'] = grid[0]['time']
return domain
def _set_grid(self):
grid = []
# Record filename and data_offset
for i, fname in enumerate(self.fnames):
file = open(fname, 'rb')
g = dict()
g['data'] = dict()
g['filename'] = fname
g['read_field'] = None
g['read_type'] = None
while g['read_field'] is None:
g['data_offset'] = file.tell()
line = file.readline()
_vtk_parse_line(line, g)
file.close()
g['Nx'] -= 1
g['Nx'][g['Nx'] == 0] = 1
g['dx'][g['Nx'] == 1] = 1.0
# Right edge
g['re'] = g['le'] + g['Nx']*g['dx']
grid.append(g)
return grid
def _parse_filename(filename):
"""Break up a filename into its component
to check the extension and extract the output number.
Parameters
----------
filename : string
Name of the file, including extension
Returns
-------
tuple containing dirname, problem_id, output number, extension, mpi flag, nonzero_id flag
Examples
--------
>>> _parse_filename('/basedir/id0/problem_id.0000.vtk')
('/basedir', 'problem_id', '0000', 'vtk', True, False)
>>> _parse_filename('/basedir/id10/problem_id-id10.0000.d1.vtk')
('/basedir', 'problem_id', '0000', 'vtk', True, True)
>>> _parse_filename('/basedir/problem_id.0000.vtk')
('/basedir', 'problem_id', '0000', 'vtk', False, False)
"""
sep = osp.sep
dirname = osp.dirname(filename)
nonzero_id = False
# Check if dirname ends with id0
dirname_last = dirname.split(sep)[-1]
if dirname_last.startswith('id') and \
dirname_last[2:].isdigit():
dirname = sep.join(dirname.split(sep)[:-1])
mpi_mode = True
else:
mpi_mode = False
base = os.path.basename(filename)
base_split = base.split('.')
if len(base_split) == 3:
problem_id = '.'.join(base_split[:-2])
num = base_split[-2]
suffix = None
ext = base_split[-1]
else:
try:
inum = -3
test = int(base_split[inum])
# If dirname is idXX where XX>0, (2d vtk slices)
# need to remove idXX string from the problem_id
suffix = base_split[-2]
except ValueError:
inum = -2
suffix = None
if mpi_mode and int(dirname_last[2:]) != 0:
problem_id = '.'.join(base_split[:inum])
problem_id = problem_id.replace('-' + dirname_last,'')
nonzero_id = True
else:
problem_id = '.'.join(base_split[:inum])
num = base_split[inum]
ext = base_split[-1]
return dirname, problem_id, num, suffix, ext, mpi_mode, nonzero_id
def _set_field_map(grid):
fp = open(grid['filename'], 'rb')
fp.seek(0, 2)
eof = fp.tell()
offset = grid['data_offset']
fp.seek(offset)
field_map = dict()
if 'Nx' in grid:
Nx = grid['Nx']
while offset < eof:
line = fp.readline()
sp = line.strip().split()
field = sp[1].decode('utf-8')
field_map[field] = dict()
field_map[field]['read_table'] = False
field_map[field]['title'] = line
if b"SCALARS" in line:
tmp = fp.readline()
field_map[field]['read_table'] = True
field_map[field]['nvar'] = 1
elif b"VECTORS" in line:
field_map[field]['nvar'] = 3
else:
raise TypeError(sp[0] + ' is unknown type.')
field_map[field]['offset'] = offset
field_map[field]['ndata'] = field_map[field]['nvar']*grid['ncells']
if field == 'face_centered_B1':
field_map[field]['ndata'] = (Nx[0]+1)*Nx[1]*Nx[2]
elif field == 'face_centered_B2':
field_map[field]['ndata'] = Nx[0]*(Nx[1]+1)*Nx[2]
elif field == 'face_centered_B3':
field_map[field]['ndata'] = Nx[0]*Nx[1]*(Nx[2]+1)
if sp[2] == b'int':
dtype = 'i'
elif sp[2] == b'float':
dtype = 'f'
elif sp[2] == b'double':
dtype = 'd'
field_map[field]['dtype'] = dtype
field_map[field]['dsize'] = field_map[field]['ndata']*struct.calcsize(dtype)
fp.seek(field_map[field]['dsize'], 1)
offset = fp.tell()
tmp = fp.readline()
if len(tmp) > 1:
fp.seek(offset)
else:
offset = fp.tell()
return field_map
def _vtk_parse_line(line, grid):
sp = line.strip().split()
if b"vtk" in sp:
grid['vtk_version'] = sp[-1]
elif b"time=" in sp:
time_index = sp.index(b"time=")
grid['time'] = float(sp[time_index + 1].rstrip(b','))
if b'level' in sp:
grid['level'] = int(sp[time_index + 3].rstrip(b','))
if b'domain' in sp:
grid['domain'] = int(sp[time_index + 5].rstrip(b','))
if sp[0] == b"PRIMITIVE":
grid['prim_var_type'] = True
elif b"DIMENSIONS" in sp:
grid['Nx'] = np.array(sp[-3:]).astype('int')
elif b"ORIGIN" in sp: # left_edge
grid['le'] = np.array(sp[-3:]).astype('float64')
elif b"SPACING" in sp:
grid['dx'] = np.array(sp[-3:]).astype('float64')
elif b"CELL_DATA" in sp:
grid['ncells'] = int(sp[-1])
elif b"SCALARS" in sp:
grid['read_field'] = sp[1]
grid['read_type'] = 'scalar'
elif b"VECTORS" in sp:
grid['read_field'] = sp[1]
grid['read_type'] = 'vector'
elif b"POINTS" in sp:
grid['npoint'] = eval(sp[1])