Source code for pyretis.core.common

# -*- coding: utf-8 -*-
# Copyright (c) 2023, PyRETIS Development Team.
# Distributed under the LGPLv2.1+ License. See LICENSE for more info.
"""Definition of some common methods that might be useful.

Important methods defined here
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

bit_fat_comparer (:py:func`.big_fat_comparer`)
    Method to compare two nested list/dictionaries.

compare_objects (:py:func`.compare_objects`)
    Method to compare two PyRETIS objects.

counter (:py:func`.counter`)
    Function to count the number of iterations.

crossing_counter (:py:func`.crossing_counter`)
    Function to count the crossing of a path on an interface.

crossing_finder (:py:func`.crossing_finder`)
    Function to get the shooting points of the crossing of a path
    on an interface.

import_from (:py:func:`.import_from`)
    A method to dynamically import method/classes etc. from user
    specified modules.

inspect_function (:py:func:`.inspect_function`)
    A method to obtain information about arguments, keyword arguments
    for functions.

initiate_instance (:py:func:`.initiate_instance`)
    Method to initiate a class with optional arguments.

generic_factory (:py:func:`.generic_factory`)
    Create instances of classes based on settings.

null_move (:py:func`.compare_objects`)
    Method to do not move.

priority_checker (:py:func`.priority_checker`)
    Method to check ensemble to prioritize.

relative_shoots_select (:py:func`.compare_objects`)
    Method to select the shooting ensemble.

segments_counter (:py:func`.segments_counter`)
    Function that counts the number of segments between two interfaces.

trim_path_between_interfaces (:py:func`.trim_path_between_interfaces`)
    Function to trim a path between interfaces.

select_and_trim_a_segment (:py:func`.select_and_trim_a_segment`)
    Function to trim a path between interfaces plus the two external points.

compute_weight (:py:func:`.compute_weight`)
    A method to compute the statistical weight of a path generated by a
    Stone Skipping and Wire Fencing move.

soft_partial_exit (:py:func`.soft_partial_exit`)
    Function that check the presence of the EXIT file,
    and kindly stops the iterator.

"""
import logging
import inspect
import importlib
import os
import sys
import numpy as np
from pyretis.inout import print_to_screen

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
logger.addHandler(logging.NullHandler())


__all__ = ['import_from', 'inspect_function', 'initiate_instance',
           'generic_factory', 'crossing_counter', 'crossing_finder',
           'segments_counter', 'select_and_trim_a_segment', 'counter',
           'trim_path_between_interfaces', 'big_fat_comparer',
           'soft_partial_exit', 'null_move', 'relative_shoots_select',
           'compute_weight', 'priority_checker']


