# -*- coding: utf-8 -*-
# Copyright (c) 2023, PyRETIS Development Team.
# Distributed under the LGPLv2.1+ License. See LICENSE for more info.
"""Definition of some common methods that might be useful.
Important methods defined here
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
bit_fat_comparer (:py:func`.big_fat_comparer`)
Method to compare two nested list/dictionaries.
compare_objects (:py:func`.compare_objects`)
Method to compare two PyRETIS objects.
counter (:py:func`.counter`)
Function to count the number of iterations.
crossing_counter (:py:func`.crossing_counter`)
Function to count the crossing of a path on an interface.
crossing_finder (:py:func`.crossing_finder`)
Function to get the shooting points of the crossing of a path
on an interface.
import_from (:py:func:`.import_from`)
A method to dynamically import method/classes etc. from user
specified modules.
inspect_function (:py:func:`.inspect_function`)
A method to obtain information about arguments, keyword arguments
for functions.
initiate_instance (:py:func:`.initiate_instance`)
Method to initiate a class with optional arguments.
generic_factory (:py:func:`.generic_factory`)
Create instances of classes based on settings.
null_move (:py:func`.compare_objects`)
Method to do not move.
priority_checker (:py:func`.priority_checker`)
Method to check ensemble to prioritize.
relative_shoots_select (:py:func`.compare_objects`)
Method to select the shooting ensemble.
segments_counter (:py:func`.segments_counter`)
Function that counts the number of segments between two interfaces.
trim_path_between_interfaces (:py:func`.trim_path_between_interfaces`)
Function to trim a path between interfaces.
select_and_trim_a_segment (:py:func`.select_and_trim_a_segment`)
Function to trim a path between interfaces plus the two external points.
compute_weight (:py:func:`.compute_weight`)
A method to compute the statistical weight of a path generated by a
Stone Skipping and Wire Fencing move.
soft_partial_exit (:py:func`.soft_partial_exit`)
Function that check the presence of the EXIT file,
and kindly stops the iterator.
"""
import logging
import inspect
import importlib
import os
import sys
import numpy as np
from pyretis.inout import print_to_screen
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
logger.addHandler(logging.NullHandler())
__all__ = ['import_from', 'inspect_function', 'initiate_instance',
'generic_factory', 'crossing_counter', 'crossing_finder',
'segments_counter', 'select_and_trim_a_segment', 'counter',
'trim_path_between_interfaces', 'big_fat_comparer',
'soft_partial_exit', 'null_move', 'relative_shoots_select',
'compute_weight', 'priority_checker']
[docs]def counter():
"""Return how many times this function is called."""
counter.count = 0 if not hasattr(counter, 'count') else counter.count + 1
return counter.count
[docs]def import_from(module_path, function_name):
"""Import a method/class from a module.
This method will dynamically import a specified method/object
from a module and return it. If the module can not be imported or
if we can't find the method/class in the module we will raise
exceptions.
Parameters
----------
module_path : string
The path/filename to load from.
function_name : string
The name of the method/class to load.
Returns
-------
out : object
The thing we managed to import.
"""
try:
module_name = os.path.basename(module_path)
module_name = os.path.splitext(module_name)[0]
spec = importlib.util.spec_from_file_location(module_name,
module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
sys.modules[module_name] = module
logger.debug('Imported module: %s', module)
return getattr(module, function_name)
except (ImportError, IOError):
msg = f'Could not import module: {module_path}'
logger.critical(msg)
except AttributeError:
msg = f'Could not import "{function_name}" from "{module_path}"'
logger.critical(msg)
raise ValueError(msg)
[docs]def _arg_kind(arg):
"""Determine kind for a given argument.
This method will help :py:func:`.inspect_function` to determine
the correct kind for arguments.
Parameters
----------
arg : object like :py:class:`inspect.Parameter`
The argument we will determine the type of.
Returns
-------
out : string
A string we use for determine the kind.
"""
kind = None
if arg.kind == arg.POSITIONAL_OR_KEYWORD:
if arg.default is arg.empty:
kind = 'args'
else:
kind = 'kwargs'
elif arg.kind == arg.POSITIONAL_ONLY:
kind = 'args'
elif arg.kind == arg.VAR_POSITIONAL:
kind = 'varargs'
elif arg.kind == arg.VAR_KEYWORD:
kind = 'keywords'
elif arg.kind == arg.KEYWORD_ONLY:
# We treat these as keyword arguments:
kind = 'kwargs'
return kind
[docs]def big_fat_comparer(any1, any2, hard=False):
"""Check if two dictionary are the same, regardless their complexity.
Parameters
----------
any1 : anything
any2 : anything
hard : boolean, optional
Raise ValueError if any1 and any2 are different
Returns
-------
out : boolean
True if any1 = any2, false otherwise
"""
if type(any1) is not type(any2):
if hard:
raise ValueError('Fail type', any1, any2)
return False
if isinstance(any1, (list, tuple)):
if len(any1) != len(any2):
if hard:
raise ValueError('Fail list length', any1, any2)
return False
for key1, key2 in zip(any1, any2):
if not big_fat_comparer(key1, key2, hard):
if hard:
raise ValueError('Fail item in list',
any1, any2) # pragma: no cover
return False
elif isinstance(any1, np.ndarray):
if any1.shape != any2.shape:
if hard:
raise ValueError('Fail np array shape', any1, any2)
return False
for key1, key2 in zip(np.nditer(any1), np.nditer(any2)):
if not (key1 == key2).all():
if hard:
raise ValueError('Fail np array item', any1, any2)
return False
elif isinstance(any1, dict):
for key in any1:
if key not in any2:
if hard:
raise ValueError('Fail dict', any1, any2)
return False
if not isinstance(any1[key], type(any2[key])):
if hard:
raise ValueError('Fail types', any1[key], any2[key])
return False
if isinstance(any1[key], (dict, list, tuple, np.ndarray)):
if not big_fat_comparer(any1[key], any2[key], hard):
if hard:
raise ValueError('Fail item',
any1[key],
any2[key]) # pragma: no cover
return False
else:
if any1[key] != any2[key]:
if hard:
raise ValueError('Fail item', any1[key], any2[key])
return False
for key in any2:
if key not in any1:
if hard:
raise ValueError('Fail item', any1, any2)
return False
else:
if any1 != any2:
if hard:
raise ValueError('Fail item', any1, any2)
return False
return True
[docs]def inspect_function(function):
"""Return arguments/kwargs of a given function.
This method is intended for use where we are checking that we can
call a certain function. This method will return arguments and
keyword arguments a function expects. This method may be fragile -
we assume here that we are not really interested in args and
kwargs and we do not look for more information about these here.
Parameters
----------
function : callable
The function to inspect.
Returns
-------
out : dict
A dict with the arguments, the following keys are defined:
* `args` : list of the positional arguments
* `kwargs` : list of keyword arguments
* `varargs` : list of arguments
* `keywords` : list of keyword arguments
"""
out = {'args': [], 'kwargs': [],
'varargs': [], 'keywords': []}
arguments = inspect.signature(function) # pylint: disable=no-member
for arg in arguments.parameters.values():
kind = _arg_kind(arg)
if kind is not None:
out[kind].append(arg.name)
else: # pragma: no cover
logger.critical('Unknown variable kind "%s" for "%s"',
arg.kind, arg.name)
return out
[docs]def _pick_out_arg_kwargs(klass, settings):
"""Pick out arguments for a class from settings.
Parameters
----------
klass : class
The class to initiate.
settings : dict
Positional and keyword arguments to pass to `klass.__init__()`.
Returns
-------
out[0] : list
A list of the positional arguments.
out[1] : dict
The keyword arguments.
"""
info = inspect_function(klass.__init__)
used, args, kwargs = set(), [], {}
for arg in info['args']:
if arg == 'self':
continue
try:
args.append(settings[arg])
used.add(arg)
except KeyError:
msg = f'Required argument "{arg}" for "{klass}" not found!'
raise ValueError(msg)
for arg in info['kwargs']:
if arg == 'self':
continue
if arg in settings:
kwargs[arg] = settings[arg]
return args, kwargs
[docs]def initiate_instance(klass, settings):
"""Initialise a class with optional arguments.
Parameters
----------
klass : class
The class to initiate.
settings : dict
Positional and keyword arguments to pass to `klass.__init__()`.
Returns
-------
out : instance of `klass`
Here, we just return the initiated instance of the given class.
"""
args, kwargs = _pick_out_arg_kwargs(klass, settings)
# Ready to initiate:
msg = 'Initiated "%s" from "%s" %s'
name = klass.__name__
mod = klass.__module__
if not args:
if not kwargs:
logger.debug(msg, name, mod, 'without arguments.')
return klass()
logger.debug(msg, name, mod, 'with keyword arguments.')
return klass(**kwargs)
if not kwargs:
logger.debug(msg, name, mod, 'with positional arguments.')
return klass(*args)
logger.debug(msg, name, mod,
'with positional and keyword arguments.')
return klass(*args, **kwargs)
[docs]def generic_factory(settings, object_map, name='generic'):
"""Create instances of classes based on settings.
This method is intended as a semi-generic factory for creating
instances of different objects based on simulation input settings.
The input settings define what classes should be created and
the object_map defines a mapping between settings and the
class.
Parameters
----------
settings : dict
This defines how we set up and select the order parameter.
object_map : dict
Definitions on how to initiate the different classes.
name : string, optional
Short name for the object type. Only used for error messages.
Returns
-------
out : instance of a class
The created object, in case we were successful. Otherwise we
return none.
"""
try:
klass = settings['class'].lower()
except KeyError:
msg = 'No class given for %s -- could not create object!'
logger.critical(msg, name)
return None
if klass not in object_map:
logger.critical('Could not create unknown class "%s" for %s',
settings['class'], name)
return None
cls = object_map[klass]['cls']
return initiate_instance(cls, settings)
def numpy_allclose(val1, val2):
"""Compare two values with allclose from numpy.
Here, we allow for one, or both, of the values to be None.
Note that if val1 == val2 but are not of a type known to
numpy, the returned value will be False.
Parameters
----------
val1 : np.array
The variable in the comparison.
val2 : np.array
The second variable in the comparison.
Returns
-------
out : boolean
True if the values are equal, False otherwise.
"""
if val1 is None and val2 is None:
return True
if val1 is None and val2 is not None:
return False
if val1 is not None and val2 is None:
return False
try:
return np.allclose(val1, val2)
except TypeError:
return False
def compare_objects(obj1, obj2, attrs, numpy_attrs=None):
"""Compare two PyRETIS objects.
This method will compare two PyRETIS objects by checking
the equality of the attributes. Some of these attributes
might be numpy arrays in which case we use the
:py:function:`.numpy_allclose` defined in this module.
Parameters
----------
obj1 : object
The first object for the comparison.
obj2 : object
The second object for the comparison.
attrs : iterable of strings
The attributes to check.
numpy_attrs : iterable of strings, optional
The subset of attributes which are numpy arrays.
Returns
-------
out : boolean
True if the objects are equal, False otherwise.
"""
if not obj1.__class__ == obj2.__class__:
logger.debug(
'The classes are different %s != %s',
obj1.__class__, obj2.__class__
)
return False
if not len(obj1.__dict__) == len(obj2.__dict__):
logger.debug('Number of attributes differ.')
return False
# Compare the requested attributes:
for key in attrs:
try:
val1 = getattr(obj1, key)
val2 = getattr(obj2, key)
except AttributeError:
logger.debug('Failed to compare attribute "%s"', key)
return False
if numpy_attrs and key in numpy_attrs:
if not numpy_allclose(val1, val2):
logger.debug('Attribute "%s" differ.', key)
return False
else:
if not val1 == val2:
logger.debug('Attribute "%s" differ.', key)
return False
return True
[docs]def null_move(ensemble, cycle):
"""Perform a null move for a path ensemble.
The null move simply consist of accepting the last accepted path
again.
Parameters
----------
ensemble: dict, it contains:
* path_ensemble : object like :py:class:`.PathEnsemble`
This is the path ensemble to update with the null move
cycle : integer
The current cycle number
Returns
-------
out[0] : boolean
Should the path be accepted or not? Here, it's always True
since the null move is always accepted.
out[1] : object like :py:class:`.PathBase`
The unchanged path.
out[2] : string
The status will here be 'ACC' since we just accept
the last accepted path again in this move.
"""
path_ensemble = ensemble['path_ensemble']
logger.info('Null move for: %s', path_ensemble.ensemble_name)
status = 'ACC'
path = path_ensemble.last_path
if not path.get_move() == 'ld':
path.set_move('00')
path_ensemble.add_path_data(path, status, cycle=cycle)
return True, path, status
[docs]def relative_shoots_select(ensembles, rgen, relative):
"""Randomly select the ensemble for 'relative' shooting moves.
Here we select the ensemble to do the shooting in based on relative
probabilities. We draw a random number between 0 and 1 which is used to
select the ensemble.
Parameters
----------
ensembles : list of objects like :py:class:`.PathEnsemble`
This is a list of the ensembles we are using in the RETIS method.
rgen : object like :py:class:`.RandomGenerator`
This is a random generator. Here we assume that we can call
`rgen.rand()` to draw random uniform numbers.
relative : list of floats
These are the relative probabilities for the ensembles. We
assume here that these numbers are normalised.
Returns
-------
out[0] : integer
The index of the path ensemble to shoot in.
out[1] : object like :py:class:`.PathEnsemble`
The selected path ensemble for shooting.
"""
freq = rgen.rand()
cumulative = 0.0
idx = None
for i, path_freq in enumerate(relative):
cumulative += path_freq
if freq < cumulative:
idx = i
break
try:
ensemble = ensembles[idx]
except TypeError:
raise ValueError('Error in relative shoot frequencies! Aborting!')
return idx, ensemble
[docs]def segments_counter(path, interface_l, interface_r, reverse=False):
"""Count the directional segment between interfaces.
Method to count the number of the directional segments of the path,
along the orderp, that connect FROM interface_l TO interface_r.
Parameters
-----------
path : object like :py:class:`.PathBase`
This is the input path which segments will be counted.
interface_r : float
This is the position of the RIGHT interface.
interface_l : float
This is the position of the LEFT interface.
reverse : boolean, optional
Check on a reversed path.
Returns
-------
n_segments : integer
Segment counter
"""
icros, n_segments = -1, 0
for i in range(path.length - 1):
op1 = path.phasepoints[i].order[0]
op2 = path.phasepoints[i+1].order[0]
if reverse and op1 >= interface_r > op2 or\
not reverse and op2 > interface_l >= op1:
icros = i
if reverse and op1 >= interface_l > op2 or\
not reverse and op2 > interface_r >= op1:
if icros != -1:
icros = -1
n_segments += 1
return n_segments
[docs]def crossing_counter(path, interface):
"""Count the crossing to an interfaces.
Method to count the crosses of a path over an interface.
Parameters
-----------
path : object like :py:class:`.PathBase`
Input path which will be trimmed.
interface : float
The position of the interface.
Returns
-------
cnt : integer
Number of crossing of the given interface.
"""
cnt = 0
for i in range(len(path.phasepoints[:-1])):
op1 = path.phasepoints[i].order[0]
op2 = path.phasepoints[i+1].order[0]
if op2 >= interface > op1 or op1 >= interface > op2:
cnt += 1
return cnt
[docs]def crossing_finder(path, interface, last_frame=False):
"""Find the crossing to an interfaces.
Method to select the crosses of a path over an interface.
Parameters
-----------
path : object like :py:class:`.PathBase`
Input path which will be trimmed.
interface : float
Interface position.
last_frame : boolean, optional
Determines if the last crossing will be selected or not.
Returns
-------
ph1, ph2 : snapshots
Snapshots to define the randomly picked crossing,
one right before and one right after the interface.
"""
ph1, ph2 = [], []
for i in range(len(path.phasepoints[:-1])):
op1 = path.phasepoints[i].order[0]
op2 = path.phasepoints[i+1].order[0]
if op2 >= interface > op1 or op1 >= interface > op2:
ph1.append(path.phasepoints[i])
ph2.append(path.phasepoints[i+1])
if not ph1:
return None, None
assert ph1, 'No crossing point available'
idx = -1 if last_frame else path.rgen.random_integers(0, len(ph1) - 1)
return ph1[idx], ph2[idx]
[docs]def trim_path_between_interfaces(path, interface_l, interface_r):
"""Cut a path between the two interfaces.
The method cut a path and keeps only what is within the range
(interface_l interface_r).
-Be careful, it can provide multiple discontinuous segments-
=Be carefull2 consider if you need to make this check left inclusive
(as the ensemble should be left inclusive)
Parameters
----------
path : object like :py:class:`.PathBase`
This is the input path which will be trimmed.
interface_r : float
This is the position of the RIGHT interface.
interface_l : float
This is the position of the LEFT interface.
Returns
-------
new_path : object like :py:class:`.PathBase`
This is the output trimmed path.
"""
new_path = path.empty_path()
for phasepoint in path.phasepoints:
orderp = phasepoint.order[0]
if interface_r > orderp > interface_l:
new_path.append(phasepoint)
new_path.maxlen = path.maxlen
new_path.status = path.status
new_path.time_origin = path.time_origin
new_path.generated = 'ct'
new_path.rgen = path.rgen
return new_path
[docs]def select_and_trim_a_segment(path, interface_l, interface_r,
segment_to_pick=None):
"""Cut a directional segment from interface_l to interface_r.
It keeps what is within the range [interface_l interface_r)
AND the snapshots just after/before the interface.
Parameters
----------
path : object like :py:class:`.PathBase`
This is the input path which will be trimmed.
interface_r : float
This is the position of the RIGHT interface.
interface_l : float
This is the position of the LEFT interface.
segment_to_pick : integer (n.b. it starts from 0)
This is the segment to be selected, None = random
Returns
-------
segment : a path segment composed only the snapshots for which
orderp is between interface_r and interface_l and the
ones right after/before the interfaces.
"""
key = False
segment = path.empty_path()
segment_i = -1
if segment_to_pick is None:
segment_number = segments_counter(path, interface_l, interface_r)
segment_to_pick = path.rgen.random_integers(0, segment_number)
for i, phasepoint in enumerate(path.phasepoints[:-1]):
op1 = path.phasepoints[i].order[0]
op2 = path.phasepoints[i+1].order[0]
# NB: these are directional crossing
if op2 >= interface_l > op1:
# We are in the good region, segment_i
if not key:
segment_i += 1
key = True
if key:
if segment_i == segment_to_pick:
segment.append(phasepoint)
isave = i
if op2 >= interface_r > op1:
if key and segment_i == segment_to_pick:
segment.append(path.phasepoints[i+1])
key = False
if segment.length == 1:
segment.append(path.phasepoints[isave + 1])
segment.maxlen = path.maxlen
segment.status = path.status
segment.time_origin = path.time_origin
segment.generated = 'sg'
segment.rgen = path.rgen
return segment
def wirefence_weight_and_pick(path, intf_l, intf_r, return_seg=False):
"""Calculate the weight of a path generated by the Wire Fence move.
The WF path weight is determined by the total sum of valid sub-path
phasepoints, where valid WF subpaths are defined as intf_l-intf_l,
intf_l-intf_r and intf_r-intf_l sub-paths.
if return_seg = True, a random valid WF sub-path is also returned.
Parameters
----------
path : object like :py:class:`.PathBase`
This is the input path which will be trimmed.
intf_r : float
This is the position of the RIGHT interface.
intf_l : float
This is the position of the LEFT interface.
return_seg : boolean, optional
Determines if a random valid WF sub-path is returned or not.
Returns
-------
n_frames: int
The weight of the path.
segment : object like :py:class:`.PathBase`
False (if return_seg=False) else a random valid WF sub-path.
"""
key_l, key_r = False, False
path_arr = []
segment = False
for i in range(len(path.phasepoints[:-1])):
op1 = path.phasepoints[i].order[0]
op2 = path.phasepoints[i+1].order[0]
if (op1 < intf_l and op2 >= intf_r) or \
(op2 < intf_l and op1 >= intf_r):
pass
elif op2 >= intf_l > op1 and not key_l:
isave, key_l = i, True
elif op2 < intf_r <= op1 and not key_r:
isave, key_r = i, True
elif key_r and op2 >= intf_r > op1:
key_l, key_r = False, False
elif True in (key_l, key_r) and (op2 < intf_l <= op1 or
op2 >= intf_r > op1):
key_l, key_r = False, False
path_arr.append((isave, i+1, i-isave))
n_frames = sum([i[2] for i in path_arr]) if path_arr else 0
if return_seg and n_frames:
sum_frames = 0
subpath_select = path.rgen.rand()
for i in path_arr:
sum_frames += i[2]
if sum_frames/n_frames >= subpath_select:
segment = path.empty_path()
for j in range(i[0], i[1]+1):
segment.append(path.phasepoints[j])
segment.maxlen = path.maxlen
segment.status = path.status
segment.time_origin = path.time_origin
segment.generated = 'ct'
segment.rgen = path.rgen
break
return n_frames, segment
[docs]def compute_weight(path, interfaces, move):
"""Compute the High Acceptance path weight after a MC move.
This function computes the weights that will be used in the
computation of the P cross. This trick allows the use of
the High Acceptance version of Stone Skipping or Wire Fencing,
allowing the acceptance of B to A paths.
The drawback is that swapping moves needs
to account also for this different weights.
The weight 1 will be returned for a path not generated by SS or WF.
Parameters
----------
path : object like :py:class:`.PathBase`
This is the input path which will be checked.
interfaces : list/tuple of floats
These are the interface positions of the form
``[left, middle, right]``.
move : string, optional
The MC move to compute the weights for.
Returns
-------
out[0] : float
The weight of the path.
"""
weight = 1.
if move == 'ss':
weight = 1.*crossing_counter(path, interfaces[1])
elif move == 'wf':
wf_weight, _ = wirefence_weight_and_pick(path,
interfaces[1],
interfaces[2])
weight = 1.*wf_weight
if path.get_start_point(interfaces[0], interfaces[2]) != \
path.get_end_point(interfaces[0], interfaces[2]):
if move in ('ss', 'wf'):
weight *= 2
return weight
[docs]def priority_checker(ensembles, settings):
"""Determine the shooting ensemble during a RETIS simulation.
Here we check whether to do priority shooting or not. If True,
we either shoot from the ensemble with the fewest paths or
ensemble [0^-] if all ensembles have the same no. of paths.
Parameters
----------
ensembles : list of dictionaries of objects
Lit of dict of ensembles we are using in a path method.
settings : dict
This dict contains the settings for the RETIS method.
Returns
-------
out[0] : list
Returns a list of boolean dictating whether certain
ensembles are to be skipped or not.
"""
priority = settings.get('simulation', {}).get('priority_shooting', False)
prio_skip = [False] * len(ensembles)
if priority:
lst_cycles = [ens['path_ensemble'].nstats['npath']
for ens in ensembles]
# Are all ensemble npath values the same?
if any(i != lst_cycles[0] for i in lst_cycles):
# If not, let's make a list:
prio_skip = [i == max(lst_cycles) for i in lst_cycles]
return prio_skip
[docs]def soft_partial_exit(exe_path=''):
"""Check the presence of the EXIT file.
Parameters
----------
exe_path: string, optional
Path for the EXIT file.
Returns
-------
out : boolean
True if EXIT is present.
False if EXIT in not present.
"""
exit_file = 'EXIT'
if exe_path:
exit_file = os.path.join(exe_path, exit_file)
if os.path.isfile(exit_file):
logger.info('Exit file found - will exit between steps.')
print_to_screen('Exit file found - will exit between steps.',
level='warning')
return True
return False