rJKOtt.TensorTrainSolver module

class rJKOtt.TensorTrainSolver.TensorTrainSolver(rho_infty, rho_start, solver_params=None, posterior_cache_size=1000000)[source]

Bases: object

Approximate a given distribution rho_infty by entropy-regularized JKO scheme with finite-difference spatial discretization and Tensor-Train decomposition.

Parameters:
  • rho_infty (Callable) – function, proportional to the probability density of the posterior. Should have signature (N_samples, dim) -> (dim, )

  • rho_start (TensorTrainDistribution) – starting distribution, discretized on grid. The posterior will be approximated on the same grid.

  • posterior_cache_size (int) – maximal size of the cache

  • solver_params (TensorTrainSolverParams) – self-explanatory; if not given, uses defaults defined in TensorTrainSolverParams class

grid

self-explanatory

Type:

Grid

n_calls

tracks the amount of real calls to the posterior durnig the solve

Type:

int

n_cache

tracks the amount of posterior calls loaded from cache during the solve

Type:

int

Ts

timesteps taken

Type:

List[float]

betas

regularization factors at steps taken

Type:

List[float]

get_current_distribution()[source]
Return type:

TensorTrainDistribution

sample(sample_x0)[source]

Starting from the sample from the initial distribution, propagate it through the fitted dynamics and return a sample from the distribution on the last step.

Parameters:

sample_x0 (numpy.ndarray) – sample from the initial distribution, should have shape (n_samples, dim)

Returns:

sample from the distribution of the last step

Return type:

np.ndarray

step(beta, T, save_history=False)[source]

Perform a regularized JKO step.

TODO implement more detailed description, with math?!

Parameters:
  • beta (numpy.float64) – regularization factor

  • T (numpy.float64) – timestep

  • save_history (bool) – if True, returns the intermediate ranks, errors etc during the fixed-point iterations

Return type:

Tuple[List[numpy.ndarray], List[numpy.ndarray]]

class rJKOtt.TensorTrainSolver.TensorTrainSolverParams(cross_nfev_with_posterior=400000, cross_nfev_no_posterior=1000000, cross_rel_diff=1e-06, cross_use_validation=False, cross_validation_rtol=1e-06, cross_n_validation=1000, max_rank_eta=5, max_rank_hat_eta=5, max_rank_density=20, trunc_tol_hat_eta=1e-13, trunc_tol_eta=1e-13, trunc_tol_density=1e-13, fp_method='2_anderson', fp_relaxation=0.9, fp_stopping_rtol=1e-08, fp_max_iter=100, zero_threshold=1e-16, sampling_ode_rtol=0.001, sampling_ode_atol=1e-06, sampling_sde_fraction=0.005, sampling_n_euler_maruyama_steps=50)[source]

Bases: object

Store the solver params such as tolerances and number of iterations. Default values provided.

These parameters can be changed between or possibly even during the fixed-point iterations.

Parameters:
  • cross_nfev_with_posterior (int)

  • cross_nfev_no_posterior (int)

  • cross_rel_diff (float)

  • cross_use_validation (bool)

  • cross_validation_rtol (float)

  • cross_n_validation (int)

  • max_rank_eta (int)

  • max_rank_hat_eta (int)

  • max_rank_density (int)

  • trunc_tol_hat_eta (float)

  • trunc_tol_eta (float)

  • trunc_tol_density (float)

  • fp_method (Literal['picard', '2_anderson', 'aitken'])

  • fp_relaxation (float)

  • fp_stopping_rtol (float)

  • fp_max_iter (int)

  • zero_threshold (float)

  • sampling_ode_rtol (float)

  • sampling_ode_atol (float)

  • sampling_sde_fraction (float)

  • sampling_n_euler_maruyama_steps (int)

cross_n_validation: int = 1000

Number of indices in the validation subset

cross_nfev_no_posterior: int = 1000000

Number of function evaluation in TT-cross when real posterior calls are not required (initial condition and KL estimation)

cross_nfev_with_posterior: int = 400000

Number of function evaluation in TT-cross when real posterior calls are required (terminal condition and KL)

cross_rel_diff: float = 1e-06

if the solution relative change is less, cross stops

Type:

Cross stopping criterion

cross_use_validation: bool = False

If to use error on validation subset stopping criterion during the cross approximation iteration

cross_validation_rtol: float = 1e-06

if error on validation subset is less, cross stops

Type:

Cross stopping criterion

fp_max_iter: int = 100

Fixed-point solution will terminate after doing this amount of iterations regardless of the convergence

fp_method: Literal['picard', '2_anderson', 'aitken'] = '2_anderson'

Method of solving the fixed-point problem at a step.

  • picard : The simplest method, next iterate is a linear combination of current iterate and the operator at current value with fixed relaxation factor

  • aitken : Same as picard, but the relaxation is selected adaptively. Supposed to be more robust

  • 2_anderson : Saves the history of 2 previous iterates and function values and generates a next iterate based on a minimization subproblem.

fp_relaxation: float = 0.9

Relaxation for the fixed-point method

fp_stopping_rtol: float = 1e-08

Relative tolerance of the fixed-point iteration; if \(\frac{\|x_k - G(x_k)\|_2}{\|x_k\|_2}\) is smaller than this value, the iteration terminates

property max_rank: int
max_rank_density: int = 20

Maximal TT-rank when representing the next distribution

max_rank_eta: int = 5

Maximal TT-rank when representing the entropic potential \(\eta\)

max_rank_hat_eta: int = 5

Maximal TT-rank when representing the entropic potential \(\hat\eta\)

sampling_n_euler_maruyama_steps: int = 50

Number of Euler-Maruyama steps used in the SDE solution

sampling_ode_atol: float = 1e-06

Tolerance of solving the sampling ODE

sampling_ode_rtol: float = 0.001

Tolerance of solving the sampling ODE

sampling_sde_fraction: float = 0.005

Fraction of each timestep to be solved with the SDE dynamic

property trunc_tol: int
trunc_tol_density: float = 1e-13

TT truncation tolerance when computing the next density

trunc_tol_eta: float = 1e-13

TT truncation tolerance for potential \(\eta\) (i.e. in the terminal condition)

trunc_tol_hat_eta: float = 1e-13

TT truncation tolerance for potential \(\hat\eta\) (i.e. in the initial condition)

zero_threshold: float = 1e-16

Safeguard when computing small values that must be positive, but could go below zero due to numerical erros