arviz_stats.kl_divergence#
- arviz_stats.kl_divergence(data1, data2, group='posterior', var_names=None, sample_dims=None, num_samples=500, round_to=None, random_seed=212480)[source]#
Compute the Kullback-Leibler (KL) divergence.
The KL-divergence is a measure of how different two probability distributions are. It represents how much extra uncertainty are we introducing when we use one distribution to approximate another. The KL-divergence is not symmetric, thus changing the order of the data1 and data2 arguments will change the result.
For details of the approximation used to the compute the KL-divergence see [1].
- Parameters:
- data1, data2
xarray.DataArray,xarray.Dataset,xarray.DataTree, orInferenceData - group
hashable, default “posterior” Group on which to compute the kl-divergence.
- var_names
strorlistofstr, optional Names of the variables for which the KL-divergence should be computed.
- sample_dimsiterable of
hashable, optional Dimensions to be considered sample dimensions and are to be reduced. Default
rcParams["data.sample_dims"].- num_samples
int Number of samples to use for the distance calculation. Default is 500.
- round_to: int or str or None, optional
- If integer, number of decimal places to round the result. Integers can be negative.
If string of the form ‘2g’ number of significant digits to round the result. Defaults to rcParams[“stats.round_to”] if None. Use the string “None” or “none” to return raw numbers.
- random_seed
int Random seed for reproducibility. Use None for no seed.
- data1, data2
- Returns:
- KL-divergence
float
- KL-divergence
References
[1]F. Perez-Cruz, Kullback-Leibler divergence estimation of continuous distributions IEEE International Symposium on Information Theory. (2008) https://doi.org/10.1109/ISIT.2008.4595271. preprint https://www.tsc.uc3m.es/~fernando/bare_conf3.pdf
Examples
Calculate the KL-divergence between the posterior distributions for the variable mu in the centered and non-centered eight schools models
In [1]: from arviz_stats import kl_divergence ...: from arviz_base import load_arviz_data ...: data1 = load_arviz_data('centered_eight') ...: data2 = load_arviz_data('non_centered_eight') ...: kl_divergence(data1, data2, var_names="mu") ...: Out[1]: 1.1