Skip to content

Commit

Permalink
Improve comments and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
ericsuh committed Mar 15, 2020
1 parent d18cb03 commit c30ae98
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 59 deletions.
221 changes: 174 additions & 47 deletions dirichlet/dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,24 @@
euler = -1 * psi(1) # Euler-Mascheroni constant


class NotConvergingError(Exception):
"""Error when a successive approximation method doesn't converge
"""
pass


def test(D1, D2, method="meanprecision", maxiter=None):
"""Test for statistical difference between observed proportions.
Parameters
----------
D1 : array
D2 : array
Both ``D1`` and ``D2`` must have the same number of columns, which are
the different levels or categorical possibilities. Each row of the
matrices must add up to 1.
D1 : (N1, K) shape array
D2 : (N2, K) shape array
Input observations. ``N1`` and ``N2`` are the number of observations,
and ``K`` is the number of parameters for the Dirichlet distribution
(i.e. the number of levels or categorical possibilities).
Each cell is the proportion seen in that category for a particular
observation. Rows of the matrices must add up to 1.
method : string
One of ``'fixedpoint'`` and ``'meanprecision'``, designates method by
which to find MLE Dirichlet distribution. Default is
Expand All @@ -70,16 +78,16 @@ def test(D1, D2, method="meanprecision", maxiter=None):
Test statistic, which is ``-2 * log`` of likelihood ratios.
p : float
p-value of test.
a0 : array
a1 : array
a2 : array
a0 : (K,) shape array
a1 : (K,) shape array
a2 : (K,) shape array
MLE parameters for the Dirichlet distributions fit to
``D1`` and ``D2`` together, ``D1``, and ``D2``, respectively."""

N1, K1 = D1.shape
N2, K2 = D2.shape
if K1 != K2:
raise Exception("D1 and D2 must have the same number of columns")
raise ValueError("D1 and D2 must have the same number of columns")

D0 = vstack((D1, D2))
a0 = mle(D0, method=method, maxiter=maxiter)
Expand All @@ -91,29 +99,52 @@ def test(D1, D2, method="meanprecision", maxiter=None):


def pdf(alphas):
"""Returns a Dirichlet PDF function"""
"""Returns a Dirichlet PDF function
Parameters
----------
alphas : (K,) shape array
The parameters for the distribution of shape ``(K,)``.
Returns
-------
function
The PDF function, takes an ``(N, K)`` shape input and gives an
``(N,)`` output.
"""
alphap = alphas - 1
c = np.exp(gammaln(alphas.sum()) - gammaln(alphas).sum())

def dirichlet(xs):
"""N x K array"""
"""Dirichlet PDF
Parameters
----------
xs : (N, K) shape array
The ``(N, K)`` shape input matrix
Returns
-------
(N,) shape array
Point value for PDF
"""
return c * (xs ** alphap).prod(axis=1)

return dirichlet


def meanprecision(a):
"""Mean and precision of Dirichlet distribution.
"""Mean and precision of a Dirichlet distribution.
Parameters
----------
a : array
Parameters of Dirichlet distribution.
a : (K,) shape array
Parameters of a Dirichlet distribution.
Returns
-------
mean : array
Numbers [0,1] of the means of the Dirichlet distribution.
mean : (K,) shape array
Means of the Dirichlet distribution. Values are in [0,1].
precision : float
Precision or concentration parameter of the Dirichlet distribution."""

Expand All @@ -127,10 +158,10 @@ def loglikelihood(D, a):
Parameters
----------
D : 2D array
where ``N`` is the number of observations, ``K`` is the number of
D : (N, K) shape array
``N`` is the number of observations, ``K`` is the number of
parameters for the Dirichlet distribution.
a : array
a : (K,) shape array
Parameters for the Dirichlet distribution.
Returns
Expand All @@ -148,10 +179,9 @@ def mle(D, tol=1e-7, method="meanprecision", maxiter=None):
Parameters
----------
D : 2D array
``N x K`` array of numbers from [0,1] where ``N`` is the number of
observations, ``K`` is the number of parameters for the Dirichlet
distribution.
D : (N, K) shape array
``N`` is the number of observations, ``K`` is the number of
parameters for the Dirichlet distribution.
tol : float
If Euclidean distance between successive parameter arrays is less than
``tol``, calculation is taken to have converged.
Expand All @@ -165,7 +195,7 @@ def mle(D, tol=1e-7, method="meanprecision", maxiter=None):
Returns
-------
a : array
a : (K,) shape array
Maximum likelihood parameters for Dirichlet distribution."""

