arviz_stats.wasserstein#
- arviz_stats.wasserstein(data1, data2, group='posterior', var_names=None, sample_dims=None, joint=True, num_samples=500, round_to=None, random_seed=212480)[source]#
Compute the Wasserstein-1 distance.
The Wasserstein distance, also called the Earth mover’s distance or the optimal transport distance, is a similarity metric between two probability distributions [1].
- Parameters:
- data1, data2
xarray.DataArray,xarray.Dataset,xarray.DataTree, orInferenceData - group
hashable, default “posterior” Group on which to compute the Wasserstein distance.
- var_names
strorlistofstr, optional Names of the variables for which the Wasserstein distance should be computed.
- sample_dimsiterable of
hashable, optional Dimensions to be considered sample dimensions and are to be reduced. Default
rcParams["data.sample_dims"].- jointbool, default
True Whether to compute Wasserstein distance for the joint distribution (True) or over the marginals (False)
- num_samples
int Number of samples to use for the distance calculation. Default is 500.
- round_to: int or str or None, optional
- If integer, number of decimal places to round the result. Integers can be negative.
If string of the form ‘2g’ number of significant digits to round the result. Defaults to rcParams[“stats.round_to”] if None. Use the string “None” or “none” to return raw numbers.
- random_seed
int Random seed for reproducibility. Use None for no seed.
- data1, data2
- Returns:
- wasserstein_distance
float
- wasserstein_distance
Notes
The computation is faster for the marginals (joint=False). This is equivalent to assume the marginals are independent, which usually is not the case. This function uses the
scipy.stats.wasserstein_distancefor the computation of the marginals andscipy.stats.wasserstein_distance_ndfor the joint distribution.References
[1]“Wasserstein metric”, https://en.wikipedia.org/wiki/Wasserstein_metric
Examples
Calculate the Wasserstein distance between the posterior distributions for the variable mu in the centered and non-centered eight schools models
In [1]: from arviz_stats import wasserstein ...: from arviz_base import load_arviz_data ...: data1 = load_arviz_data('centered_eight') ...: data2 = load_arviz_data('non_centered_eight') ...: wasserstein(data1, data2, var_names="mu") ...: Out[1]: 0.29