[docs]def counter(): """Return how many times this function is called.""" counter.count = 0 if not hasattr(counter, 'count') else counter.count + 1 return counter.count
[docs]def import_from(module_path, function_name): """Import a method/class from a module. This method will dynamically import a specified method/object from a module and return it. If the module can not be imported or if we can't find the method/class in the module we will raise exceptions. Parameters ---------- module_path : string The path/filename to load from. function_name : string The name of the method/class to load. Returns ------- out : object The thing we managed to import. """ try: module_name = os.path.basename(module_path) module_name = os.path.splitext(module_name)[0] spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) sys.modules[module_name] = module logger.debug('Imported module: %s', module) return getattr(module, function_name) except (ImportError, IOError): msg = f'Could not import module: {module_path}' logger.critical(msg) except AttributeError: msg = f'Could not import "{function_name}" from "{module_path}"' logger.critical(msg) raise ValueError(msg)
[docs]def _arg_kind(arg): """Determine kind for a given argument. This method will help :py:func:`.inspect_function` to determine the correct kind for arguments. Parameters ---------- arg : object like :py:class:`inspect.Parameter` The argument we will determine the type of. Returns ------- out : string A string we use for determine the kind. """ kind = None if arg.kind == arg.POSITIONAL_OR_KEYWORD: if arg.default is arg.empty: kind = 'args' else: kind = 'kwargs' elif arg.kind == arg.POSITIONAL_ONLY: kind = 'args' elif arg.kind == arg.VAR_POSITIONAL: kind = 'varargs' elif arg.kind == arg.VAR_KEYWORD: kind = 'keywords' elif arg.kind == arg.KEYWORD_ONLY: # We treat these as keyword arguments: kind = 'kwargs' return kind
[docs]def big_fat_comparer(any1, any2, hard=False): """Check if two dictionary are the same, regardless their complexity. Parameters ---------- any1 : anything any2 : anything hard : boolean, optional Raise ValueError if any1 and any2 are different Returns ------- out : boolean True if any1 = any2, false otherwise """ if type(any1) is not type(any2): if hard: raise ValueError('Fail type', any1, any2) return False if isinstance(any1, (list, tuple)): if len(any1) != len(any2): if hard: raise ValueError('Fail list length', any1, any2) return False for key1, key2 in zip(any1, any2): if not big_fat_comparer(key1, key2, hard): if hard: raise ValueError('Fail item in list', any1, any2) # pragma: no cover return False elif isinstance(any1, np.ndarray): if any1.shape != any2.shape: if hard: raise ValueError('Fail np array shape', any1, any2) return False for key1, key2 in zip(np.nditer(any1), np.nditer(any2)): if not (key1 == key2).all(): if hard: raise ValueError('Fail np array item', any1, any2) return False elif isinstance(any1, dict): for key in any1: if key not in any2: if hard: raise ValueError('Fail dict', any1, any2) return False if not isinstance(any1[key], type(any2[key])): if hard: raise ValueError('Fail types', any1[key], any2[key]) return False if isinstance(any1[key], (dict, list, tuple, np.ndarray)): if not big_fat_comparer(any1[key], any2[key], hard): if hard: raise ValueError('Fail item', any1[key], any2[key]) # pragma: no cover return False else: if any1[key] != any2[key]: if hard: raise ValueError('Fail item', any1[key], any2[key]) return False for key in any2: if key not in any1: if hard: raise ValueError('Fail item', any1, any2) return False else: if any1 != any2: if hard: raise ValueError('Fail item', any1, any2) return False return True
[docs]def inspect_function(function): """Return arguments/kwargs of a given function. This method is intended for use where we are checking that we can call a certain function. This method will return arguments and keyword arguments a function expects. This method may be fragile - we assume here that we are not really interested in args and kwargs and we do not look for more information about these here. Parameters ---------- function : callable The function to inspect. Returns ------- out : dict A dict with the arguments, the following keys are defined: * `args` : list of the positional arguments * `kwargs` : list of keyword arguments * `varargs` : list of arguments * `keywords` : list of keyword arguments """ out = {'args': [], 'kwargs': [], 'varargs': [], 'keywords': []} arguments = inspect.signature(function) # pylint: disable=no-member for arg in arguments.parameters.values(): kind = _arg_kind(arg) if kind is not None: out[kind].append(arg.name) else: # pragma: no cover logger.critical('Unknown variable kind "%s" for "%s"', arg.kind, arg.name) return out
[docs]def _pick_out_arg_kwargs(klass, settings): """Pick out arguments for a class from settings. Parameters ---------- klass : class The class to initiate. settings : dict Positional and keyword arguments to pass to `klass.__init__()`. Returns ------- out[0] : list A list of the positional arguments. out[1] : dict The keyword arguments. """ info = inspect_function(klass.__init__) used, args, kwargs = set(), [], {} for arg in info['args']: if arg == 'self': continue try: args.append(settings[arg]) used.add(arg) except KeyError: msg = f'Required argument "{arg}" for "{klass}" not found!' raise ValueError(msg) for arg in info['kwargs']: if arg == 'self': continue if arg in settings: kwargs[arg] = settings[arg] return args, kwargs
[docs]def initiate_instance(klass, settings): """Initialise a class with optional arguments. Parameters ---------- klass : class The class to initiate. settings : dict Positional and keyword arguments to pass to `klass.__init__()`. Returns ------- out : instance of `klass` Here, we just return the initiated instance of the given class. """ args, kwargs = _pick_out_arg_kwargs(klass, settings) # Ready to initiate: msg = 'Initiated "%s" from "%s" %s' name = klass.__name__ mod = klass.__module__ if not args: if not kwargs: logger.debug(msg, name, mod, 'without arguments.') return klass() logger.debug(msg, name, mod, 'with keyword arguments.') return klass(**kwargs) if not kwargs: logger.debug(msg, name, mod, 'with positional arguments.') return klass(*args) logger.debug(msg, name, mod, 'with positional and keyword arguments.') return klass(*args, **kwargs)
[docs]def generic_factory(settings, object_map, name='generic'): """Create instances of classes based on settings. This method is intended as a semi-generic factory for creating instances of different objects based on simulation input settings. The input settings define what classes should be created and the object_map defines a mapping between settings and the class. Parameters ---------- settings : dict This defines how we set up and select the order parameter. object_map : dict Definitions on how to initiate the different classes. name : string, optional Short name for the object type. Only used for error messages. Returns ------- out : instance of a class The created object, in case we were successful. Otherwise we return none. """ try: klass = settings['class'].lower() except KeyError: msg = 'No class given for %s -- could not create object!' logger.critical(msg, name) return None if klass not in object_map: logger.critical('Could not create unknown class "%s" for %s', settings['class'], name) return None cls = object_map[klass]['cls'] return initiate_instance(cls, settings)
def numpy_allclose(val1, val2): """Compare two values with allclose from numpy. Here, we allow for one, or both, of the values to be None. Note that if val1 == val2 but are not of a type known to numpy, the returned value will be False. Parameters ---------- val1 : np.array The variable in the comparison. val2 : np.array The second variable in the comparison. Returns ------- out : boolean True if the values are equal, False otherwise. """ if val1 is None and val2 is None: return True if val1 is None and val2 is not None: return False if val1 is not None and val2 is None: return False try: return np.allclose(val1, val2) except TypeError: return False def compare_objects(obj1, obj2, attrs, numpy_attrs=None): """Compare two PyRETIS objects. This method will compare two PyRETIS objects by checking the equality of the attributes. Some of these attributes might be numpy arrays in which case we use the :py:function:`.numpy_allclose` defined in this module. Parameters ---------- obj1 : object The first object for the comparison. obj2 : object The second object for the comparison. attrs : iterable of strings The attributes to check. numpy_attrs : iterable of strings, optional The subset of attributes which are numpy arrays. Returns ------- out : boolean True if the objects are equal, False otherwise. """ if not obj1.__class__ == obj2.__class__: logger.debug( 'The classes are different %s != %s', obj1.__class__, obj2.__class__ ) return False if not len(obj1.__dict__) == len(obj2.__dict__): logger.debug('Number of attributes differ.') return False # Compare the requested attributes: for key in attrs: try: val1 = getattr(obj1, key) val2 = getattr(obj2, key) except AttributeError: logger.debug('Failed to compare attribute "%s"', key) return False if numpy_attrs and key in numpy_attrs: if not numpy_allclose(val1, val2): logger.debug('Attribute "%s" differ.', key) return False else: if not val1 == val2: logger.debug('Attribute "%s" differ.', key) return False return True
[docs]def null_move(ensemble, cycle): """Perform a null move for a path ensemble. The null move simply consist of accepting the last accepted path again. Parameters ---------- ensemble: dict, it contains: * path_ensemble : object like :py:class:`.PathEnsemble` This is the path ensemble to update with the null move cycle : integer The current cycle number Returns ------- out[0] : boolean Should the path be accepted or not? Here, it's always True since the null move is always accepted. out[1] : object like :py:class:`.PathBase` The unchanged path. out[2] : string The status will here be 'ACC' since we just accept the last accepted path again in this move. """ path_ensemble = ensemble['path_ensemble'] logger.info('Null move for: %s', path_ensemble.ensemble_name) status = 'ACC' path = path_ensemble.last_path if not path.get_move() == 'ld': path.set_move('00') path_ensemble.add_path_data(path, status, cycle=cycle) return True, path, status
[docs]def relative_shoots_select(ensembles, rgen, relative): """Randomly select the ensemble for 'relative' shooting moves. Here we select the ensemble to do the shooting in based on relative probabilities. We draw a random number between 0 and 1 which is used to select the ensemble. Parameters ---------- ensembles : list of objects like :py:class:`.PathEnsemble` This is a list of the ensembles we are using in the RETIS method. rgen : object like :py:class:`.RandomGenerator` This is a random generator. Here we assume that we can call `rgen.rand()` to draw random uniform numbers. relative : list of floats These are the relative probabilities for the ensembles. We assume here that these numbers are normalised. Returns ------- out[0] : integer The index of the path ensemble to shoot in. out[1] : object like :py:class:`.PathEnsemble` The selected path ensemble for shooting. """ freq = rgen.rand() cumulative = 0.0 idx = None for i, path_freq in enumerate(relative): cumulative += path_freq if freq < cumulative: idx = i break try: ensemble = ensembles[idx] except TypeError: raise ValueError('Error in relative shoot frequencies! Aborting!') return idx, ensemble
[docs]def segments_counter(path, interface_l, interface_r, reverse=False): """Count the directional segment between interfaces. Method to count the number of the directional segments of the path, along the orderp, that connect FROM interface_l TO interface_r. Parameters ----------- path : object like :py:class:`.PathBase` This is the input path which segments will be counted. interface_r : float This is the position of the RIGHT interface. interface_l : float This is the position of the LEFT interface. reverse : boolean, optional Check on a reversed path. Returns ------- n_segments : integer Segment counter """ icros, n_segments = -1, 0 for i in range(path.length - 1): op1 = path.phasepoints[i].order[0] op2 = path.phasepoints[i+1].order[0] if reverse and op1 >= interface_r > op2 or\ not reverse and op2 > interface_l >= op1: icros = i if reverse and op1 >= interface_l > op2 or\ not reverse and op2 > interface_r >= op1: if icros != -1: icros = -1 n_segments += 1 return n_segments
[docs]def crossing_counter(path, interface): """Count the crossing to an interfaces. Method to count the crosses of a path over an interface. Parameters ----------- path : object like :py:class:`.PathBase` Input path which will be trimmed. interface : float The position of the interface. Returns ------- cnt : integer Number of crossing of the given interface. """ cnt = 0 for i in range(len(path.phasepoints[:-1])): op1 = path.phasepoints[i].order[0] op2 = path.phasepoints[i+1].order[0] if op2 >= interface > op1 or op1 >= interface > op2: cnt += 1 return cnt
[docs]def crossing_finder(path, interface, last_frame=False): """Find the crossing to an interfaces. Method to select the crosses of a path over an interface. Parameters ----------- path : object like :py:class:`.PathBase` Input path which will be trimmed. interface : float Interface position. last_frame : boolean, optional Determines if the last crossing will be selected or not. Returns ------- ph1, ph2 : snapshots Snapshots to define the randomly picked crossing, one right before and one right after the interface. """ ph1, ph2 = [], [] for i in range(len(path.phasepoints[:-1])): op1 = path.phasepoints[i].order[0] op2 = path.phasepoints[i+1].order[0] if op2 >= interface > op1 or op1 >= interface > op2: ph1.append(path.phasepoints[i]) ph2.append(path.phasepoints[i+1]) if not ph1: return None, None assert ph1, 'No crossing point available' idx = -1 if last_frame else path.rgen.random_integers(0, len(ph1) - 1) return ph1[idx], ph2[idx]
[docs]def trim_path_between_interfaces(path, interface_l, interface_r): """Cut a path between the two interfaces. The method cut a path and keeps only what is within the range (interface_l interface_r). -Be careful, it can provide multiple discontinuous segments- =Be carefull2 consider if you need to make this check left inclusive (as the ensemble should be left inclusive) Parameters ---------- path : object like :py:class:`.PathBase` This is the input path which will be trimmed. interface_r : float This is the position of the RIGHT interface. interface_l : float This is the position of the LEFT interface. Returns ------- new_path : object like :py:class:`.PathBase` This is the output trimmed path. """ new_path = path.empty_path() for phasepoint in path.phasepoints: orderp = phasepoint.order[0] if interface_r > orderp > interface_l: new_path.append(phasepoint) new_path.maxlen = path.maxlen new_path.status = path.status new_path.time_origin = path.time_origin new_path.generated = 'ct' new_path.rgen = path.rgen return new_path
[docs]def select_and_trim_a_segment(path, interface_l, interface_r, segment_to_pick=None): """Cut a directional segment from interface_l to interface_r. It keeps what is within the range [interface_l interface_r) AND the snapshots just after/before the interface. Parameters ---------- path : object like :py:class:`.PathBase` This is the input path which will be trimmed. interface_r : float This is the position of the RIGHT interface. interface_l : float This is the position of the LEFT interface. segment_to_pick : integer (n.b. it starts from 0) This is the segment to be selected, None = random Returns ------- segment : a path segment composed only the snapshots for which orderp is between interface_r and interface_l and the ones right after/before the interfaces. """ key = False segment = path.empty_path() segment_i = -1 if segment_to_pick is None: segment_number = segments_counter(path, interface_l, interface_r) segment_to_pick = path.rgen.random_integers(0, segment_number) for i, phasepoint in enumerate(path.phasepoints[:-1]): op1 = path.phasepoints[i].order[0] op2 = path.phasepoints[i+1].order[0] # NB: these are directional crossing if op2 >= interface_l > op1: # We are in the good region, segment_i if not key: segment_i += 1 key = True if key: if segment_i == segment_to_pick: segment.append(phasepoint) isave = i if op2 >= interface_r > op1: if key and segment_i == segment_to_pick: segment.append(path.phasepoints[i+1]) key = False if segment.length == 1: segment.append(path.phasepoints[isave + 1]) segment.maxlen = path.maxlen segment.status = path.status segment.time_origin = path.time_origin segment.generated = 'sg' segment.rgen = path.rgen return segment
def wirefence_weight_and_pick(path, intf_l, intf_r, return_seg=False): """Calculate the weight of a path generated by the Wire Fence move. The WF path weight is determined by the total sum of valid sub-path phasepoints, where valid WF subpaths are defined as intf_l-intf_l, intf_l-intf_r and intf_r-intf_l sub-paths. if return_seg = True, a random valid WF sub-path is also returned. Parameters ---------- path : object like :py:class:`.PathBase` This is the input path which will be trimmed. intf_r : float This is the position of the RIGHT interface. intf_l : float This is the position of the LEFT interface. return_seg : boolean, optional Determines if a random valid WF sub-path is returned or not. Returns ------- n_frames: int The weight of the path. segment : object like :py:class:`.PathBase` False (if return_seg=False) else a random valid WF sub-path. """ key_l, key_r = False, False path_arr = [] segment = False for i in range(len(path.phasepoints[:-1])): op1 = path.phasepoints[i].order[0] op2 = path.phasepoints[i+1].order[0] if (op1 < intf_l and op2 >= intf_r) or \ (op2 < intf_l and op1 >= intf_r): pass elif op2 >= intf_l > op1 and not key_l: isave, key_l = i, True elif op2 < intf_r <= op1 and not key_r: isave, key_r = i, True elif key_r and op2 >= intf_r > op1: key_l, key_r = False, False elif True in (key_l, key_r) and (op2 < intf_l <= op1 or op2 >= intf_r > op1): key_l, key_r = False, False path_arr.append((isave, i+1, i-isave)) n_frames = sum([i[2] for i in path_arr]) if path_arr else 0 if return_seg and n_frames: sum_frames = 0 subpath_select = path.rgen.rand() for i in path_arr: sum_frames += i[2] if sum_frames/n_frames >= subpath_select: segment = path.empty_path() for j in range(i[0], i[1]+1): segment.append(path.phasepoints[j]) segment.maxlen = path.maxlen segment.status = path.status segment.time_origin = path.time_origin segment.generated = 'ct' segment.rgen = path.rgen break return n_frames, segment
[docs]def compute_weight(path, interfaces, move): """Compute the High Acceptance path weight after a MC move. This function computes the weights that will be used in the computation of the P cross. This trick allows the use of the High Acceptance version of Stone Skipping or Wire Fencing, allowing the acceptance of B to A paths. The drawback is that swapping moves needs to account also for this different weights. The weight 1 will be returned for a path not generated by SS or WF. Parameters ---------- path : object like :py:class:`.PathBase` This is the input path which will be checked. interfaces : list/tuple of floats These are the interface positions of the form ``[left, middle, right]``. move : string, optional The MC move to compute the weights for. Returns ------- out[0] : float The weight of the path. """ weight = 1. if move == 'ss': weight = 1.*crossing_counter(path, interfaces[1]) elif move == 'wf': wf_weight, _ = wirefence_weight_and_pick(path, interfaces[1], interfaces[2]) weight = 1.*wf_weight if path.get_start_point(interfaces[0], interfaces[2]) != \ path.get_end_point(interfaces[0], interfaces[2]): if move in ('ss', 'wf'): weight *= 2 return weight
[docs]def priority_checker(ensembles, settings): """Determine the shooting ensemble during a RETIS simulation. Here we check whether to do priority shooting or not. If True, we either shoot from the ensemble with the fewest paths or ensemble [0^-] if all ensembles have the same no. of paths. Parameters ---------- ensembles : list of dictionaries of objects Lit of dict of ensembles we are using in a path method. settings : dict This dict contains the settings for the RETIS method. Returns ------- out[0] : list Returns a list of boolean dictating whether certain ensembles are to be skipped or not. """ priority = settings.get('simulation', {}).get('priority_shooting', False) prio_skip = [False] * len(ensembles) if priority: lst_cycles = [ens['path_ensemble'].nstats['npath'] for ens in ensembles] # Are all ensemble npath values the same? if any(i != lst_cycles[0] for i in lst_cycles): # If not, let's make a list: prio_skip = [i == max(lst_cycles) for i in lst_cycles] return prio_skip
[docs]def soft_partial_exit(exe_path=''): """Check the presence of the EXIT file. Parameters ---------- exe_path: string, optional Path for the EXIT file. Returns ------- out : boolean True if EXIT is present. False if EXIT in not present. """ exit_file = 'EXIT' if exe_path: exit_file = os.path.join(exe_path, exit_file) if os.path.isfile(exit_file): logger.info('Exit file found - will exit between steps.') print_to_screen('Exit file found - will exit between steps.', level='warning') return True return False