Source code for pyretis.core.pathensemble

# -*- coding: utf-8 -*-
# Copyright (c) 2023, PyRETIS Development Team.
# Distributed under the LGPLv2.1+ License. See LICENSE for more info.
"""Classes and functions for path ensembles.

The classes and functions defined in this module are useful for
representing path ensembles.

Important classes defined here
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

PathEnsemble (:py:class:`.PathEnsemble`)
    Class for defining path ensembles.

PathEnsembleExt (:py:class:`.PathEnsembleExt`)
    Class for defining path ensembles when we are working with
    paths stored on disk and not in memory only.
"""
import collections
import logging
import os
import shutil
from pyretis.core.path import Path
from pyretis.core.common import big_fat_comparer
from pyretis.core.random_gen import create_random_generator


logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
logger.addHandler(logging.NullHandler())


__all__ = ['PathEnsemble', 'PathEnsembleExt', 'get_path_ensemble_class',
           'generate_ensemble_name']


[docs]def generate_ensemble_name(ensemble_number, zero_pad=3): """Generate a simple name for an ensemble. The simple name will have a format like 01, 001, 0001 etc. and it is used to name the path ensemble and the output directory. Parameters ---------- ensemble_number : int The number representing the ensemble. zero_pad : int, optional The number of zeros to use for padding the name. Returns ------- out : string The ensemble name. """ if zero_pad < 3: logger.warning('zero_pad must be >= 3. Setting it to 3.') zero_pad = 3 fmt = f'{{:0{zero_pad}d}}' return fmt.format(ensemble_number)
[docs]def _generate_file_names(path, target_dir, prefix=None): """Generate new file names for moving copying paths. Parameters ---------- path : object like :py:class:`.PathBase` This is the path object we are going to store. target_dir : string The location where we are moving the path to. prefix : string, optional The prefix can be used to prefix the name of the files. Returns ------- out[0] : list A list with new file names. out[1] : dict A dict which defines the unique "source -> destination" for copy/move operations. """ source = {} new_pos = [] for phasepoint in path.phasepoints: pos_file, idx = phasepoint.particles.get_pos() if pos_file not in source: localfile = os.path.basename(pos_file) if prefix is not None: localfile = f'{prefix}{localfile}' dest = os.path.join(target_dir, localfile) source[pos_file] = dest dest = source[pos_file] new_pos.append((dest, idx)) return new_pos, source
[docs]class PathEnsemble: """Representation of a path ensemble. This class represents a collection of paths in a path ensemble. In general, paths may be "long and complicated" so here, we really just store a simplified abstraction of the path, which is obtained by the `Path.get_path_data()` function for a given `Path` object. The returned dictionary is stored in the list `PathEnsemble.paths`. The only full path we store is the last accepted path. This is convenient for the RETIS method where paths may be swapped between path ensembles. Attributes ---------- ensemble_number : integer This integer is used to represent the path ensemble, for RETIS simulations it's useful to identify the path ensemble. The path ensembles are numbered sequentially 0, 1, 2, etc. This corresponds to ``[0^-]``, ``[0^+]``, ``[1^+]``, etc. ensemble_name : string A string which can be used for printing the ensemble name. This is of form ``[0^-]``, ``[0^+]``, ``[1^+]``, etc. ensemble_name_simple : string A string with a simpler representation of the ensemble name, can be used for creating output files etc. interfaces : list of floats Interfaces, specified with the values for the order parameters: `[left, middle, right]`. paths : list This list contains the stored information for the paths. Here we only store the data returned by calling the `get_path_data()` function of the `Path` object. nstats : dict of ints This dict store some statistics for the path ensemble. The keys are: * npath : The number of paths stored. * nshoot : The number of accepted paths generated by shooting. * ACC, BWI, ... : Number of paths with given status (from `_STATUS`). maxpath : int The maximum number of paths to store. last_path : object like :py:class:`.PathBase` This is the last **accepted** path. """
[docs] def __init__(self, ensemble_number, interfaces, rgen=None, maxpath=10000, exe_dir=None): """Initialise the PathEnsemble object. Parameters ---------- ensemble_number : integer An integer used to identify the ensemble. interfaces : list of floats These are the interfaces specified with the values for the order parameters: ``[left, middle, right]``. rgen : object like :py:class:`.RandomGenerator`, optional The random generator that will be used for the paths that required random numbers. maxpath : integer, optional The maximum number of paths to store information for in memory. Note, that this will not influence the analysis as long as you are using the output files when running the analysis. exe_dir : string, optional The base folder where the simulation was executed from. This is used to set up output directories for the path ensemble. """ if rgen is None: rgen = create_random_generator() self.rgen = rgen self.ensemble_number = ensemble_number self.interfaces = tuple(interfaces) # Should not change interfaces. self.last_path = None self.nstats = {'npath': 0, 'nshoot': 0, 'ACC': 0} self.paths = [] self.maxpath = maxpath if self.ensemble_number == 0: self.ensemble_name = '[0^-]' self.start_condition = 'R' else: ensemble_number = self.ensemble_number - 1 self.ensemble_name = f'[{ensemble_number}^+]' self.start_condition = 'L' self.ensemble_name_simple = generate_ensemble_name( self.ensemble_number ) self.directory = collections.OrderedDict() self.directory['path_ensemble'] = None self.directory['accepted'] = None self.directory['generate'] = None self.directory['traj'] = None if exe_dir is not None: path_dir = os.path.join(exe_dir, self.ensemble_name_simple) self.update_directories(path_dir)
[docs] def __eq__(self, other): """Check if two path_ensemble are equal.""" equal = True if self.__class__ != other.__class__: logger.debug('%s and %s.__class__ differ', self, other) return False if set(self.__dict__) != set(other.__dict__): logger.debug('%s and %s.__dict__ differ', self, other) equal = False for i in ['directory', 'interfaces', 'nstats', 'paths']: if hasattr(self, i): for j, k in zip(getattr(self, i), getattr(other, i)): if j != k: logger.debug('%s for %s and %s attributes are %s and ' '%s', i, self, other, j, k) equal = False for i in ['ensemble_name', 'ensemble_name_simple', 'ensemble_number', 'maxpath', 'start_condition']: if hasattr(self, i): if getattr(self, i) != getattr(other, i): logger.debug('%s for %s and %s, attributes are %s and %s', i, self, other, getattr(self, i), getattr(other, i)) equal = False if hasattr(self, 'last_path'): if self.last_path != other.last_path: logger.debug('last paths differs') equal = False if hasattr(self, 'rgen'): if self.rgen.__class__ != other.rgen.__class__: logger.debug('self.rgen.__class__ differs') return False if self.rgen.__dict__['seed'] != other.rgen.__dict__['seed']: logger.debug('rgen seed differs') equal = False if not big_fat_comparer(self.rgen.__dict__['rgen'].get_state(), other.rgen.__dict__['rgen'].get_state()): logger.debug('rgen differs') equal = False return equal
[docs] def __ne__(self, other): """Check if two paths are not equal.""" return not self == other
[docs] def directories(self): """Yield the directories PyRETIS should make.""" for key in self.directory: yield self.directory[key]
[docs] def update_directories(self, path): """Update directory names. This method will not create new directories, but it will update the directory names. Parameters ---------- path : string The base path to set. """ for key, val in self.directory.items(): if key == 'path_ensemble': self.directory[key] = path else: self.directory[key] = os.path.join(path, key) if val is None: logger.debug('Setting directory "%s" to %s', key, self.directory[key]) else: logger.debug('Updating directory "%s": %s -> %s', key, val, self.directory[key])
[docs] def reset_data(self): """Erase the stored data in the path ensemble. It can be used in combination with flushing the data to a file in order to periodically write and empty the amount of data stored in memory. Notes ----- We do not reset `self.last_path` as this might be used in the RETIS function. """ self.paths = [] for key in self.nstats: self.nstats[key] = 0
[docs] def store_path(self, path): """Store a new accepted path in the path ensemble. Parameters ---------- path : object like :py:class:`.PathBase` The path we are going to store. Returns ------- None, but we update `self.last_path`. """ self.last_path = path
[docs] def add_path_data(self, path, status, cycle=0): """Append data from the given path to `self.path_data`. This will add the data from a given` path` to the list path data for this ensemble. If will also update `self.last_path` if the given `path` is accepted. Parameters ---------- path : object like :py:class:`.PathBase` This is the object to store data from. status : string This is the status of the path. Note that the path object also has a status property. However, this one might not be set, for instance when the path is just None. We therefore use `status` here as a parameter. cycle : int, optional The current cycle number. """ if len(self.paths) >= self.maxpath: # This is just to limit the data we keep in memory in # case of really long simulations. logger.debug(('Path-data memory storage reset for ensemble %s.\n' 'This is just to limit the amount of data we store ' 'in memory.\nThis will *NOT* influence the ' 'simulation'), self.ensemble_name) self.paths = [] # Update statistics: if path is None: # Here we add a dummy path with minimal info. This is because we # could not generate a path for some reason which should be # specified by the status. path_data = {'status': status, 'generated': ('', 0, 0, 0), 'weight': 1.} else: path_data = path.get_path_data(status, self.interfaces) if 'EXP' in status: path_data['status'] = 'EXP' if path_data['status'] in {'ACC', 'EXP'}: # Store the path: self.store_path(path) if path_data['generated'][0] in {'sh', 'ss', 'wt', 'wf'}: self.nstats['nshoot'] += 1 path_data['cycle'] = cycle # Also store cycle number. self.paths.append(path_data) # Store the new data. # Update some statistics: # This is to count also for the first occurrence of the status: self.nstats[status] = self.nstats.get(status, 0) + 1 self.nstats['npath'] += 1
[docs] def get_accepted(self): """Yield accepted paths from the path ensemble. This function will return an iterator useful for iterating over accepted paths only. In the path ensemble we store both accepted and rejected paths. This function will loop over all paths stored and yield the accepted paths the correct number of times. """ last_path = None for path in self.paths: if path['status'] == 'ACC': last_path = path yield last_path
[docs] def get_acceptance_rate(self): """Return acceptance rate for the path ensemble. The acceptance rate is obtained as the fraction of accepted paths to the total number of paths in the path ensemble. This will only consider the paths that are currently stored in `self.paths`. Returns ------- out : float The acceptance rate. """ acc = 0 npath = 0 for path in self.paths: if path['status'] == 'ACC': acc += 1 npath += 1 return float(acc) / float(npath)
[docs] def get_paths(self): """Yield the different paths stored in the path ensemble. It is included here in order to have a simple compatibility between the :py:class:`.PathEnsemble` object and the py:class:`.PathEnsembleFile` object. This is useful for the analysis. Yields ------ out : dict This is the dictionary representing the path data. """ for path in self.paths: yield path
[docs] def move_path_to_generate(self, _path, _prefix=None): """Move a path for temporary storing.""" return
[docs] def copy_path_to_generate(self, path, _prefix=None): """Copy a path for temporary storing. Parameters ---------- path : object like :py:class:`.PathBase` The path to copy. Returns ------- out : object like :py:class:`.PathBase` The copy of the path. """ path_copy = path.copy() return path_copy
[docs] def __str__(self): """Return a string with some info about the path ensemble.""" msg = [f'Path ensemble: {self.ensemble_name}'] msg += [f'\tInterfaces: {self.interfaces}'] if self.nstats['npath'] > 0: npath = self.nstats['npath'] nacc = self.nstats.get('ACC', 0) msg += [f'\tNumber of paths stored: {npath}'] msg += [f'\tNumber of accepted paths: {nacc}'] ratio = float(nacc) / float(npath) msg += [f'\tRatio accepted/total paths: {ratio}'] return '\n'.join(msg)
[docs] def restart_info(self): """Return a dictionary with restart information.""" restart = { 'nstats': self.nstats, 'interfaces': self.interfaces, 'ensemble_number': self.ensemble_number, } if hasattr(self, 'rgen'): restart['rgen'] = self.rgen.get_state() if self.last_path: restart['last_path'] = self.last_path.restart_info() return restart
[docs] def load_restart_info(self, info, cycle=0): """Load restart information. Parameters ---------- info : dict A dictionary with the restart information. cycle : integer, optional The current simulation cycle. """ self.nstats = info['nstats'] for attr in ('interfaces', 'ensemble_number'): if info[attr] != getattr(self, attr): logger.warning( 'Inconsistent path ensemble restart info for %s', attr) for key in info: if key == 'rgen': self.rgen = create_random_generator(info[key]) elif key == 'last_path': rgen = create_random_generator(info[key]['rgen']) path = Path(rgen=rgen) path.load_restart_info(info['last_path']) path_data = path.get_path_data('ACC', self.interfaces) path_data['cycle'] = cycle self.last_path = path self.paths.append(path_data) elif hasattr(self, key): setattr(self, key, info[key])
[docs] def clear_generate(self): """Remove all the files in an ensemble/generate/ directory. This is toggled on by adding 'remove_generate = True' to the retis.rst input file, under the [simulation] section. """ gendir = self.directory['generate'] logger.debug("Removing generate files from %s", gendir) if gendir is not None and os.path.exists(gendir): for file in os.listdir(gendir): path = os.path.join(gendir, file) logger.debug("Removing generate file %s", path) # assert that path ends with generate/file assert path.endswith(os.path.join('generate', file)) os.remove(path)
[docs]class PathEnsembleExt(PathEnsemble): """Representation of a path ensemble. This class is similar to :py:class:`.PathEnsemble` but it is made to work with external paths. That is, some extra file handling is done when accepting a path. """
[docs] @staticmethod def _move_path(path, target_dir, prefix=None): """Move a path to a given target directory. Parameters ---------- path : object like :py:class:`.PathBase` This is the path object we are going to move. target_dir : string The location where we are moving the path to. prefix : string, optional To give a prefix to the name of moved files. """ logger.debug('Moving path to %s', target_dir) new_pos, source = _generate_file_names(path, target_dir, prefix=prefix) for pos, phasepoint in zip(new_pos, path.phasepoints): phasepoint.particles.set_pos(pos) for src, dest in source.items(): if src == dest: logger.debug('Skipping move %s -> %s', src, dest) else: if os.path.exists(dest): if os.path.isfile(dest): logger.debug('Removing %s as it exists', dest) os.remove(dest) logger.debug('Moving %s -> %s', src, dest) os.rename(src, dest)
[docs] @staticmethod def _copy_path(path, target_dir, prefix=None): """Copy a path to a given target directory. Parameters ---------- path : object like :py:class:`.PathBase` This is the path object we are going to copy. target_dir : string The location where we are copying the path to. prefix : string, optional To give a prefix to the name of copied files. Returns ------- out : object like py:class:`.PathBase` A copy of the input path. """ path_copy = path.copy() new_pos, source = _generate_file_names(path_copy, target_dir, prefix=prefix) # Update positions: for pos, phasepoint in zip(new_pos, path_copy.phasepoints): phasepoint.particles.set_pos(pos) for src, dest in source.items(): if src != dest: if os.path.exists(dest): if os.path.isfile(dest): logger.debug('Removing %s as it exists', dest) os.remove(dest) logger.debug('Copy %s -> %s', src, dest) shutil.copy(src, dest) return path_copy
[docs] def store_path(self, path): """Store a path by explicitly moving it. Parameters ---------- path : object like :py:class:`.PathBase` This is the path object we are going to store. """ self._move_path(path, self.directory['accepted']) self.last_path = path for entry in self.list_superfluous(): try: os.remove(entry) except OSError: # pragma: no cover pass
[docs] def list_superfluous(self): """List files in accepted directory that we do not need.""" last = set() if self.last_path: for phasepoint in self.last_path.phasepoints: pos_file, _ = phasepoint.particles.get_pos() last.add(pos_file) for entry in os.scandir(self.directory['accepted']): if entry.is_file() and entry.path not in last: yield entry.path
[docs] def move_path_to_generate(self, path, prefix=None): """Move a path for temporary storing.""" self._move_path(path, self.directory['generate'], prefix=prefix)
[docs] def copy_path_to_generate(self, path, pref=None): """Copy a path for temporary storing.""" return self._copy_path(path, self.directory['generate'], prefix=pref)
[docs] def load_restart_info(self, info, cycle=0): """Load restart for external path.""" super().load_restart_info(info, cycle=cycle) # Update file names: directory = self.directory['accepted'] for phasepoint in self.last_path.phasepoints: filename = os.path.basename(phasepoint.particles.get_pos()[0]) new_file_name = os.path.join(directory, filename) if not os.path.isfile(new_file_name): logger.critical('The restart path "%s" does not exist', new_file_name) phasepoint.particles.set_pos((new_file_name, phasepoint.particles.get_pos()[1]))
[docs]def get_path_ensemble_class(ensemble_type): """Return the path ensemble class consistent with the given engine. Parameters ---------- ensemble_type : string The type of ensemble we are requesting. """ path_ensemble_map = {'internal': PathEnsemble, 'external': PathEnsembleExt} try: return path_ensemble_map[ensemble_type] except KeyError as err: msg = f'Unknown ensemble type "{ensemble_type}" requested.' logger.critical(msg) raise ValueError(msg) from err