Source code for gizio.core

"""Core interface."""
from collections import OrderedDict
from copy import copy, deepcopy
import os
from pathlib import Path

import h5py
import numpy as np
import unyt

from .spec import SPEC_REGISTRY


[docs]def load(prefix, suffix=".hdf5", spec="gizmo"): """Load snapshot. Parameters ---------- prefix : str or pathlib.Path Snapshot file(s) prefix. suffix : str, optional Snapshot file(s) suffix. (default: ".hdf5") spec : str or SpecBase, optional Snapshot format specification. If given as str, will use a built-in one. (default: "gizmo") Returns ------- Snapshot The loaded snapshot. """ # Determine snapshot paths prefix = Path(prefix).expanduser().resolve() if prefix.is_dir(): # Directory case parent = prefix glob_pattern = "*" + suffix else: parent = prefix.parent # Single file case glob_pattern = prefix.name if not prefix.is_file(): # Glob prefix case glob_pattern += "*" + suffix paths = sorted(parent.glob(glob_pattern)) if isinstance(spec, str): spec = SPEC_REGISTRY[spec]() return Snapshot(paths, spec)
[docs]class Snapshot: """Simulation snapshot. .. describe:: snap[key] .. describe:: snap[ptype, field] Load the field from file and cache in memory. .. describe:: del snap[key] .. describe:: del snap[ptype, field] Delete the cache. Parameters ---------- paths : typing.Iterable Snapshot file paths in correct order. spec : SpecBase Snapshot format specification. Attributes ---------- paths : list Snapshot file paths. prefix : str Common prefix of paths without trailing dot. spec : SpecBase Snapshot format specification. header : dict Snapshot header. shape : collections.OrderedDict Data shape composed of {ptype_name: n_part} entries, in the specification ptypes order. cosmology : astropy.cosmology.LambdaCDM An astropy cosmology calculator. unit_registry : unyt.unit_registry.UnitRegistry Simulation unit registry. pt : dict Dictionary of particle type selectors. """ def __init__(self, paths, spec): self.paths = [Path(path).resolve() for path in paths] self.prefix = os.path.commonprefix(self.paths).rstrip(".") # Apply spec to extract meta info header, shape, cosmology, unit_registry = spec.apply_to(self) self.spec = spec self.header = header self.shape = shape self.cosmology = cosmology self.unit_registry = unit_registry # Set up default particle selectors according to particle types self.pt = {} for ptype, abbr in self.spec.ptype_abbrs.items(): if self.shape[ptype] > 0: self.pt[abbr] = ParticleSelector.from_ptypes(self, [ptype]) self.spec.register_derived_fields(self.pt[abbr], abbr) self.pt["all"] = ParticleSelector.from_ptypes(self, self.spec.ptypes) self.spec.register_derived_fields(self.pt["all"], "all") # Initialize field cache self._field_cache = {} # dictionay interface
[docs] def keys(self): """A list of available keys.""" keys = [] # It suffices to check the first file with h5py.File(self.paths[0], "r") as f: for ptype in self.spec.ptypes: if ptype in f: for field in f[ptype].keys(): keys += [(ptype, field)] return keys
[docs] def cached_keys(self): """A list of cached keys.""" return list(self._field_cache.keys())
[docs] def clear_cache(self): """Clear field cache.""" self._field_cache = {}
def __getitem__(self, key): # Create cache if not existing if key not in self._field_cache: # Load from file value = [] for path in self.paths: with h5py.File(path, "r") as h5f: value += [h5f["/".join(key)][()]] value = np.concatenate(value) # Determine unit _, field = key if field in self.spec.field_units: # Use spec unit if defined unit = self.spec.field_units[field] else: # Assume dimentionless otherwise unit = "dimensionless" # Create cache self._field_cache[key] = self.array(value, unit) # Retrieve cache return self._field_cache[key] def __delitem__(self, key): # Delete cache del self._field_cache[key] # unyt helpers
[docs] def array(self, value, unit): """Helper method to create unyt array with snapshot unit registry. Parameters ---------- value : typing.Iterable The value. unit : str The unit. Returns ------- unyt.array.unyt_array A unyt array. """ return unyt.unyt_array(value, unit, registry=self.unit_registry)
[docs] def quantity(self, value, unit): """Helper method to create unyt quantity with snapshot unit registry. Parameters ---------- value : float The value. unit : str The unit. Returns ------- unyt.array.unyt_quantity A unyt quantity. """ return unyt.unyt_quantity(value, unit, registry=self.unit_registry)
[docs]class ParticleSelector: """High level snapshot field access for selected particles. .. describe:: len(ps) Return the number of selected particles. .. describe:: ps[key] Retrieve the field. Compute and create cache if not existing already. .. describe:: del ps[key] Delete the cache. Parameters ---------- snap : Snapshot The snapshot to access. Attributes ---------- snap : Snapshot The snapshot to access. """ @property def pmask(self): """collections.OrderedDict: Particle mask.""" return OrderedDict(zip(self.snap.spec.ptypes, self._masks)) @property def shape(self): """collections.OrderedDict: Shape.""" return OrderedDict( [ (ptype, mask.sum()) if mask is not None else (ptype, 0) for ptype, mask in self.pmask.items() ] )
[docs] @classmethod def from_ptypes(cls, snap, ptypes): """Create particle selector for specified particle types. Parameters ---------- snap : Snapshot The snapshot to access. ptypes : list A list of particle types to select. Returns ------- ParticleSelector The corresponding particle selector. """ spec_ptypes = deepcopy(snap.spec.ptypes) masks = [] for ptype in spec_ptypes: if ptype in ptypes: n_part = snap.shape[ptype] mask = np.ones(n_part, dtype=bool) else: mask = None masks.append(mask) return cls(snap, masks)
def __init__(self, snap, masks): # Initialize field system self.snap = snap self._masks = masks self.normalize_mask() self._field_registry = {} self._field_cache = {} # Register direct fields for key, field in self.direct_fields().items(): self.register_direct_field(key, field) def __copy__(self): ps = ParticleSelector(self.snap, deepcopy(self._masks)) ps._field_registry = deepcopy(self._field_registry) return ps def __len__(self): return sum(self.shape.values()) # field system
[docs] def keys(self): """All registered fields. Returns ------- dict_keys A view on keys. """ return self._field_registry.keys()
[docs] def direct_fields(self): """Known direct fields. Returns ------- dict A dictionary mapping shorhand keys to raw field names. """ ptype_fields = {} for ptype, field in self.snap.keys(): if self.pmask[ptype] is not None: if ptype not in ptype_fields: ptype_fields[ptype] = {field} else: ptype_fields[ptype].add(field) common_fields = set.intersection(*ptype_fields.values()) direct_fields = {} for field in common_fields: if field in self.snap.spec.field_abbrs: key = self.snap.spec.field_abbrs[field] else: key = field direct_fields[key] = field return direct_fields
[docs] def register_field(self, key, func): """Register a field. Parameters ---------- key : str The key to retrieve the field. func : typing.Callable The function to compute the field. """ self._field_registry[key] = func
[docs] def register_direct_field(self, key, field): """Register a direct field. Parameters ---------- key : str The key to retrieve the field. field : str The raw field name. """ def load_direct_field(ps): data = [] for ptype, mask in ps.pmask.items(): if mask is not None: data += [ps.snap[ptype, field][mask]] return unyt.array.uconcatenate(data) self.register_field(key, load_direct_field)
[docs] def unregister_field(self, key): """Unregister a field. Parameters ---------- key : str The key of the field. """ del self._field_registry[key] del self[key]
[docs] def clear_cache(self): """Clear all field caches.""" self._field_cache = {}
def __contains__(self, key): return key in self._field_registry def __getitem__(self, key): if isinstance(key, np.ndarray) and key.dtype == bool: # Boolean masking return self._where(key) if isinstance(key, str): # Field access if key not in self._field_cache: self._field_cache[key] = self._field_registry[key](self) return self._field_cache[key] raise KeyError def __delitem__(self, key): if key in self._field_cache: del self._field_cache[key] # mask operation
[docs] def normalize_mask(self): """Normalize mask arrays.""" # Consistently substitute all-false arrays by None self._masks = [ mask if mask is not None and mask.any() else None for mask in self._masks ]
def _where(self, cond): assert len(cond) == len(self) # Create new ParticleMask ps = copy(self) left = 0 for mask in ps._masks: if mask is not None: # Update one mask array n_part = mask.sum() mask[mask] = cond[left : left + n_part] left += n_part ps.normalize_mask() return ps def _update_mask(self, operator, other): # Initialize object self.clear_cache() keys_to_unregister = [key for key in self.keys() if key not in other] for key in keys_to_unregister: self.unregister_field(key) # Evaluate operation from itertools import starmap def apply_op(mask1, mask2): if mask1 is None and mask2 is None: return None if mask1 is None: mask1 = np.zeros_like(mask2) if mask2 is None: mask2 = np.zeros_like(mask1) return operator(mask1, mask2) self._masks = tuple(starmap(apply_op, zip(self._masks, other._masks))) self.normalize_mask() return self ## | union def __or__(self, other): return copy(self).__ior__(other) def __ior__(self, other): return self._update_mask(np.logical_or, other) ## & intersection def __and__(self, other): return copy(self).__iand__(other) def __iand__(self, other): return self._update_mask(np.logical_and, other) ## - difference def __sub__(self, other): return copy(self).__isub__(other) def __isub__(self, other): return self._update_mask(lambda a, b: a ^ (a & b), other) ## ^ symmetric difference def __xor__(self, other): return copy(self).__ixor__(other) def __ixor__(self, other): return self._update_mask(np.logical_xor, other)