if method == "meanprecision":
Expand All @@ -175,8 +205,24 @@ def mle(D, tol=1e-7, method="meanprecision", maxiter=None):


def _fixedpoint(D, tol=1e-7, maxiter=None):
"""Simple fixed point iteration method for MLE of Dirichlet distribution"""
N, K = D.shape
"""Simple fixed point iteration method for MLE of Dirichlet distribution
Parameters
----------
D : (N, K) shape array
``N`` is the number of observations, ``K`` is the number of
parameters for the Dirichlet distribution.
tol : float
If Euclidean distance between successive parameter arrays is less than
``tol``, calculation is taken to have converged.
maxiter : int
Maximum number of iterations to take calculations. Default is
``sys.maxint``.
Returns
-------
a : (K,) shape array
Fixed-point estimated parameters for Dirichlet distribution."""
logp = log(D).mean(axis=0)
a0 = _init_a(D)

Expand All @@ -185,19 +231,37 @@ def _fixedpoint(D, tol=1e-7, maxiter=None):
maxiter = MAXINT
for i in range(maxiter):
a1 = _ipsi(psi(a0.sum()) + logp)
# if norm(a1-a0) < tol:
if abs(loglikelihood(D, a1) - loglikelihood(D, a0)) < tol: # much faster
# Much faster convergence than with the more obvious condition
# `norm(a1-a0) < tol`
if abs(loglikelihood(D, a1) - loglikelihood(D, a0)) < tol:
return a1
a0 = a1
raise Exception(
raise NotConvergingError(
"Failed to converge after {} iterations, values are {}.".format(maxiter, a1)
)


def _meanprecision(D, tol=1e-7, maxiter=None):
"""Mean and precision alternating method for MLE of Dirichlet
distribution"""
N, K = D.shape
"""Mean/precision method for MLE of Dirichlet distribution
Uses alternating estimations of mean and precision.
Parameters
----------
D : (N, K) shape array
``N`` is the number of observations, ``K`` is the number of
parameters for the Dirichlet distribution.
tol : float
If Euclidean distance between successive parameter arrays is less than
``tol``, calculation is taken to have converged.
maxiter : int
Maximum number of iterations to take calculations. Default is
``sys.maxint``.
Returns
-------
a : (K,) shape array
Estimated parameters for Dirichlet distribution."""
logp = log(D).mean(axis=0)
a0 = _init_a(D)
s0 = a0.sum()
Expand All @@ -217,19 +281,38 @@ def _meanprecision(D, tol=1e-7, maxiter=None):
s1 = sum(a1)
a1 = _fit_m(D, a1, logp, tol=tol)
m = a1 / s1
# if norm(a1-a0) < tol:
if abs(loglikelihood(D, a1) - loglikelihood(D, a0)) < tol: # much faster
# Much faster convergence than with the more obvious condition
# `norm(a1-a0) < tol`
if abs(loglikelihood(D, a1) - loglikelihood(D, a0)) < tol:
return a1
a0 = a1
raise Exception(
raise NotConvergingError(
f"Failed to converge after {maxiter} iterations, " f"values are {a1}."
)


def _fit_s(D, a0, logp, tol=1e-7, maxiter=1000):
"""Assuming a fixed mean for Dirichlet distribution, maximize likelihood
for preicision a.k.a. s"""
N, K = D.shape
"""Update parameters via MLE of precision with fixed mean
Parameters
----------
D : (N, K) shape array
``N`` is the number of observations, ``K`` is the number of
parameters for the Dirichlet distribution.
a0 : (K,) shape array
Current parameters for Dirichlet distribution
logp : (K,) shape array
Mean of log-transformed D across N observations
tol : float
If Euclidean distance between successive parameter arrays is less than
``tol``, calculation is taken to have converged.
maxiter : int
Maximum number of iterations to take calculations. Default is 1000.
Returns
-------
(K,) shape array
Updated parameters for Dirichlet distribution."""
s1 = a0.sum()
m = a0 / s1
mlogp = (m * logp).sum()
Expand All @@ -247,20 +330,38 @@ def _fit_s(D, a0, logp, tol=1e-7, maxiter=1000):
if s1 <= 0:
s1 = s0 - g / h # Newton
if s1 <= 0:
raise Exception(f"Unable to update s from {s0}")
raise NotConvergingError(f"Unable to update s from {s0}")

a = s1 * m
if abs(s1 - s0) < tol:
return a

raise Exception(f"Failed to converge after {maxiter} iterations, " f"s is {s1}")
raise NotConvergingError(f"Failed to converge after {maxiter} iterations, " f"s is {s1}")


def _fit_m(D, a0, logp, tol=1e-7, maxiter=1000):
"""With fixed precision s, maximize mean m"""
N, K = D.shape
s = a0.sum()
"""Update parameters via MLE of mean with fixed precision s
Parameters
----------
D : (N, K) shape array
``N`` is the number of observations, ``K`` is the number of
parameters for the Dirichlet distribution.
a0 : (K,) shape array
Current parameters for Dirichlet distribution
logp : (K,) shape array
Mean of log-transformed D across N observations
tol : float
If Euclidean distance between successive parameter arrays is less than
``tol``, calculation is taken to have converged.
maxiter : int
Maximum number of iterations to take calculations. Default is 1000.
Returns
-------
(K,) shape array
Updated parameters for Dirichlet distribution."""
s = a0.sum()
for i in range(maxiter):
m = a0 / s
a1 = _ipsi(logp + (m * (psi(a0) - logp)).sum())
Expand All @@ -270,11 +371,22 @@ def _fit_m(D, a0, logp, tol=1e-7, maxiter=1000):
return a1
a0 = a1

raise Exception(f"Failed to converge after {maxiter} iterations, " f"s is {s}")
raise NotConvergingError(f"Failed to converge after {maxiter} iterations, " f"s is {s}")


def _init_a(D):
"""Initial guess for Dirichlet alpha parameters given data D"""
"""Initial guess for Dirichlet alpha parameters given data D
Parameters
----------
D : (N, K) shape array
``N`` is the number of observations, ``K`` is the number of
parameters for the Dirichlet distribution.
Returns
-------
(K,) shape array
Crude guess for parameters of Dirichlet distribution."""
E = D.mean(axis=0)
E2 = (D ** 2).mean(axis=0)
return ((E[0] - E2[0]) / (E2[0] - E[0] ** 2)) * E
Expand All @@ -283,7 +395,22 @@ def _init_a(D):
def _ipsi(y, tol=1.48e-9, maxiter=10):
"""Inverse of psi (digamma) using Newton's method. For the purposes
of Dirichlet MLE, since the parameters a[i] must always
satisfy a > 0, we define ipsi :: R -> (0,inf)."""
satisfy a > 0, we define ipsi :: R -> (0,inf).
Parameters
----------
y : (K,) shape array
y-values of psi(x)
tol : float
If Euclidean distance between successive parameter arrays is less than
``tol``, calculation is taken to have converged.
maxiter : int
Maximum number of iterations to take calculations. Default is 10.
Returns
-------
(K,) shape array
Approximate x for psi(x)."""
y = asanyarray(y, dtype="float")
x0 = np.piecewise(
y,
Expand All @@ -295,7 +422,7 @@ def _ipsi(y, tol=1.48e-9, maxiter=10):
if norm(x1 - x0) < tol:
return x1
x0 = x1
raise Exception(f"Failed to converge after {maxiter} iterations, " f"value is {x1}")
raise NotConvergingError(f"Failed to converge after {maxiter} iterations, " f"value is {x1}")


def _trigamma(x):
Expand Down
Loading

0 comments on commit c30ae98

Please sign in to comment.