rJKOtt.TensorTrainSolver module

class rJKOtt.TensorTrainSolver.TensorTrainSolver(rho_infty, rho_start, solver_params=None, posterior_cache_size=1000000)[source]

Bases: object

Approximate a given distribution rho_infty by entropy-regularized JKO scheme with finite-difference spatial discretization and Tensor-Train decomposition.

  • rho_infty (Callable) – function, proportional to the probability density of the posterior. Should have signature (N_samples, dim) -> (dim, )

  • rho_start (TensorTrainDistribution) – starting distribution, discretized on grid. The posterior will be approximated on the same grid.

  • posterior_cache_size (int) – maximal size of the cache

  • solver_params (TensorTrainSolverParams) – self-explanatory; if not given, uses defaults defined in TensorTrainSolverParams class






tracks the amount of real calls to the posterior durnig the solve




tracks the amount of posterior calls loaded from cache during the solve




timesteps taken




regularization factors at steps taken



Return type:



Starting from the sample from the initial distribution, propagate it through the fitted dynamics and return a sample from the distribution on the last step.


sample_x0 (numpy.ndarray) – sample from the initial distribution, should have shape (n_samples, dim)


sample from the distribution of the last step

Return type:


step(beta, T, save_history=False)[source]

Perform a regularized JKO step.

TODO implement more detailed description, with math?!

  • beta (numpy.float64) – regularization factor

  • T (numpy.float64) – timestep

  • save_history (bool) – if True, returns the intermediate ranks, errors etc during the fixed-point iterations

Return type:

Tuple[List[numpy.ndarray], List[numpy.ndarray]]

class rJKOtt.TensorTrainSolver.TensorTrainSolverParams(cross_nfev_with_posterior=400000, cross_nfev_no_posterior=1000000, cross_rel_diff=1e-06, cross_use_validation=False, cross_validation_rtol=1e-06, cross_n_validation=1000, max_rank_eta=5, max_rank_hat_eta=5, max_rank_density=20, trunc_tol_hat_eta=1e-13, trunc_tol_eta=1e-13, trunc_tol_density=1e-13, fp_method='2_anderson', fp_relaxation=0.9, fp_stopping_rtol=1e-08, fp_max_iter=100, zero_threshold=1e-16, sampling_ode_rtol=0.001, sampling_ode_atol=1e-06, sampling_sde_fraction=0.005, sampling_n_euler_maruyama_steps=50)[source]

Bases: object

Store the solver params such as tolerances and number of iterations. Default values provided.

These parameters can be changed between or possibly even during the fixed-point iterations.

  • cross_nfev_with_posterior (int)

  • cross_nfev_no_posterior (int)

  • cross_rel_diff (float)

  • cross_use_validation (bool)

  • cross_validation_rtol (float)

  • cross_n_validation (int)

  • max_rank_eta (int)

  • max_rank_hat_eta (int)

  • max_rank_density (int)

  • trunc_tol_hat_eta (float)

  • trunc_tol_eta (float)

  • trunc_tol_density (float)

  • fp_method (Literal['picard', '2_anderson', 'aitken'])

  • fp_relaxation (float)

  • fp_stopping_rtol (float)

  • fp_max_iter (int)

  • zero_threshold (float)

  • sampling_ode_rtol (float)

  • sampling_ode_atol (float)

  • sampling_sde_fraction (float)

  • sampling_n_euler_maruyama_steps (int)

cross_n_validation: int = 1000

Number of indices in the validation subset

cross_nfev_no_posterior: int = 1000000

Number of function evaluation in TT-cross when real posterior calls are not required (initial condition and KL estimation)

cross_nfev_with_posterior: int = 400000

Number of function evaluation in TT-cross when real posterior calls are required (terminal condition and KL)

cross_rel_diff: float = 1e-06

if the solution relative change is less, cross stops


Cross stopping criterion

cross_use_validation: bool = False

If to use error on validation subset stopping criterion during the cross approximation iteration

cross_validation_rtol: float = 1e-06

if error on validation subset is less, cross stops


Cross stopping criterion

fp_max_iter: int = 100

Fixed-point solution will terminate after doing this amount of iterations regardless of the convergence

fp_method: Literal['picard', '2_anderson', 'aitken'] = '2_anderson'

Method of solving the fixed-point problem at a step.

  • picard : The simplest method, next iterate is a linear combination of current iterate and the operator at current value with fixed relaxation factor

  • aitken : Same as picard, but the relaxation is selected adaptively. Supposed to be more robust

  • 2_anderson : Saves the history of 2 previous iterates and function values and generates a next iterate based on a minimization subproblem.

fp_relaxation: float = 0.9

Relaxation for the fixed-point method

fp_stopping_rtol: float = 1e-08

Relative tolerance of the fixed-point iteration; if \(\frac{\|x_k - G(x_k)\|_2}{\|x_k\|_2}\) is smaller than this value, the iteration terminates

property max_rank: int
max_rank_density: int = 20

Maximal TT-rank when representing the next distribution

max_rank_eta: int = 5

Maximal TT-rank when representing the entropic potential \(\eta\)

max_rank_hat_eta: int = 5

Maximal TT-rank when representing the entropic potential \(\hat\eta\)

sampling_n_euler_maruyama_steps: int = 50

Number of Euler-Maruyama steps used in the SDE solution

sampling_ode_atol: float = 1e-06

Tolerance of solving the sampling ODE

sampling_ode_rtol: float = 0.001

Tolerance of solving the sampling ODE

sampling_sde_fraction: float = 0.005

Fraction of each timestep to be solved with the SDE dynamic

property trunc_tol: int
trunc_tol_density: float = 1e-13

TT truncation tolerance when computing the next density

trunc_tol_eta: float = 1e-13

TT truncation tolerance for potential \(\eta\) (i.e. in the terminal condition)

trunc_tol_hat_eta: float = 1e-13

TT truncation tolerance for potential \(\hat\eta\) (i.e. in the initial condition)

zero_threshold: float = 1e-16

Safeguard when computing small values that must be positive, but could go below zero due to numerical erros