Source code for memtorch.bh.StochasticParameter

import copy
import inspect
import math

import torch

import memtorch


[docs]def StochasticParameter( distribution=torch.distributions.normal.Normal, min=0, max=float("Inf"), function=True, **kwargs ): """Method to model a stochastic parameter. Parameters ---------- distribution : torch.distributions torch distribution. min : float Minimum value to sample. max: float Maximum value to sample. function : bool A sampled value is returned (False). A function to return a sampled value or mean is returned (True). Returns ------- float or function A sampled value of the stochatic parameter, or a sample-value generator. """ assert issubclass( distribution, torch.distributions.distribution.Distribution ), "Distribution is not in torch.distributions." for arg in inspect.signature(distribution).parameters.values(): if arg.name not in kwargs and arg.name != "validate_args": raise Exception("Argument %s is required for %s" % (arg.name, distribution)) m = distribution(**kwargs) def f(return_mean=False): """Method to return a sampled value or the mean of the stochatic parameters. Parameters ---------- return_mean : bool Return the mean value of the stochatic parameter (True). Return a sampled value of the stochatic parameter (False). Returns ------- float The mean value, or a sampled value of the stochatic parameter. """ if return_mean: return m.mean else: return m.sample().clamp(min, max).item() if function: return f else: return f()
[docs]def unpack_parameters(local_args, r_rel_tol=None, r_abs_tol=None, resample_threshold=5): """Method to sample from stochastic sample-value generators Parameters ---------- local_args : locals() Local arguments with stochastic sample-value generators from which to sample from. r_rel_tol : float Relative threshold tolerance. r_abs_tol : float Absolute threshold tolerance. resample_threshold : int Number of times to resample r_off and r_on when their proximity is within the threshold tolerance before raising an exception. Returns ------- ** locals() with sampled stochastic parameters. """ assert ( r_rel_tol is None or r_abs_tol is None ), "r_rel_tol or r_abs_tol must be None." assert ( type(resample_threshold) == int and resample_threshold >= 0 ), "resample_threshold must be of type int and >= 0." if "reference" in local_args: return_mean = True else: return_mean = False local_args_copy = copy.deepcopy(local_args) for arg in local_args: if callable(local_args[arg]) and "__" not in str(arg): local_args[arg] = local_args[arg](return_mean=return_mean) args = Dict2Obj(local_args) if hasattr(args, "r_off") and hasattr(args, "r_on"): resample_idx = 0 r_off_generator = local_args_copy["r_off"] r_on_generator = local_args_copy["r_on"] while True: if r_abs_tol is None and r_rel_tol is not None: if not math.isclose(args.r_off, args.r_on, rel_tol=r_rel_tol): break elif r_rel_tol is None and r_abs_tol is not None: if not math.isclose(args.r_off, args.r_on, abs_tol=r_abs_tol): break else: if not math.isclose(args.r_off, args.r_on): break if callable(r_off_generator) and callable(r_on_generator): args.r_off = copy.deepcopy(r_off_generator)(return_mean=return_mean) args.r_on = copy.deepcopy(r_on_generator)(return_mean=return_mean) else: raise Exception( "Resample threshold exceeded (deterministic values used)." ) resample_idx += 1 if resample_idx > resample_threshold: raise Exception("Resample threshold exceeded.") return args
[docs]class Dict2Obj(object): """Class used to instantiate a object given a dictionary.""" def __init__(self, dictionary): for key in dictionary: if key == "__class__": continue setattr(self, key, dictionary[key])