Welford algorithm for updating variance
I came across this little algorithm by chance, and stunned by its simple form. It took me a whole page to derive it from end to end, it feels good that I can still do something like that!
The problem setting is pretty simple: we want to calculate the mean and variance of a numeric list, with length \(N\). Really nothing fancy, just straight-up definitions:
\[\begin{eqnarray*} \mu_N &=& \frac{1}{N}\sum_{i=1}^{N} x_i~, \\ \sigma^2_N &=& \frac{1}{N} \sum_{i=1}^{N} (x_i - \mu_N)^2~. \end{eqnarray*}\]Let me treat the variance as the population variance, hence the denominator of \(N\), instead of \(N-1\). In term of the speed complexity, each of the calculation is \(O(N)\) since we have to touch every element of the list.
Now a new element, \(x_{N+1}\), is added to the list, and we want to update the mean and variance of this new list with length \(N+1\). Naturally, we just repeat the two \(O(N)\) calculations. But can we do better, if we already know \(N, \mu_N, \sigma^2_N\)? For large \(N\), it seems like a lot of effort with just one new element. Is it needed?
The algorithm
Yes we can. Let’s start from the new mean, \(\mu_{N+1}\). This is rather straightforward:
\[\begin{eqnarray} (N+1)\mu_{N+1} &=& \sum_{i=1}^{N+1} x_i\\ &=& \sum_{i=1}^{N} x_i + x_{N+1}\\ &=& N\mu_{N+1} + x_{N+1}\\ \mu_{N+1} &=& \mu_N + \frac{1}{N+1}(x_{N+1} - \mu_N)~. \tag{1}\label{new_mu} \end{eqnarray}\]By doing so, there is no need to go through the whole list, but rather just a few arithmetic operations, to update the mean. This is an \(O(1)\) calculation.
To get \(\sigma^2_{N+1}\) is slightly more involved. First, we recall the alternative form of variance as
\[\begin{eqnarray} \sigma^2_N &=& \frac{1}{N} \sum_{i=1}^{N} x_i^2 - \mu_N^2~.\tag{2}\label{var_2} \end{eqnarray}\]From here, we can get two generic forms of:
\[\begin{eqnarray} N\sigma^2_N &=& \sum_{i=1}^{N} x_i^2 - N\mu_N^2~,\\ (N+1)\sigma^2_{N+1} &=& \sum_{i=1}^{N+1} x_i^2 - (N+1)\mu_{N+1}^2~\\ &=& \sum_{i=1}^{N} x_i^2 + x_{N+1}^2 - (N+1)\mu_{N+1}^2~. \tag{3}\label{new_sig} \end{eqnarray}\]Note that, in Equation \(\eqref{new_sig}\), there is no need to calculate the summation term, but rather retrieve its value from the known \(N, \mu_N, \sigma^2_N\), as \(N(\sigma^2_N + \mu_N)\), therefore, all three terms on the right-hand-side of Equation \(\eqref{new_sig}\) can be calculated with \(O(1)\), hence \(\sigma^2_{N+1}\) can be obtained with \(O(1)\) time too.
Equations \(\eqref{new_mu}\) and \(\eqref{new_sig}\) are the essence of the Welford algorithm. However, do not use Equation \(\eqref{new_sig}\) directly in practice (see below for numerical stability issue). Instead, one shall keep going to express Equation \(\eqref{new_sig}\) in a more compact, recurrence-like format, it requires some patience. Below I will just show some key steps:
\[\begin{eqnarray} (N+1)\sigma^2_{N+1} &=& \sum_{i=1}^{N} x_i^2 + x_{N+1}^2 - (N+1)\mu_{N+1}^2~,\\ &=& \color{red}{N\sigma^2_N + N \mu_N^2} + x_{N+1}^2 - (N+1)\mu_{N+1}^2~,\\ &=& N\sigma^2_N + N \mu_N^2 + x_{N+1}^2 - \color{red}{\frac{1}{N+1}(N\mu_{N} + x_{N+1})^2}\\ &=& N\sigma^2_N + N \mu_N^2 + x_{N+1}^2 - \color{red}{\frac{1}{N+1}(N^2\mu_{N}^2 + 2N\mu_{N}x_{N+1} + x_{N+1}^2)}. \end{eqnarray}\]Multiplying both sides with \(N+1\), and after some expansions, we then have:
\[\begin{eqnarray} (N+1)^2 \sigma_{N+1}^2 &=& N(N+1) \sigma_{N}^2 + N(\mu_N^2 - 2\mu_{N}x_{N+1} + x_{N+1}^2) \\ &=& N(N+1) \sigma_{N}^2 + N(\mu_{N} - x_{N+1})^2. \end{eqnarray}\]Now we can finally express \(\sigma_{N+1}^2\) in a cleaner, recurrence relationship as:
\[\begin{eqnarray} \sigma_{N+1}^2 &=& \sigma_{N}^2 + \frac {N(\mu_{N} - x_{N+1})^2 - (N+1)\sigma_N^2} {(N+1)^2} \tag{4.1}\label{final_sig_1} \end{eqnarray}\]Equation \(\eqref{final_sig_1}\) is already rather compact, however, there is another equivalent expression for it, as we carry out some further arithmetic tricks on the second part of right-hand-side of Equation \(\eqref{final_sig_1}\), as:
\[\begin{eqnarray} &~&\frac {N(\mu_{N} - x_{N+1})^2 - (N+1)\sigma_N^2} {(N+1)^2} \\ &=& \frac {(x_{N+1} - \mu_N)N(x_{N+1} - \mu_N) - (N+1)\sigma_N^2} {(N+1)^2}\\ &=& \frac {N(x_{N+1} - \mu_N)N\color{red}{(N+1)(\mu_{N+1} - \mu_N)} - (N+1)\sigma_N^2} {(N+1)^2}\\ &=& \frac {(x_{N+1} - \mu_N)N(\mu_{N+1} - \mu_N) - \sigma_N^2} {(N+1)} \\ &=& \frac {(x_{N+1} - \mu_N)\color{red}{(x_{N+1} - \mu_{N+1})} - \sigma_N^2} {(N+1)}.\\ \end{eqnarray}\]The transformations rely on different forms of Equation \(\eqref{new_mu}\), and we have:
\[\begin{eqnarray} \sigma_{N+1}^2 &=& \sigma_N^2 + \frac {(x_{N+1} - \mu_N)(x_{N+1} - \mu_{N+1}) - \sigma_N^2} {(N+1)}. \tag{4.2}\label{final_sig_2} \end{eqnarray}\]Both Equations \(\eqref{final_sig_1}\) and \(\eqref{final_sig_2}\) can be used equivalently.
One, of course, needs proof to be convinced. You can find the tests for correctness and speed.
Numerical stability
Other than the \(O(N)\) to \(O(1)\) speed gain, the Welford algorithm provides an added benefit, which is the numerical stability. Since real numbers are represented using floating-point arithmetic with finite bits in computers, one has to lose some precision. In almost all cases, this is not a concern, however, if one subtracts a very small number from a very large number, or performs subtraction between two very close numbers, bad things can happen, due to this rounding issue. This is known as the loss of significance, or catastrophic cancellation.
How this catastrophic cancellation can affect the variance calculation? It stems from Equation \(\eqref{var_2}\), in which one calculates the variance by subtracting the squared mean from the mean of summed squared. It is very appealing since if one knows \(\mu_{N+1}\) from an \(O(1)\) operation, i.e., Equation \(\eqref{new_mu}\), and keeping track of the summed squared \(\sum x_i^2\), then \(\sigma_{N+1}^2\) can be easily calculated from Equation \(\eqref{new_sig}\) in \(O(1)\) time.
That is exactly when the catastrophic cancellation hits: if \(\sum x_i^2\) is orders of different to \(N \mu_i\), weird things can happen, and one might even get a negative \(\sigma^2\)! See here for a simple illustration.
Using Welford algorithm (Equations \(\eqref{final_sig_1}\) and \(\eqref{final_sig_2}\)) can prevent such numerical instability, since the two terms on either side of the subtraction are of the same order by design, while we can still enjoy the \(O(1)\) time!
How about median?
Other than mean and variance, another important descriptive statistical metric is median. A natural question will be: can we update it in the online fashion, with small time complexity as well? On the face of it, this seems to be an impossible task: a new element added to the list can change the order of its element, and the generic sorting algorithm runs at \(O(N\log{N})\) time.
However, we don’t want to sort the whole list, we just want to know, and keep track of, what’s the median. Here the heap data structure comes to the rescue. One can simply use two heaps to store the list: one for the smaller half, and the other for the larger half. The median of the list, will then be at the top of either heap, which can be queried in \(O(1)\) time. When a new element is added to the list, one just needs to insert it to either the smaller half, or the larger half, in \(O(\log{N})\) time. There might be some rebalancing of the two heaps needed, but you get the idea. Overall, one can update the median in \(O(\log{N})\) time. Not too shabby.
Below is the median update implementation in python:
from heapq import heappush, heappop
class RunningMedian(object):
def __init__(self):
self.running_median = []
return
def stream(self, fname):
"""
streaming data in from file
"""
with open(fname, 'r') as f:
while True:
line = f.readline()
if not line:
break
yield int(line.strip())
def get_median(self, fname):
"""
main function, update median as numbers streaming in
"""
# first initate the low and high
tmp = []
for i, x in enumerate(self.stream(fname)):
if i == 0:
self.running_median.append(x)
if i >= 2:
break
tmp.append(x)
low, high = [-1 * min(tmp)], [max(tmp)]
# two heaps for the lower and higher half
self.running_median.append(-1 * low[0])
for i, x in enumerate(self.stream(fname)):
if i < 2:
continue
# case 1, low and high is equal length
if len(low) - len(high) == 0:
if x < high[0]:
heappush(low, -1 * x)
self.running_median.append(-1 * low[0])
else:
heappush(high, x)
self.running_median.append(high[0])
continue
# case 2, low has one more element than high
if len(low) - len(high) == 1:
if x < high[0]:
heappush(low, -1 * x)
tmp = -1 * heappop(low)
heappush(high, tmp)
else:
heappush(high, x)
self.running_median.append(-1 * low[0])
continue
# case 3, high has one more element than low
if len(low) - len(high) == -1:
if x < high[0]:
heappush(low, -1 * x)
else:
heappush(high, x)
tmp = heappop(high)
heappush(low, -1 * tmp)
self.running_median.append(-1 * low[0])
continue
else:
raise Exception('Error!')
return None
Leave a Comment