rJKOtt.TensorTrainSolver module
- class rJKOtt.TensorTrainSolver.TensorTrainSolver(rho_infty, rho_start, solver_params=None, posterior_cache_size=1000000)[source]
Bases:
object
Approximate a given distribution rho_infty by entropy-regularized JKO scheme with finite-difference spatial discretization and Tensor-Train decomposition.
- Parameters:
rho_infty (Callable) – function, proportional to the probability density of the posterior. Should have signature (N_samples, dim) -> (dim, )
rho_start (TensorTrainDistribution) – starting distribution, discretized on grid. The posterior will be approximated on the same grid.
posterior_cache_size (int) – maximal size of the cache
solver_params (TensorTrainSolverParams) – self-explanatory; if not given, uses defaults defined in TensorTrainSolverParams class
- n_calls
tracks the amount of real calls to the posterior durnig the solve
- Type:
int
- n_cache
tracks the amount of posterior calls loaded from cache during the solve
- Type:
int
- Ts
timesteps taken
- Type:
List[float]
- betas
regularization factors at steps taken
- Type:
List[float]
- sample(sample_x0)[source]
Starting from the sample from the initial distribution, propagate it through the fitted dynamics and return a sample from the distribution on the last step.
- Parameters:
sample_x0 (numpy.ndarray) – sample from the initial distribution, should have shape (n_samples, dim)
- Returns:
sample from the distribution of the last step
- Return type:
np.ndarray
- step(beta, T, save_history=False)[source]
Perform a regularized JKO step.
TODO implement more detailed description, with math?!
- Parameters:
beta (numpy.float64) – regularization factor
T (numpy.float64) – timestep
save_history (bool) – if True, returns the intermediate ranks, errors etc during the fixed-point iterations
- Return type:
Tuple[List[numpy.ndarray], List[numpy.ndarray]]
- class rJKOtt.TensorTrainSolver.TensorTrainSolverParams(cross_nfev_with_posterior=400000, cross_nfev_no_posterior=1000000, cross_rel_diff=1e-06, cross_use_validation=False, cross_validation_rtol=1e-06, cross_n_validation=1000, max_rank_eta=5, max_rank_hat_eta=5, max_rank_density=20, trunc_tol_hat_eta=1e-13, trunc_tol_eta=1e-13, trunc_tol_density=1e-13, fp_method='2_anderson', fp_relaxation=0.9, fp_stopping_rtol=1e-08, fp_max_iter=100, zero_threshold=1e-16, sampling_ode_rtol=0.001, sampling_ode_atol=1e-06, sampling_sde_fraction=0.005, sampling_n_euler_maruyama_steps=50)[source]
Bases:
object
Store the solver params such as tolerances and number of iterations. Default values provided.
These parameters can be changed between or possibly even during the fixed-point iterations.
- Parameters:
cross_nfev_with_posterior (int)
cross_nfev_no_posterior (int)
cross_rel_diff (float)
cross_use_validation (bool)
cross_validation_rtol (float)
cross_n_validation (int)
max_rank_eta (int)
max_rank_hat_eta (int)
max_rank_density (int)
trunc_tol_hat_eta (float)
trunc_tol_eta (float)
trunc_tol_density (float)
fp_method (Literal['picard', '2_anderson', 'aitken'])
fp_relaxation (float)
fp_stopping_rtol (float)
fp_max_iter (int)
zero_threshold (float)
sampling_ode_rtol (float)
sampling_ode_atol (float)
sampling_sde_fraction (float)
sampling_n_euler_maruyama_steps (int)
- cross_n_validation: int = 1000
Number of indices in the validation subset
- cross_nfev_no_posterior: int = 1000000
Number of function evaluation in TT-cross when real posterior calls are not required (initial condition and KL estimation)
- cross_nfev_with_posterior: int = 400000
Number of function evaluation in TT-cross when real posterior calls are required (terminal condition and KL)
- cross_rel_diff: float = 1e-06
if the solution relative change is less, cross stops
- Type:
Cross stopping criterion
- cross_use_validation: bool = False
If to use error on validation subset stopping criterion during the cross approximation iteration
- cross_validation_rtol: float = 1e-06
if error on validation subset is less, cross stops
- Type:
Cross stopping criterion
- fp_max_iter: int = 100
Fixed-point solution will terminate after doing this amount of iterations regardless of the convergence
- fp_method: Literal['picard', '2_anderson', 'aitken'] = '2_anderson'
Method of solving the fixed-point problem at a step.
picard : The simplest method, next iterate is a linear combination of current iterate and the operator at current value with fixed relaxation factor
aitken : Same as picard, but the relaxation is selected adaptively. Supposed to be more robust
2_anderson : Saves the history of 2 previous iterates and function values and generates a next iterate based on a minimization subproblem.
- fp_relaxation: float = 0.9
Relaxation for the fixed-point method
- fp_stopping_rtol: float = 1e-08
Relative tolerance of the fixed-point iteration; if \(\frac{\|x_k - G(x_k)\|_2}{\|x_k\|_2}\) is smaller than this value, the iteration terminates
- property max_rank: int
- max_rank_density: int = 20
Maximal TT-rank when representing the next distribution
- max_rank_eta: int = 5
Maximal TT-rank when representing the entropic potential \(\eta\)
- max_rank_hat_eta: int = 5
Maximal TT-rank when representing the entropic potential \(\hat\eta\)
- sampling_n_euler_maruyama_steps: int = 50
Number of Euler-Maruyama steps used in the SDE solution
- sampling_ode_atol: float = 1e-06
Tolerance of solving the sampling ODE
- sampling_ode_rtol: float = 0.001
Tolerance of solving the sampling ODE
- sampling_sde_fraction: float = 0.005
Fraction of each timestep to be solved with the SDE dynamic
- property trunc_tol: int
- trunc_tol_density: float = 1e-13
TT truncation tolerance when computing the next density
- trunc_tol_eta: float = 1e-13
TT truncation tolerance for potential \(\eta\) (i.e. in the terminal condition)
- trunc_tol_hat_eta: float = 1e-13
TT truncation tolerance for potential \(\hat\eta\) (i.e. in the initial condition)
- zero_threshold: float = 1e-16
Safeguard when computing small values that must be positive, but could go below zero due to numerical erros